diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 66831ba..3eb3076 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -182,42 +182,42 @@ func FlushDB(tb testing.TB, r redis.UniversalClient) { func SeedPendingQueue(tb testing.TB, r redis.UniversalClient, msgs []*base.TaskMessage, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisList(tb, r, base.PendingKey(qname), msgs) + seedRedisList(tb, r, base.PendingKey(qname), msgs, "pending") } // SeedActiveQueue initializes the active queue with the given messages. func SeedActiveQueue(tb testing.TB, r redis.UniversalClient, msgs []*base.TaskMessage, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisList(tb, r, base.ActiveKey(qname), msgs) + seedRedisList(tb, r, base.ActiveKey(qname), msgs, "active") } // SeedScheduledQueue initializes the scheduled queue with the given messages. func SeedScheduledQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisZSet(tb, r, base.ScheduledKey(qname), entries) + seedRedisZSet(tb, r, base.ScheduledKey(qname), entries, "scheduled") } // SeedRetryQueue initializes the retry queue with the given messages. func SeedRetryQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisZSet(tb, r, base.RetryKey(qname), entries) + seedRedisZSet(tb, r, base.RetryKey(qname), entries, "retry") } // SeedArchivedQueue initializes the archived queue with the given messages. func SeedArchivedQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisZSet(tb, r, base.ArchivedKey(qname), entries) + seedRedisZSet(tb, r, base.ArchivedKey(qname), entries, "archived") } // SeedDeadlines initializes the deadlines set with the given entries. func SeedDeadlines(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname string) { tb.Helper() r.SAdd(base.AllQueues, qname) - seedRedisZSet(tb, r, base.DeadlinesKey(qname), entries) + seedRedisZSet(tb, r, base.DeadlinesKey(qname), entries, "active") } // SeedAllPendingQueues initializes all of the specified queues with the given messages. @@ -270,7 +270,8 @@ func SeedAllDeadlines(tb testing.TB, r redis.UniversalClient, deadlines map[stri } } -func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, msgs []*base.TaskMessage) { +func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, + msgs []*base.TaskMessage, state string) { tb.Helper() for _, msg := range msgs { encoded := MustMarshal(tb, msg) @@ -280,6 +281,7 @@ func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, msgs []*b key := base.TaskKey(msg.Queue, msg.ID.String()) data := map[string]interface{}{ "msg": encoded, + "state": state, "timeout": msg.Timeout, "deadline": msg.Deadline, } @@ -289,7 +291,8 @@ func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, msgs []*b } } -func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string, items []base.Z) { +func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string, + items []base.Z, state string) { tb.Helper() for _, item := range items { msg := item.Message @@ -301,6 +304,7 @@ func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string, items []b key := base.TaskKey(msg.Queue, msg.ID.String()) data := map[string]interface{}{ "msg": encoded, + "state": state, "timeout": msg.Timeout, "deadline": msg.Deadline, } @@ -311,91 +315,116 @@ func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string, items []b } // GetPendingMessages returns all pending messages in the given queue. +// It also asserts the state field of the task. func GetPendingMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { tb.Helper() - return getMessagesFromList(tb, r, qname, base.PendingKey) + return getMessagesFromList(tb, r, qname, base.PendingKey, "pending") } // GetActiveMessages returns all active messages in the given queue. +// It also asserts the state field of the task. func GetActiveMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { tb.Helper() - return getMessagesFromList(tb, r, qname, base.ActiveKey) + return getMessagesFromList(tb, r, qname, base.ActiveKey, "active") } // GetScheduledMessages returns all scheduled task messages in the given queue. +// It also asserts the state field of the task. func GetScheduledMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { tb.Helper() - return getMessagesFromZSet(tb, r, qname, base.ScheduledKey) + return getMessagesFromZSet(tb, r, qname, base.ScheduledKey, "scheduled") } // GetRetryMessages returns all retry messages in the given queue. +// It also asserts the state field of the task. func GetRetryMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { tb.Helper() - return getMessagesFromZSet(tb, r, qname, base.RetryKey) + return getMessagesFromZSet(tb, r, qname, base.RetryKey, "retry") } // GetArchivedMessages returns all archived messages in the given queue. +// It also asserts the state field of the task. func GetArchivedMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { tb.Helper() - return getMessagesFromZSet(tb, r, qname, base.ArchivedKey) + return getMessagesFromZSet(tb, r, qname, base.ArchivedKey, "archived") } // GetScheduledEntries returns all scheduled messages and its score in the given queue. +// It also asserts the state field of the task. func GetScheduledEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z { tb.Helper() - return getMessagesFromZSetWithScores(tb, r, qname, base.ScheduledKey) + return getMessagesFromZSetWithScores(tb, r, qname, base.ScheduledKey, "scheduled") } // GetRetryEntries returns all retry messages and its score in the given queue. +// It also asserts the state field of the task. func GetRetryEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z { tb.Helper() - return getMessagesFromZSetWithScores(tb, r, qname, base.RetryKey) + return getMessagesFromZSetWithScores(tb, r, qname, base.RetryKey, "retry") } // GetArchivedEntries returns all archived messages and its score in the given queue. +// It also asserts the state field of the task. func GetArchivedEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z { tb.Helper() - return getMessagesFromZSetWithScores(tb, r, qname, base.ArchivedKey) + return getMessagesFromZSetWithScores(tb, r, qname, base.ArchivedKey, "archived") } // GetDeadlinesEntries returns all task messages and its score in the deadlines set for the given queue. +// It also asserts the state field of the task. func GetDeadlinesEntries(tb testing.TB, r redis.UniversalClient, qname string) []base.Z { tb.Helper() - return getMessagesFromZSetWithScores(tb, r, qname, base.DeadlinesKey) + return getMessagesFromZSetWithScores(tb, r, qname, base.DeadlinesKey, "active") } // Retrieves all messages stored under `keyFn(qname)` key in redis list. -func getMessagesFromList(tb testing.TB, r redis.UniversalClient, qname string, keyFn func(qname string) string) []*base.TaskMessage { +func getMessagesFromList(tb testing.TB, r redis.UniversalClient, qname string, + keyFn func(qname string) string, state string) []*base.TaskMessage { tb.Helper() ids := r.LRange(keyFn(qname), 0, -1).Val() var msgs []*base.TaskMessage for _, id := range ids { - data := r.HGet(base.TaskKey(qname, id), "msg").Val() + taskKey := base.TaskKey(qname, id) + data := r.HGet(taskKey, "msg").Val() msgs = append(msgs, MustUnmarshal(tb, data)) + if gotState := r.HGet(taskKey, "state").Val(); gotState != state { + tb.Errorf("task (id=%q) is in %q state, want %q", id, gotState, state) + } } return msgs } // Retrieves all messages stored under `keyFn(qname)` key in redis zset (sorted-set). -func getMessagesFromZSet(tb testing.TB, r redis.UniversalClient, qname string, keyFn func(qname string) string) []*base.TaskMessage { +func getMessagesFromZSet(tb testing.TB, r redis.UniversalClient, qname string, + keyFn func(qname string) string, state string) []*base.TaskMessage { tb.Helper() ids := r.ZRange(keyFn(qname), 0, -1).Val() var msgs []*base.TaskMessage for _, id := range ids { - msg := r.HGet(base.TaskKey(qname, id), "msg").Val() + taskKey := base.TaskKey(qname, id) + msg := r.HGet(taskKey, "msg").Val() msgs = append(msgs, MustUnmarshal(tb, msg)) + if gotState := r.HGet(taskKey, "state").Val(); gotState != state { + tb.Errorf("task (id=%q) is in %q state, want %q", id, gotState, state) + } } return msgs } // Retrieves all messages along with their scores stored under `keyFn(qname)` key in redis zset (sorted-set). -func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient, qname string, keyFn func(qname string) string) []base.Z { +func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient, + qname string, keyFn func(qname string) string, state string) []base.Z { tb.Helper() zs := r.ZRangeWithScores(keyFn(qname), 0, -1).Val() var res []base.Z for _, z := range zs { - msg := r.HGet(base.TaskKey(qname, z.Member.(string)), "msg").Val() + taskID := z.Member.(string) + taskKey := base.TaskKey(qname, taskID) + msg := r.HGet(taskKey, "msg").Val() res = append(res, base.Z{Message: MustUnmarshal(tb, msg), Score: int64(z.Score)}) + if gotState := r.HGet(taskKey, "state").Val(); gotState != state { + tb.Errorf("task (id=%q) is in state %q, want %q", taskID, gotState, state) + } } return res } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 04e9f3b..fbb1825 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -556,10 +556,7 @@ func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { // and move any tasks that are ready to be processed to the pending set. func (r *RDB) ForwardIfReady(qnames ...string) error { for _, qname := range qnames { - if err := r.forwardAll(base.ScheduledKey(qname), base.PendingKey(qname)); err != nil { - return err - } - if err := r.forwardAll(base.RetryKey(qname), base.PendingKey(qname)); err != nil { + if err := r.forwardAll(qname); err != nil { return err } } @@ -567,36 +564,43 @@ func (r *RDB) ForwardIfReady(qnames ...string) error { } // KEYS[1] -> source queue (e.g. asynq:{:scheduled or asynq:{}:retry}) -// KEYS[2] -> destination queue (e.g. asynq:{}) +// KEYS[2] -> asynq:{}:pending // ARGV[1] -> current unix time +// ARGV[2] -> task key prefix // Note: Script moves tasks up to 100 at a time to keep the runtime of script short. var forwardCmd = redis.NewScript(` local ids = redis.call("ZRANGEBYSCORE", KEYS[1], "-inf", ARGV[1], "LIMIT", 0, 100) for _, id in ipairs(ids) do redis.call("LPUSH", KEYS[2], id) redis.call("ZREM", KEYS[1], id) + redis.call("HSET", ARGV[2] .. id, "state", "pending") end return table.getn(ids)`) // forward moves tasks with a score less than the current unix time // from the src zset to the dst list. It returns the number of tasks moved. -func (r *RDB) forward(src, dst string) (int, error) { +func (r *RDB) forward(src, dst, taskKeyPrefix string) (int, error) { now := float64(time.Now().Unix()) - res, err := forwardCmd.Run(r.client, []string{src, dst}, now).Result() + res, err := forwardCmd.Run(r.client, []string{src, dst}, now, taskKeyPrefix).Result() if err != nil { return 0, err } return cast.ToInt(res), nil } -// forwardAll moves tasks with a score less than the current unix time from the src zset, -// until there's no more tasks. -func (r *RDB) forwardAll(src, dst string) (err error) { - n := 1 - for n != 0 { - n, err = r.forward(src, dst) - if err != nil { - return err +// forwardAll checks for tasks in scheduled/retry state that are ready to be run, and updates +// their state to "pending". +func (r *RDB) forwardAll(qname string) (err error) { + sources := []string{base.ScheduledKey(qname), base.RetryKey(qname)} + dst := base.PendingKey(qname) + taskKeyPrefix := base.TaskKeyPrefix(qname) + for _, src := range sources { + n := 1 + for n != 0 { + n, err = r.forward(src, dst, taskKeyPrefix) + if err != nil { + return err + } } } return nil