diff --git a/asynq.go b/asynq.go index cda9a23..3d27d06 100644 --- a/asynq.go +++ b/asynq.go @@ -162,6 +162,8 @@ func newTaskInfo(msg *base.TaskMessage, state base.TaskState, nextProcessAt time info.State = TaskStateArchived case base.TaskStateCompleted: info.State = TaskStateCompleted + case base.TaskStateAggregating: + info.State = TaskStateAggregating default: panic(fmt.Sprintf("internal error: unknown state: %d", state)) } @@ -189,6 +191,9 @@ const ( // Indicates that the task is processed successfully and retained until the retention TTL expires. TaskStateCompleted + + // Indicates that the task is waiting in a group to be aggreated into one task. + TaskStateAggregating ) func (s TaskState) String() string { @@ -205,6 +210,8 @@ func (s TaskState) String() string { return "archived" case TaskStateCompleted: return "completed" + case TaskStateAggregating: + return "aggregating" } panic("asynq: unknown task state") } diff --git a/client.go b/client.go index 7c4e259..34aafd5 100644 --- a/client.go +++ b/client.go @@ -227,6 +227,7 @@ type option struct { uniqueTTL time.Duration processAt time.Time retention time.Duration + groupKey string } // composeOptions merges user provided options into the default options @@ -254,8 +255,8 @@ func composeOptions(opts ...Option) (option, error) { res.queue = qname case taskIDOption: id := string(opt) - if err := validateTaskID(id); err != nil { - return option{}, err + if isBlank(id) { + return option{}, errors.New("task ID cannot be empty") } res.taskID = id case timeoutOption: @@ -274,6 +275,12 @@ func composeOptions(opts ...Option) (option, error) { res.processAt = time.Now().Add(time.Duration(opt)) case retentionOption: res.retention = time.Duration(opt) + case groupOption: + key := string(opt) + if isBlank(key) { + return option{}, errors.New("group key cannot be empty") + } + res.groupKey = key default: // ignore unexpected option } @@ -281,12 +288,9 @@ func composeOptions(opts ...Option) (option, error) { return res, nil } -// validates user provided task ID string. -func validateTaskID(id string) error { - if strings.TrimSpace(id) == "" { - return errors.New("task ID cannot be empty") - } - return nil +// isBlank returns true if the given s is empty or consist of all whitespaces. +func isBlank(s string) bool { + return strings.TrimSpace(s) == "" } const ( @@ -375,13 +379,18 @@ func (c *Client) EnqueueContext(ctx context.Context, task *Task, opts ...Option) } now := time.Now() var state base.TaskState - if opt.processAt.Before(now) || opt.processAt.Equal(now) { + if opt.processAt.After(now) { + err = c.schedule(ctx, msg, opt.processAt, opt.uniqueTTL) + state = base.TaskStateScheduled + } else if opt.groupKey != "" { + // Use zero value for processAt since we don't know when the task will be aggregated and processed. + opt.processAt = time.Time{} + err = c.addToGroup(ctx, msg, opt.groupKey, opt.uniqueTTL) + state = base.TaskStateAggregating + } else { opt.processAt = now err = c.enqueue(ctx, msg, opt.uniqueTTL) state = base.TaskStatePending - } else { - err = c.schedule(ctx, msg, opt.processAt, opt.uniqueTTL) - state = base.TaskStateScheduled } switch { case errors.Is(err, errors.ErrDuplicateTask): @@ -408,3 +417,10 @@ func (c *Client) schedule(ctx context.Context, msg *base.TaskMessage, t time.Tim } return c.rdb.Schedule(ctx, msg, t) } + +func (c *Client) addToGroup(ctx context.Context, msg *base.TaskMessage, groupKey string, uniqueTTL time.Duration) error { + if uniqueTTL > 0 { + return c.rdb.AddToGroupUnique(ctx, msg, groupKey, uniqueTTL) + } + return c.rdb.AddToGroup(ctx, msg, groupKey) +} diff --git a/client_test.go b/client_test.go index 867de91..06cfec5 100644 --- a/client_test.go +++ b/client_test.go @@ -478,6 +478,154 @@ func TestClientEnqueue(t *testing.T) { } } +func TestClientEnqueueWithGroupOption(t *testing.T) { + r := setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + task := NewTask("mytask", []byte("foo")) + now := time.Now() + + tests := []struct { + desc string + task *Task + opts []Option + wantInfo *TaskInfo + wantPending map[string][]*base.TaskMessage + wantGroups map[string]map[string][]base.Z // map queue name to a set of groups + wantScheduled map[string][]base.Z + }{ + { + desc: "With only Group option", + task: task, + opts: []Option{ + Group("mygroup"), + }, + wantInfo: &TaskInfo{ + Queue: "default", + Type: task.Type(), + Payload: task.Payload(), + State: TaskStateAggregating, + MaxRetry: defaultMaxRetry, + Retried: 0, + LastErr: "", + LastFailedAt: time.Time{}, + Timeout: defaultTimeout, + Deadline: time.Time{}, + NextProcessAt: time.Time{}, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, // should not be pending + }, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + { + Message: &base.TaskMessage{ + Type: task.Type(), + Payload: task.Payload(), + Retry: defaultMaxRetry, + Queue: "default", + Timeout: int64(defaultTimeout.Seconds()), + Deadline: noDeadline.Unix(), + }, + Score: now.Unix(), + }, + }, + }, + }, + wantScheduled: map[string][]base.Z{ + "default": {}, + }, + }, + { + desc: "With Group and ProcessIn options", + task: task, + opts: []Option{ + Group("mygroup"), + ProcessIn(30 * time.Minute), + }, + wantInfo: &TaskInfo{ + Queue: "default", + Type: task.Type(), + Payload: task.Payload(), + State: TaskStateScheduled, + MaxRetry: defaultMaxRetry, + Retried: 0, + LastErr: "", + LastFailedAt: time.Time{}, + Timeout: defaultTimeout, + Deadline: time.Time{}, + NextProcessAt: now.Add(30 * time.Minute), + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, // should not be pending + }, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": {}, // should not be added to the group yet + }, + }, + wantScheduled: map[string][]base.Z{ + "default": { + { + Message: &base.TaskMessage{ + Type: task.Type(), + Payload: task.Payload(), + Retry: defaultMaxRetry, + Queue: "default", + Timeout: int64(defaultTimeout.Seconds()), + Deadline: noDeadline.Unix(), + }, + Score: now.Add(30 * time.Minute).Unix(), + }, + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) // clean up db before each test case. + + gotInfo, err := client.Enqueue(tc.task, tc.opts...) + if err != nil { + t.Error(err) + continue + } + cmpOptions := []cmp.Option{ + cmpopts.IgnoreFields(TaskInfo{}, "ID"), + cmpopts.EquateApproxTime(500 * time.Millisecond), + } + if diff := cmp.Diff(tc.wantInfo, gotInfo, cmpOptions...); diff != "" { + t.Errorf("%s;\nEnqueue(task) returned %v, want %v; (-want,+got)\n%s", + tc.desc, gotInfo, tc.wantInfo, diff) + } + + for qname, want := range tc.wantPending { + got := h.GetPendingMessages(t, r, qname) + if diff := cmp.Diff(want, got, h.IgnoreIDOpt, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("%s;\nmismatch found in %q; (-want,+got)\n%s", tc.desc, base.PendingKey(qname), diff) + } + } + + for qname, groups := range tc.wantGroups { + for groupKey, want := range groups { + got := h.GetGroupEntries(t, r, qname, groupKey) + if diff := cmp.Diff(want, got, h.IgnoreIDOpt, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("%s;\nmismatch found in %q; (-want,+got)\n%s", tc.desc, base.GroupKey(qname, groupKey), diff) + } + } + } + + for qname, want := range tc.wantScheduled { + gotScheduled := h.GetScheduledEntries(t, r, qname) + if diff := cmp.Diff(want, gotScheduled, h.IgnoreIDOpt, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("%s;\nmismatch found in %q; (-want,+got)\n%s", tc.desc, base.ScheduledKey(qname), diff) + } + } + } +} + func TestClientEnqueueWithTaskIDOption(t *testing.T) { r := setup(t) client := NewClient(getRedisConnOpt(t)) diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 46ea859..ce440e5 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -434,6 +434,14 @@ func GetCompletedEntries(tb testing.TB, r redis.UniversalClient, qname string) [ return getMessagesFromZSetWithScores(tb, r, qname, base.CompletedKey, base.TaskStateCompleted) } +// GetGroupEntries returns all scheduled messages and its score in the given queue. +// It also asserts the state field of the task. +func GetGroupEntries(tb testing.TB, r redis.UniversalClient, qname, groupKey string) []base.Z { + tb.Helper() + return getMessagesFromZSetWithScores(tb, r, qname, + func(qname string) string { return base.GroupKey(qname, groupKey) }, base.TaskStateAggregating) +} + // Retrieves all messages stored under `keyFn(qname)` key in redis list. func getMessagesFromList(tb testing.TB, r redis.UniversalClient, qname string, keyFn func(qname string) string, state base.TaskState) []*base.TaskMessage { diff --git a/internal/base/base.go b/internal/base/base.go index 98694ff..f43c316 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -50,6 +50,7 @@ const ( TaskStateRetry TaskStateArchived TaskStateCompleted + TaskStateAggregating // describes a state where task is waiting in a group to be aggregated ) func (s TaskState) String() string { @@ -66,6 +67,8 @@ func (s TaskState) String() string { return "archived" case TaskStateCompleted: return "completed" + case TaskStateAggregating: + return "aggregating" } panic(fmt.Sprintf("internal error: unknown task state %d", s)) } @@ -84,6 +87,8 @@ func TaskStateFromString(s string) (TaskState, error) { return TaskStateArchived, nil case "completed": return TaskStateCompleted, nil + case "aggregating": + return TaskStateAggregating, nil } return 0, errors.E(errors.FailedPrecondition, fmt.Sprintf("%q is not supported task state", s)) }