diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 26c34cc..5db9572 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -1447,6 +1447,50 @@ func (r *RDB) deleteAll(key, qname string) (int64, error) { return n, nil } +// deleteAllAggregatingCmd deletes all tasks from the given group. +// +// Input: +// KEYS[1] -> asynq:{}:g: +// KEYS[2] -> asynq:{}:groups +// ------- +// ARGV[1] -> task key prefix +// ARGV[2] -> group name +var deleteAllAggregatingCmd = redis.NewScript(` +local ids = redis.call("ZRANGE", KEYS[1], 0, -1) +for _, id in ipairs(ids) do + redis.call("DEL", ARGV[1] .. id) +end +redis.call("SREM", KEYS[2], ARGV[2]) +redis.call("DEL", KEYS[1]) +return table.getn(ids) +`) + +// DeleteAllAggregatingTasks deletes all aggregating tasks from the given group +// and returns the number of tasks deleted. +func (r *RDB) DeleteAllAggregatingTasks(qname, gname string) (int64, error) { + var op errors.Op = "rdb.DeleteAllAggregatingTasks" + if err := r.checkQueueExists(qname); err != nil { + return 0, errors.E(op, errors.CanonicalCode(err), err) + } + keys := []string{ + base.GroupKey(qname, gname), + base.AllGroups(qname), + } + argv := []interface{}{ + base.TaskKeyPrefix(qname), + gname, + } + res, err := deleteAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + if err != nil { + return 0, errors.E(op, errors.Unknown, err) + } + n, ok := res.(int64) + if !ok { + return 0, errors.E(op, errors.Internal, "command error: unexpected return value %v", res) + } + return n, nil +} + // deleteAllPendingCmd deletes all pending tasks from the given queue. // // Input: diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 246ce4d..b2039cd 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -4210,6 +4210,105 @@ func TestDeleteAllScheduledTasks(t *testing.T) { } } +func TestDeleteAllAggregatingTasks(t *testing.T) { + r := setup(t) + defer r.Close() + now := time.Now() + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task1").SetGroup("group1").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task2").SetGroup("group1").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("custom").SetType("task3").SetGroup("group1").Build() + + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + }, + allQueues: []string{"default", "custom"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {"group1"}, + }, + groups: map[string][]*redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group1"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + } + + tests := []struct { + desc string + qname string + gname string + want int64 + wantAllGroups map[string][]string + wantGroups map[string][]redis.Z + }{ + { + desc: "default queue group1", + qname: "default", + gname: "group1", + want: 2, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {}, + base.AllGroups("custom"): {"group1"}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): nil, + base.GroupKey("custom", "group1"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + }, + { + desc: "custom queue group1", + qname: "custom", + gname: "group1", + want: 1, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group1"): nil, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + got, err := r.DeleteAllAggregatingTasks(tc.qname, tc.gname) + if err != nil { + t.Fatalf("DeleteAllAggregatingTasks returned error: %v", err) + } + if got != tc.want { + t.Errorf("DeleteAllAggregatingTasks = %d, want %d", got, tc.want) + } + h.AssertRedisSets(t, r.client, tc.wantAllGroups) + h.AssertRedisZSets(t, r.client, tc.wantGroups) + }) + } +} + func TestDeleteAllPendingTasks(t *testing.T) { r := setup(t) defer r.Close() diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index dc249f8..d22904e 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -605,6 +605,18 @@ func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]s } } +func AssertRedisSets(t *testing.T, r redis.UniversalClient, wantSets map[string][]string) { + for key, want := range wantSets { + got, err := r.SMembers(context.Background(), key).Result() + if err != nil { + t.Fatalf("Failed to read set (key=%q): %v", key, err) + } + if diff := cmp.Diff(want, got, SortStringSliceOpt); diff != "" { + t.Errorf("mismatch found in set (key=%q): (-want,+got)\n%s", key, diff) + } + } +} + func AssertRedisZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) { for key, want := range wantZSets { got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result()