diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 0056f86..75774e0 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -729,7 +729,7 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateScheduled, pgn) + res, err := r.listZSetEntries(qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -747,7 +747,7 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateRetry, pgn) + res, err := r.listZSetEntries(qname, base.TaskStateRetry, base.RetryKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -764,7 +764,7 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateArchived, pgn) + zs, err := r.listZSetEntries(qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -781,7 +781,24 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateCompleted, pgn) + zs, err := r.listZSetEntries(qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) + if err != nil { + return nil, errors.E(op, errors.CanonicalCode(err), err) + } + return zs, nil +} + +// ListAggregating returns all tasks from the given group. +func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { + var op errors.Op = "rdb.ListAggregating" + exists, err := r.queueExists(qname) + if err != nil { + return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) + } + if !exists { + return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) + } + zs, err := r.listZSetEntries(qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -817,20 +834,7 @@ return data // listZSetEntries returns a list of message and score pairs in Redis sorted-set // with the given key. -func (r *RDB) listZSetEntries(qname string, state base.TaskState, pgn Pagination) ([]*base.TaskInfo, error) { - var key string - switch state { - case base.TaskStateScheduled: - key = base.ScheduledKey(qname) - case base.TaskStateRetry: - key = base.RetryKey(qname) - case base.TaskStateArchived: - key = base.ArchivedKey(qname) - case base.TaskStateCompleted: - key = base.CompletedKey(qname) - default: - panic(fmt.Sprintf("unsupported task state: %v", state)) - } +func (r *RDB) listZSetEntries(qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { res, err := listZSetEntriesCmd.Run(context.Background(), r.client, []string{key}, pgn.start(), pgn.stop(), base.TaskKeyPrefix(qname)).Result() if err != nil { diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index ae3c608..246ce4d 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -1571,6 +1571,220 @@ func TestListCompletedPagination(t *testing.T) { } } +func TestListAggregating(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + m1 := h.NewTaskMessageBuilder().SetType("task1").SetQueue("default").SetGroup("group1").Build() + m2 := h.NewTaskMessageBuilder().SetType("task2").SetQueue("default").SetGroup("group1").Build() + m3 := h.NewTaskMessageBuilder().SetType("task3").SetQueue("default").SetGroup("group2").Build() + m4 := h.NewTaskMessageBuilder().SetType("task4").SetQueue("custom").SetGroup("group3").Build() + + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + {Msg: m4, State: base.TaskStateAggregating}, + }, + allQueues: []string{"default", "custom"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"group1", "group2"}, + base.AllGroups("custom"): {"group3"}, + }, + groups: map[string][]*redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-30 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + base.GroupKey("default", "group2"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group3"): { + {Member: m4.ID, Score: float64(now.Add(-40 * time.Second).Unix())}, + }, + }, + } + + tests := []struct { + desc string + qname string + gname string + want []*base.TaskInfo + }{ + { + desc: "with group1 in default queue", + qname: "default", + gname: "group1", + want: []*base.TaskInfo{ + {Message: m1, State: base.TaskStateAggregating, NextProcessAt: time.Time{}, Result: nil}, + {Message: m2, State: base.TaskStateAggregating, NextProcessAt: time.Time{}, Result: nil}, + }, + }, + { + desc: "with group3 in custom queue", + qname: "custom", + gname: "group3", + want: []*base.TaskInfo{ + {Message: m4, State: base.TaskStateAggregating, NextProcessAt: time.Time{}, Result: nil}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{}) + if err != nil { + t.Fatalf("ListAggregating returned error: %v", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("ListAggregating = %v, want %v; (-want,+got)\n%s", got, tc.want, diff) + } + }) + } +} + +func TestListAggregatingPagination(t *testing.T) { + r := setup(t) + defer r.Close() + + groupkey := base.GroupKey("default", "mygroup") + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{}, // will be populated below + allQueues: []string{"default"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"mygroup"}, + }, + groups: map[string][]*redis.Z{ + groupkey: {}, // will be populated below + }, + } + + now := time.Now() + for i := 0; i < 100; i++ { + msg := h.NewTaskMessageBuilder().SetType(fmt.Sprintf("task%d", i)).SetGroup("mygroup").Build() + fxt.tasks = append(fxt.tasks, &h.TaskSeedData{ + Msg: msg, State: base.TaskStateAggregating, + }) + fxt.groups[groupkey] = append(fxt.groups[groupkey], &redis.Z{ + Member: msg.ID, + Score: float64(now.Add(-time.Duration(100-i) * time.Second).Unix()), + }) + } + + tests := []struct { + desc string + qname string + gname string + page int + size int + wantSize int + wantFirst string + wantLast string + }{ + { + desc: "first page", + qname: "default", + gname: "mygroup", + page: 0, + size: 20, + wantSize: 20, + wantFirst: "task0", + wantLast: "task19", + }, + { + desc: "second page", + qname: "default", + gname: "mygroup", + page: 1, + size: 20, + wantSize: 20, + wantFirst: "task20", + wantLast: "task39", + }, + { + desc: "with different page size", + qname: "default", + gname: "mygroup", + page: 2, + size: 30, + wantSize: 30, + wantFirst: "task60", + wantLast: "task89", + }, + { + desc: "last page", + qname: "default", + gname: "mygroup", + page: 3, + size: 30, + wantSize: 10, + wantFirst: "task90", + wantLast: "task99", + }, + { + desc: "out of range", + qname: "default", + gname: "mygroup", + page: 4, + size: 30, + wantSize: 0, + wantFirst: "", + wantLast: "", + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) + if err != nil { + t.Fatalf("ListAggregating returned error: %v", err) + } + + if len(got) != tc.wantSize { + t.Errorf("got %d results, want %d", len(got), tc.wantSize) + } + + if len(got) == 0 { + return + } + + first := got[0].Message + if first.Type != tc.wantFirst { + t.Errorf("First message %q, want %q", first.Type, tc.wantFirst) + } + + last := got[len(got)-1].Message + if last.Type != tc.wantLast { + t.Errorf("Last message %q, want %q", last.Type, tc.wantLast) + } + }) + } +} + func TestListTasksError(t *testing.T) { r := setup(t) defer r.Close()