diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index fbd4666..fada959 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -7,8 +7,10 @@ package asynqtest import ( "encoding/json" + "fmt" "math" "sort" + "strconv" "strings" "testing" "time" @@ -68,6 +70,15 @@ var SortZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []base.Z) []ba return out }) +// SortTaskInfos is an cmp.Option to sort TaskInfo for comparing slice of task infos. +var SortTaskInfos = cmp.Transformer("SortTaskInfos", func(in []*base.TaskInfo) []*base.TaskInfo { + out := append([]*base.TaskInfo(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + return out[i].ID.String() < out[j].ID.String() + }) + return out +}) + // SortServerInfoOpt is a cmp.Option to sort base.ServerInfo for comparing slice of process info. var SortServerInfoOpt = cmp.Transformer("SortServerInfo", func(in []*base.ServerInfo) []*base.ServerInfo { out := append([]*base.ServerInfo(nil), in...) // Copy input to avoid mutating it @@ -478,24 +489,77 @@ func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient, qname return res } -// GetRetryEntries returns all retry messages and its score in the given queue. -func GetRetryTasks(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskInfo { +// GetRetryTaskInfos returns all retry tasks' TaskInfo from the given queue. +func GetRetryTaskInfos(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskInfo { tb.Helper() - zs := r.ZRangeWithScores(base.RetryKey(qname), 0, -1).Val() - var tasks []*base.TaskInfo - for _, z := range zs { - vals := r.HMGet(base.TaskKey(qname, z.Member.(string)), "msg", "state", "process_at", "last_failed_at").Val() - if len(vals) != 4 { - tb.Fatalf("unexpected number of values returned from HMGET command, got %d elements, want 4", len(vals)) - } - if vals[0] == redis.Nil { - tb.Fatalf("msg field contained nil for task ID %v", z.Member) - } - if vals[1] == redis.Nil { - tb.Fatalf("state field contained nil for task ID %v", z.Member) - } - // TODO: continue from here - - } - return res + return getTaskInfosFromZSet(tb, r, qname, base.RetryKey) +} + +// GetArchivedTaskInfos returns all archived tasks' TaskInfo from the given queue. +func GetArchivedTaskInfos(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskInfo { + tb.Helper() + return getTaskInfosFromZSet(tb, r, qname, base.ArchivedKey) +} + +func getTaskInfosFromZSet(tb testing.TB, r redis.UniversalClient, qname string, + keyFn func(qname string) string) []*base.TaskInfo { + tb.Helper() + ids := r.ZRange(keyFn(qname), 0, -1).Val() + var tasks []*base.TaskInfo + for _, id := range ids { + vals := r.HMGet(base.TaskKey(qname, id), "msg", "state", "process_at", "last_failed_at").Val() + info, err := makeTaskInfo(vals) + if err != nil { + tb.Fatalf("could not make task info from values returned by HMGET: %v", err) + } + tasks = append(tasks, info) + } + return tasks +} + +// makeTaskInfo takes values returned from HMGET(TASK_KEY, "msg", "state", "process_at", "last_failed_at") +// command and return a TaskInfo. It assumes that `vals` contains four values for each field. +func makeTaskInfo(vals []interface{}) (*base.TaskInfo, error) { + if len(vals) != 4 { + return nil, fmt.Errorf("asynq internal error: HMGET command returned %d elements", len(vals)) + } + // Note: The "msg", "state" fields are non-nil; + // whereas the "process_at", "last_failed_at" fields can be nil. + encoded := vals[0] + if encoded == nil { + return nil, fmt.Errorf("asynq internal error: HMGET field 'msg' was nil") + } + msg, err := base.DecodeMessage([]byte(encoded.(string))) + if err != nil { + return nil, err + } + state := vals[1] + if state == nil { + return nil, fmt.Errorf("asynq internal error: HMGET field 'state' was nil") + } + processAt, err := parseIntOrDefault(vals[2], 0) + if err != nil { + return nil, err + } + lastFailedAt, err := parseIntOrDefault(vals[3], 0) + if err != nil { + return nil, err + } + return &base.TaskInfo{ + TaskMessage: msg, + State: strings.ToLower(state.(string)), + NextProcessAt: processAt, + LastFailedAt: lastFailedAt, + }, nil +} + +// Parses val as base10 64-bit integer if val contains a value. +// Uses default value if val is nil. +// +// Assumes val contains either string value or nil. +func parseIntOrDefault(val interface{}, defaultVal int64) (int64, error) { + if val == nil { + return defaultVal, nil + } + return strconv.ParseInt(val.(string), 10, 64) } diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 13addcd..1e61abf 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -688,6 +688,7 @@ if n == 0 then return 0 end redis.call("ZADD", KEYS[2], ARGV[3], ARGV[1]) +redis.call("HSET", KEYS[1], "state", "ARCHIVED", "process_at", 0) redis.call("ZREMRANGEBYSCORE", KEYS[2], "-inf", ARGV[4]) redis.call("ZREMRANGEBYRANK", KEYS[2], 0, -ARGV[5]) return 1 diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 1adfa05..20df424 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -1795,12 +1795,14 @@ func TestArchiveRetryTask(t *testing.T) { m2 := h.NewTaskMessage("task2", nil) m3 := h.NewTaskMessageWithQueue("task3", nil, "custom") m4 := h.NewTaskMessageWithQueue("task4", nil, "custom") - t1 := time.Now().Add(1 * time.Minute) - t2 := time.Now().Add(1 * time.Hour) - t3 := time.Now().Add(2 * time.Hour) - t4 := time.Now().Add(3 * time.Hour) + now := time.Now() + t1 := now.Add(1 * time.Minute) + t2 := now.Add(1 * time.Hour) + t3 := now.Add(2 * time.Hour) + t4 := now.Add(3 * time.Hour) tests := []struct { + desc string retry map[string][]base.Z archived map[string][]base.Z qname string @@ -1810,6 +1812,7 @@ func TestArchiveRetryTask(t *testing.T) { wantArchived map[string][]*base.TaskInfo }{ { + desc: "archives task in the default queue", retry: map[string][]base.Z{ "default": { {Message: m1, Score: t1.Unix()}, @@ -1824,7 +1827,7 @@ func TestArchiveRetryTask(t *testing.T) { want: nil, wantRetry: map[string][]*base.TaskInfo{ "default": { - {TaskMessage: m2, State: "retry", NextProcessAt: t2.Unix(), LastFailedAt: 0}, + {TaskMessage: m2, State: "retry", NextProcessAt: t2.Unix(), LastFailedAt: now.Unix()}, }, }, wantArchived: map[string][]*base.TaskInfo{ @@ -1834,6 +1837,7 @@ func TestArchiveRetryTask(t *testing.T) { }, }, { + desc: "returns ErrTaskNotFound with non-existent task ID", retry: map[string][]base.Z{ "default": {{Message: m1, Score: t1.Unix()}}, }, @@ -1845,7 +1849,7 @@ func TestArchiveRetryTask(t *testing.T) { want: ErrTaskNotFound, wantRetry: map[string][]*base.TaskInfo{ "default": { - {TaskMessage: m1, State: "retry", NextProcessAt: t1.Unix(), LastFailedAt: 0}, + {TaskMessage: m1, State: "retry", NextProcessAt: t1.Unix(), LastFailedAt: now.Unix()}, }, }, wantArchived: map[string][]*base.TaskInfo{ @@ -1855,6 +1859,7 @@ func TestArchiveRetryTask(t *testing.T) { }, }, { + desc: "archives tasks in a custom named queue", retry: map[string][]base.Z{ "default": { {Message: m1, Score: t1.Unix()}, @@ -1874,11 +1879,11 @@ func TestArchiveRetryTask(t *testing.T) { want: nil, wantRetry: map[string][]*base.TaskInfo{ "default": { - {TaskMessage: m1, State: "retry", NextProcessAt: t1.Unix(), LastFailedAt: 0}, - {TaskMessage: m2, State: "retry", NextProcessAt: t2.Unix(), LastFailedAt: 0}, + {TaskMessage: m1, State: "retry", NextProcessAt: t1.Unix(), LastFailedAt: now.Unix()}, + {TaskMessage: m2, State: "retry", NextProcessAt: t2.Unix(), LastFailedAt: now.Unix()}, }, "custom": { - {TaskMessage: m4, State: "retry", NextProcessAt: t4.Unix(), LastFailedAt: 0}, + {TaskMessage: m4, State: "retry", NextProcessAt: t4.Unix(), LastFailedAt: now.Unix()}, }, }, wantArchived: map[string][]*base.TaskInfo{ @@ -1897,24 +1902,24 @@ func TestArchiveRetryTask(t *testing.T) { got := r.ArchiveTask(tc.qname, tc.id.String()) if got != tc.want { - t.Errorf("(*RDB).ArchiveTask(%q, %v) = %v, want %v", - tc.qname, tc.id, got, tc.want) + t.Errorf("%s; (*RDB).ArchiveTask(%q, %v) = %v, want %v", + tc.desc, tc.qname, tc.id, got, tc.want) continue } for qname, want := range tc.wantRetry { - gotRetry := h.GetRetryEntries(t, r.client, qname) - if diff := cmp.Diff(want, gotRetry, h.SortZSetEntryOpt, unixTimeCmpOpt); diff != "" { - t.Errorf("mismatch found in %q; (-want,+got)\n%s", - base.RetryKey(qname), diff) + gotRetry := h.GetRetryTaskInfos(t, r.client, qname) + if diff := cmp.Diff(want, gotRetry, h.SortTaskInfos, unixTimeCmpOpt); diff != "" { + t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", + tc.desc, base.RetryKey(qname), diff) } } for qname, want := range tc.wantArchived { - gotDead := h.GetArchivedEntries(t, r.client, qname) - if diff := cmp.Diff(want, gotDead, h.SortZSetEntryOpt, unixTimeCmpOpt); diff != "" { - t.Errorf("mismatch found in %q; (-want,+got)\n%s", - base.ArchivedKey(qname), diff) + gotArchived := h.GetArchivedTaskInfos(t, r.client, qname) + if diff := cmp.Diff(want, gotArchived, h.SortTaskInfos, unixTimeCmpOpt); diff != "" { + t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", + tc.desc, base.ArchivedKey(qname), diff) } } }