diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index a5a1c36..b989839 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -978,9 +978,11 @@ func (r *RDB) RunAllAggregatingTasks(qname, gname string) (int64, error) { // Input: // KEYS[1] -> asynq:{}:t: // KEYS[2] -> asynq:{}:pending +// KEYS[3] -> asynq:{}:groups // -- // ARGV[1] -> task ID // ARGV[2] -> queue key prefix; asynq:{}: +// ARGV[3] -> group key prefix // // Output: // Numeric code indicating the status: @@ -993,15 +995,24 @@ var runTaskCmd = redis.NewScript(` if redis.call("EXISTS", KEYS[1]) == 0 then return 0 end -local state = redis.call("HGET", KEYS[1], "state") +local state, group = unpack(redis.call("HMGET", KEYS[1], "state", "group")) if state == "active" then return -1 elseif state == "pending" then return -2 -end -local n = redis.call("ZREM", ARGV[2] .. state, ARGV[1]) -if n == 0 then - return redis.error_reply("internal error: task id not found in zset " .. tostring(state)) +elseif state == "aggregating" then + local n = redis.call("ZREM", ARGV[3] .. group, ARGV[1]) + if n == 0 then + return redis.error_reply("internal error: task id not found in zset " .. tostring(ARGV[3] .. group)) + end + if redis.call("ZCARD", ARGV[3] .. group) == 0 then + redis.call("SREM", KEYS[3], group) + end +else + local n = redis.call("ZREM", ARGV[2] .. state, ARGV[1]) + if n == 0 then + return redis.error_reply("internal error: task id not found in zset " .. tostring(ARGV[2] .. state)) + end end redis.call("LPUSH", KEYS[2], ARGV[1]) redis.call("HSET", KEYS[1], "state", "pending") @@ -1022,10 +1033,12 @@ func (r *RDB) RunTask(qname, id string) error { keys := []string{ base.TaskKey(qname, id), base.PendingKey(qname), + base.AllGroups(qname), } argv := []interface{}{ id, base.QueueKeyPrefix(qname), + base.GroupKeyPrefix(qname), } res, err := runTaskCmd.Run(context.Background(), r.client, keys, argv...).Result() if err != nil { diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 5212920..5987100 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -1986,6 +1986,111 @@ func TestRunRetryTask(t *testing.T) { } } +func TestRunAggregatingTask(t *testing.T) { + r := setup(t) + defer r.Close() + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task1").SetGroup("group1").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetType("task2").SetGroup("group1").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("custom").SetType("task3").SetGroup("group1").Build() + + fxt := struct { + tasks []*h.TaskSeedData + allQueues []string + allGroups map[string][]string + groups map[string][]*redis.Z + }{ + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + }, + allQueues: []string{"default", "custom"}, + allGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {"group1"}, + }, + groups: map[string][]*redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group1"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + } + + tests := []struct { + desc string + qname string + id string + wantPending map[string][]string + wantAllGroups map[string][]string + wantGroups map[string][]redis.Z + }{ + { + desc: "schedules task from a group with multiple tasks", + qname: "default", + id: m1.ID, + wantPending: map[string][]string{ + base.PendingKey("default"): {m1.ID}, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {"group1"}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group1"): { + {Member: m3.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + }, + }, + }, + { + desc: "schedules task from a group with a single task", + qname: "custom", + id: m3.ID, + wantPending: map[string][]string{ + base.PendingKey("custom"): {m3.ID}, + }, + wantAllGroups: map[string][]string{ + base.AllGroups("default"): {"group1"}, + base.AllGroups("custom"): {}, + }, + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "group1"): { + {Member: m1.ID, Score: float64(now.Add(-20 * time.Second).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-25 * time.Second).Unix())}, + }, + base.GroupKey("custom", "group1"): {}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedTasks(t, r.client, fxt.tasks) + h.SeedRedisSet(t, r.client, base.AllQueues, fxt.allQueues) + h.SeedRedisSets(t, r.client, fxt.allGroups) + h.SeedRedisZSets(t, r.client, fxt.groups) + + t.Run(tc.desc, func(t *testing.T) { + err := r.RunTask(tc.qname, tc.id) + if err != nil { + t.Fatalf("RunTask returned error: %v", err) + } + + h.AssertRedisLists(t, r.client, tc.wantPending) + h.AssertRedisZSets(t, r.client, tc.wantGroups) + h.AssertRedisSets(t, r.client, tc.wantAllGroups) + }) + } +} + func TestRunScheduledTask(t *testing.T) { r := setup(t) defer r.Close()