diff --git a/internal/base/base.go b/internal/base/base.go index 185a839..a03d381 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -267,7 +267,7 @@ func (c *Cancelations) GetAll() []context.CancelFunc { type Broker interface { Enqueue(msg *TaskMessage) error EnqueueUnique(msg *TaskMessage, ttl time.Duration) error - Dequeue(qnames ...string) (*TaskMessage, error) + Dequeue(qnames ...string) (*TaskMessage, int, error) Done(msg *TaskMessage) error Requeue(msg *TaskMessage) error Schedule(msg *TaskMessage, processAt time.Time) error diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 9314358..0d0a7bb 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -102,22 +102,26 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { return nil } -// Dequeue queries given queues in order and pops a task message if there is one and returns it. +// Dequeue queries given queues in order and pops a task message +// off a queue if one exists and returns the message and deadline in Unix time in seconds. // Dequeue skips a queue if the queue is paused. // If all queues are empty, ErrNoProcessableTask error is returned. -func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) { +func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, deadline int, err error) { var qkeys []interface{} for _, q := range qnames { qkeys = append(qkeys, base.QueueKey(q)) } - data, err := r.dequeue(qkeys...) + data, deadline, err := r.dequeue(qkeys...) if err == redis.Nil { - return nil, ErrNoProcessableTask + return nil, 0, ErrNoProcessableTask } if err != nil { - return nil, err + return nil, 0, err } - return base.DecodeMessage(data) + if msg, err = base.DecodeMessage(data); err != nil { + return nil, 0, err + } + return msg, deadline, nil } // KEYS[1] -> asynq:in_progress @@ -134,9 +138,9 @@ var dequeueCmd = redis.NewScript(` for i = 2, table.getn(ARGV) do local qkey = ARGV[i] if redis.call("SISMEMBER", KEYS[2], qkey) == 0 then - local res = redis.call("RPOPLPUSH", qkey, KEYS[1]) - if res then - local decoded = cjson.decode(res) + local msg = redis.call("RPOPLPUSH", qkey, KEYS[1]) + if msg then + local decoded = cjson.decode(msg) local timeout = decoded["Timeout"] local deadline = decoded["Deadline"] local score @@ -149,23 +153,36 @@ for i = 2, table.getn(ARGV) do else return redis.error_reply("asynq internal error: both timeout and deadline are not set") end - redis.call("ZADD", KEYS[3], score, res) - return res + redis.call("ZADD", KEYS[3], score, msg) + return {msg, score} end end end return nil`) -func (r *RDB) dequeue(qkeys ...interface{}) (data string, err error) { +func (r *RDB) dequeue(qkeys ...interface{}) (msgjson string, deadline int, err error) { var args []interface{} args = append(args, time.Now().Unix()) args = append(args, qkeys...) res, err := dequeueCmd.Run(r.client, []string{base.InProgressQueue, base.PausedQueues, base.KeyDeadlines}, args...).Result() if err != nil { - return "", err + return "", 0, err } - return cast.ToStringE(res) + data, err := cast.ToSliceE(res) + if err != nil { + return "", 0, err + } + if len(data) != 2 { + return "", 0, fmt.Errorf("asynq: internal error: dequeue command returned %v values", len(data)) + } + if msgjson, err = cast.ToStringE(data[0]); err != nil { + return "", 0, err + } + if deadline, err = cast.ToIntE(data[1]); err != nil { + return "", 0, err + } + return msgjson, deadline, nil } // KEYS[1] -> asynq:in_progress diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 3732feb..7c2d1ed 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -143,7 +143,8 @@ func TestDequeue(t *testing.T) { tests := []struct { enqueued map[string][]*base.TaskMessage args []string // list of queues to query - want *base.TaskMessage + wantMsg *base.TaskMessage + wantDeadline int err error wantEnqueued map[string][]*base.TaskMessage wantInProgress []*base.TaskMessage @@ -153,9 +154,10 @@ func TestDequeue(t *testing.T) { enqueued: map[string][]*base.TaskMessage{ "default": {t1}, }, - args: []string{"default"}, - want: t1, - err: nil, + args: []string{"default"}, + wantMsg: t1, + wantDeadline: t1Deadline, + err: nil, wantEnqueued: map[string][]*base.TaskMessage{ "default": {}, }, @@ -171,9 +173,10 @@ func TestDequeue(t *testing.T) { enqueued: map[string][]*base.TaskMessage{ "default": {}, }, - args: []string{"default"}, - want: nil, - err: ErrNoProcessableTask, + args: []string{"default"}, + wantMsg: nil, + wantDeadline: 0, + err: ErrNoProcessableTask, wantEnqueued: map[string][]*base.TaskMessage{ "default": {}, }, @@ -186,9 +189,10 @@ func TestDequeue(t *testing.T) { "critical": {t2}, "low": {t3}, }, - args: []string{"critical", "default", "low"}, - want: t2, - err: nil, + args: []string{"critical", "default", "low"}, + wantMsg: t2, + wantDeadline: t2Deadline, + err: nil, wantEnqueued: map[string][]*base.TaskMessage{ "default": {t1}, "critical": {}, @@ -208,9 +212,10 @@ func TestDequeue(t *testing.T) { "critical": {}, "low": {t2, t1}, }, - args: []string{"critical", "default", "low"}, - want: t3, - err: nil, + args: []string{"critical", "default", "low"}, + wantMsg: t3, + wantDeadline: t3Deadline, + err: nil, wantEnqueued: map[string][]*base.TaskMessage{ "default": {}, "critical": {}, @@ -230,9 +235,10 @@ func TestDequeue(t *testing.T) { "critical": {}, "low": {}, }, - args: []string{"critical", "default", "low"}, - want: nil, - err: ErrNoProcessableTask, + args: []string{"critical", "default", "low"}, + wantMsg: nil, + wantDeadline: 0, + err: ErrNoProcessableTask, wantEnqueued: map[string][]*base.TaskMessage{ "default": {}, "critical": {}, @@ -249,10 +255,20 @@ func TestDequeue(t *testing.T) { h.SeedEnqueuedQueue(t, r.client, msgs, queue) } - got, err := r.Dequeue(tc.args...) - if !cmp.Equal(got, tc.want) || err != tc.err { - t.Errorf("(*RDB).Dequeue(%v) = %v, %v; want %v, %v", - tc.args, got, err, tc.want, tc.err) + gotMsg, gotDeadline, err := r.Dequeue(tc.args...) + if err != tc.err { + t.Errorf("(*RDB).Dequeue(%v) returned error %v; want %v", + tc.args, err, tc.err) + continue + } + if !cmp.Equal(gotMsg, tc.wantMsg) || err != tc.err { + t.Errorf("(*RDB).Dequeue(%v) returned message %v; want %v", + tc.args, gotMsg, tc.wantMsg) + continue + } + if gotDeadline != tc.wantDeadline { + t.Errorf("(*RDB).Dequeue(%v) returned deadline %v; want %v", + tc.args, gotDeadline, tc.wantDeadline) continue } @@ -284,7 +300,7 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { paused []string // list of paused queues enqueued map[string][]*base.TaskMessage args []string // list of queues to query - want *base.TaskMessage + wantMsg *base.TaskMessage err error wantEnqueued map[string][]*base.TaskMessage wantInProgress []*base.TaskMessage @@ -295,9 +311,9 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { "default": {t1}, "critical": {t2}, }, - args: []string{"default", "critical"}, - want: t2, - err: nil, + args: []string{"default", "critical"}, + wantMsg: t2, + err: nil, wantEnqueued: map[string][]*base.TaskMessage{ "default": {t1}, "critical": {}, @@ -309,9 +325,9 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { enqueued: map[string][]*base.TaskMessage{ "default": {t1}, }, - args: []string{"default"}, - want: nil, - err: ErrNoProcessableTask, + args: []string{"default"}, + wantMsg: nil, + err: ErrNoProcessableTask, wantEnqueued: map[string][]*base.TaskMessage{ "default": {t1}, }, @@ -323,9 +339,9 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { "default": {t1}, "critical": {t2}, }, - args: []string{"default", "critical"}, - want: nil, - err: ErrNoProcessableTask, + args: []string{"default", "critical"}, + wantMsg: nil, + err: ErrNoProcessableTask, wantEnqueued: map[string][]*base.TaskMessage{ "default": {t1}, "critical": {t2}, @@ -345,10 +361,10 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { h.SeedEnqueuedQueue(t, r.client, msgs, queue) } - got, err := r.Dequeue(tc.args...) - if !cmp.Equal(got, tc.want) || err != tc.err { + got, _, err := r.Dequeue(tc.args...) + if !cmp.Equal(got, tc.wantMsg) || err != tc.err { t.Errorf("Dequeue(%v) = %v, %v; want %v, %v", - tc.args, got, err, tc.want, tc.err) + tc.args, got, err, tc.wantMsg, tc.err) continue } diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 08e407f..d577fd9 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -60,11 +60,11 @@ func (tb *TestBroker) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) er return tb.real.EnqueueUnique(msg, ttl) } -func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, error) { +func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, int, error) { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { - return nil, errRedisDown + return nil, 0, errRedisDown } return tb.real.Dequeue(qnames...) }