diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index c96bcae..12614d8 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -1073,6 +1073,66 @@ func (r *RDB) ArchiveAllScheduledTasks(qname string) (int64, error) { return n, nil } +// archiveAllAggregatingCmd archives all tasks in the given group. +// +// Input: +// KEYS[1] -> asynq:{}:g: +// KEYS[2] -> asynq:{}:archived +// KEYS[3] -> asynq:{}:groups +// ------- +// ARGV[1] -> current timestamp +// ARGV[2] -> cutoff timestamp (e.g., 90 days ago) +// ARGV[3] -> max number of tasks in archive (e.g., 100) +// ARGV[4] -> task key prefix (asynq:{}:t:) +// ARGV[5] -> group name +// +// Output: +// integer: Number of tasks archived +var archiveAllAggregatingCmd = redis.NewScript(` +local ids = redis.call("ZRANGE", KEYS[1], 0, -1) +for _, id in ipairs(ids) do + redis.call("ZADD", KEYS[2], ARGV[1], id) + redis.call("HSET", ARGV[4] .. id, "state", "archived") +end +redis.call("ZREMRANGEBYSCORE", KEYS[2], "-inf", ARGV[2]) +redis.call("ZREMRANGEBYRANK", KEYS[2], 0, -ARGV[3]) +redis.call("DEL", KEYS[1]) +redis.call("SREM", KEYS[3], ARGV[5]) +return table.getn(ids) +`) + +// ArchiveAllAggregatingTasks archives all aggregating tasks from the given group +// and returns the number of tasks archived. +// If a queue with the given name doesn't exist, it returns QueueNotFoundError. +func (r *RDB) ArchiveAllAggregatingTasks(qname, gname string) (int64, error) { + var op errors.Op = "rdb.ArchiveAllAggregatingTasks" + if err := r.checkQueueExists(qname); err != nil { + return 0, errors.E(op, errors.CanonicalCode(err), err) + } + keys := []string{ + base.GroupKey(qname, gname), + base.ArchivedKey(qname), + base.AllGroups(qname), + } + now := r.clock.Now() + argv := []interface{}{ + now.Unix(), + now.AddDate(0, 0, -archivedExpirationInDays).Unix(), + maxArchiveSize, + base.TaskKeyPrefix(qname), + gname, + } + res, err := archiveAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + if err != nil { + return 0, errors.E(op, errors.Internal, err) + } + n, ok := res.(int64) + if !ok { + return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from script %v", res)) + } + return n, nil +} + // archiveAllPendingCmd is a Lua script that moves all pending tasks from // the given queue to archived state. // diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 68c3073..1678b6d 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -3144,6 +3144,121 @@ func TestArchiveAllPendingTasks(t *testing.T) { } } } + +func TestArchiveAllAggregatingTasks(t *testing.T) { + r := setup(t) + defer r.Close() + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task1").SetGroup("group1").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task2").SetGroup("group1").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("custom").SetType("task3").SetGroup("group2").Build() + + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + }, + allQueues: []string{"default", "custom"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {"group2"}, + }, + groups: map[string][]*redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group2"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + } + + tests := []struct { + desc string + qname string + gname string + want int64 + wantArchived map[string][]redis.Z + wantGroups map[string][]redis.Z + wantAllGroups map[string][]string + }{ + { + desc: "archive tasks in a group with multiple tasks", + qname: "default", + gname: "group1", + want: 2, + wantArchived: map[string][]redis.Z{ + base.ArchivedKey("default"): { + {Member: m1.ID, Score: float64(now.Unix())}, + {Member: m2.ID, Score: float64(now.Unix())}, + }, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): {}, + base.GroupKey("custom", "group2"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {}, + base.AllGroups("custom"): {"group2"}, + }, + }, + { + desc: "archive tasks in a group with a single task", + qname: "custom", + gname: "group2", + want: 1, + wantArchived: map[string][]redis.Z{ + base.ArchivedKey("custom"): { + {Member: m3.ID, Score: float64(now.Unix())}, + }, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group2"): {}, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + got, err := r.ArchiveAllAggregatingTasks(tc.qname, tc.gname) + if err != nil { + t.Fatalf("ArchiveAllAggregatingTasks returned error: %v", err) + } + if got != tc.want { + t.Errorf("ArchiveAllAggregatingTasks = %d, want %d", got, tc.want) + } + h.AssertRedisZSets(t, r.client, tc.wantArchived) + h.AssertRedisZSets(t, r.client, tc.wantGroups) + h.AssertRedisSets(t, r.client, tc.wantAllGroups) + }) + } +} + func TestArchiveAllRetryTasks(t *testing.T) { r := setup(t) defer r.Close()