diff --git a/internal/base/base.go b/internal/base/base.go index 615ed75..4df0782 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -684,13 +684,13 @@ type Broker interface { Enqueue(ctx context.Context, msg *TaskMessage) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error Dequeue(qnames ...string) (*TaskMessage, time.Time, error) - Done(msg *TaskMessage) error - MarkAsComplete(msg *TaskMessage) error - Requeue(msg *TaskMessage) error + Done(ctx context.Context, msg *TaskMessage) error + MarkAsComplete(ctx context.Context, msg *TaskMessage) error + Requeue(ctx context.Context, msg *TaskMessage) error Schedule(ctx context.Context, msg *TaskMessage, processAt time.Time) error ScheduleUnique(ctx context.Context, msg *TaskMessage, processAt time.Time, ttl time.Duration) error - Retry(msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error - Archive(msg *TaskMessage, errMsg string) error + Retry(ctx context.Context, msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error + Archive(ctx context.Context, msg *TaskMessage, errMsg string) error ForwardIfReady(qnames ...string) error DeleteExpiredCompletedTasks(qname string) error ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*TaskMessage, error) diff --git a/internal/rdb/benchmark_test.go b/internal/rdb/benchmark_test.go index 3e28564..011d06d 100644 --- a/internal/rdb/benchmark_test.go +++ b/internal/rdb/benchmark_test.go @@ -113,7 +113,7 @@ func BenchmarkDequeueSingleQueue(b *testing.B) { } b.StartTimer() - if _, err := r.Dequeue(base.DefaultQueueName); err != nil { + if _, _, err := r.Dequeue(base.DefaultQueueName); err != nil { b.Fatalf("Dequeue failed: %v", err) } } @@ -139,7 +139,7 @@ func BenchmarkDequeueMultipleQueues(b *testing.B) { } b.StartTimer() - if _, err := r.Dequeue(qnames...); err != nil { + if _, _, err := r.Dequeue(qnames...); err != nil { b.Fatalf("Dequeue failed: %v", err) } } @@ -156,6 +156,7 @@ func BenchmarkDone(b *testing.B) { {Message: m2, Score: time.Now().Add(20 * time.Second).Unix()}, {Message: m3, Score: time.Now().Add(30 * time.Second).Unix()}, } + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -165,7 +166,7 @@ func BenchmarkDone(b *testing.B) { asynqtest.SeedDeadlines(b, r.client, zs, base.DefaultQueueName) b.StartTimer() - if err := r.Done(msgs[0]); err != nil { + if err := r.Done(ctx, msgs[0]); err != nil { b.Fatalf("Done failed: %v", err) } } @@ -182,6 +183,7 @@ func BenchmarkRetry(b *testing.B) { {Message: m2, Score: time.Now().Add(20 * time.Second).Unix()}, {Message: m3, Score: time.Now().Add(30 * time.Second).Unix()}, } + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -191,7 +193,7 @@ func BenchmarkRetry(b *testing.B) { asynqtest.SeedDeadlines(b, r.client, zs, base.DefaultQueueName) b.StartTimer() - if err := r.Retry(msgs[0], time.Now().Add(1*time.Minute), "error", true /*isFailure*/); err != nil { + if err := r.Retry(ctx, msgs[0], time.Now().Add(1*time.Minute), "error", true /*isFailure*/); err != nil { b.Fatalf("Retry failed: %v", err) } } @@ -208,6 +210,7 @@ func BenchmarkArchive(b *testing.B) { {Message: m2, Score: time.Now().Add(20 * time.Second).Unix()}, {Message: m3, Score: time.Now().Add(30 * time.Second).Unix()}, } + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -217,7 +220,7 @@ func BenchmarkArchive(b *testing.B) { asynqtest.SeedDeadlines(b, r.client, zs, base.DefaultQueueName) b.StartTimer() - if err := r.Archive(msgs[0], "error"); err != nil { + if err := r.Archive(ctx, msgs[0], "error"); err != nil { b.Fatalf("Archive failed: %v", err) } } @@ -234,6 +237,7 @@ func BenchmarkRequeue(b *testing.B) { {Message: m2, Score: time.Now().Add(20 * time.Second).Unix()}, {Message: m3, Score: time.Now().Add(30 * time.Second).Unix()}, } + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -243,7 +247,7 @@ func BenchmarkRequeue(b *testing.B) { asynqtest.SeedDeadlines(b, r.client, zs, base.DefaultQueueName) b.StartTimer() - if err := r.Requeue(msgs[0]); err != nil { + if err := r.Requeue(ctx, msgs[0]); err != nil { b.Fatalf("Requeue failed: %v", err) } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 715576b..f134a36 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -241,10 +241,10 @@ end return nil`) // Dequeue queries given queues in order and pops a task message -// off a queue if one exists and returns the message. +// off a queue if one exists and returns the message and its lease expiration time. // Dequeue skips a queue if the queue is paused. // If all queues are empty, ErrNoProcessableTask error is returned. -func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, err error) { +func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationTime time.Time, err error) { var op errors.Op = "rdb.Dequeue" for _, qname := range qnames { keys := []string{ @@ -253,26 +253,27 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, err error) { base.ActiveKey(qname), base.LeaseKey(qname), } + leaseExpirationTime = r.clock.Now().Add(LeaseDuration) argv := []interface{}{ - r.clock.Now().Add(LeaseDuration).Unix(), + leaseExpirationTime.Unix(), base.TaskKeyPrefix(qname), } res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result() if err == redis.Nil { continue } else if err != nil { - return nil, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) + return nil, time.Time{}, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) } encoded, err := cast.ToStringE(res) if err != nil { - return nil, errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res)) + return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res)) } if msg, err = base.DecodeMessage([]byte(encoded)); err != nil { - return nil, errors.E(op, errors.Internal, fmt.Sprintf("cannot decode message: %v", err)) + return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cannot decode message: %v", err)) } - return msg, nil + return msg, leaseExpirationTime, nil } - return nil, errors.E(op, errors.NotFound, errors.ErrNoProcessableTask) + return nil, time.Time{}, errors.E(op, errors.NotFound, errors.ErrNoProcessableTask) } // KEYS[1] -> asynq:{}:active @@ -345,9 +346,8 @@ return redis.status_reply("OK") // Done removes the task from active queue and deletes the task. // It removes a uniqueness lock acquired by the task, if any. -func (r *RDB) Done(msg *base.TaskMessage) error { +func (r *RDB) Done(ctx context.Context, msg *base.TaskMessage) error { var op errors.Op = "rdb.Done" - ctx := context.Background() now := r.clock.Now() expireAt := now.Add(statsTTL) keys := []string{ @@ -448,9 +448,8 @@ return redis.status_reply("OK") // MarkAsComplete removes the task from active queue to mark the task as completed. // It removes a uniqueness lock acquired by the task, if any. -func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error { +func (r *RDB) MarkAsComplete(ctx context.Context, msg *base.TaskMessage) error { var op errors.Op = "rdb.MarkAsComplete" - ctx := context.Background() now := r.clock.Now() statsExpireAt := now.Add(statsTTL) msg.CompletedAt = now.Unix() @@ -499,9 +498,8 @@ redis.call("HSET", KEYS[4], "state", "pending") return redis.status_reply("OK")`) // Requeue moves the task from active queue to the specified queue. -func (r *RDB) Requeue(msg *base.TaskMessage) error { +func (r *RDB) Requeue(ctx context.Context, msg *base.TaskMessage) error { var op errors.Op = "rdb.Requeue" - ctx := context.Background() keys := []string{ base.ActiveKey(msg.Queue), base.LeaseKey(msg.Queue), @@ -682,9 +680,8 @@ return redis.status_reply("OK")`) // Retry moves the task from active to retry queue. // It also annotates the message with the given error message and // if isFailure is true increments the retried counter. -func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { +func (r *RDB) Retry(ctx context.Context, msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { var op errors.Op = "rdb.Retry" - ctx := context.Background() now := r.clock.Now() modified := *msg if isFailure { @@ -770,9 +767,8 @@ return redis.status_reply("OK")`) // Archive sends the given task to archive, attaching the error message to the task. // It also trims the archive by timestamp and set size. -func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { +func (r *RDB) Archive(ctx context.Context, msg *base.TaskMessage, errMsg string) error { var op errors.Op = "rdb.Archive" - ctx := context.Background() now := r.clock.Now() modified := *msg modified.ErrorMsg = errMsg @@ -959,14 +955,19 @@ func (r *RDB) ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*base.Task } // ExtendLease extends the lease for the given tasks by LeaseDuration (30s). -func (r *RDB) ExtendLease(qname string, ids ...string) error { +// It returns a new expiration time if the operation was successful. +func (r *RDB) ExtendLease(qname string, ids ...string) (expirationTime time.Time, err error) { expireAt := r.clock.Now().Add(LeaseDuration) var zs []redis.Z for _, id := range ids { zs = append(zs, redis.Z{Member: id, Score: float64(expireAt.Unix())}) } // Use XX option to only update elements that already exist; Don't add new elements - return r.client.ZAddArgs(context.Background(), base.LeaseKey(qname), redis.ZAddArgs{XX: true, GT: true, Members: zs}).Err() + err = r.client.ZAddArgs(context.Background(), base.LeaseKey(qname), redis.ZAddArgs{XX: true, GT: true, Members: zs}).Err() + if err != nil { + return time.Time{}, err + } + return expireAt, nil } // KEYS[1] -> asynq:servers:{} diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index ec6517a..e28381d 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -340,19 +340,21 @@ func TestDequeue(t *testing.T) { } tests := []struct { - pending map[string][]*base.TaskMessage - qnames []string // list of queues to query - wantMsg *base.TaskMessage - wantPending map[string][]*base.TaskMessage - wantActive map[string][]*base.TaskMessage - wantLease map[string][]base.Z + pending map[string][]*base.TaskMessage + qnames []string // list of queues to query + wantMsg *base.TaskMessage + wantExpirationTime time.Time + wantPending map[string][]*base.TaskMessage + wantActive map[string][]*base.TaskMessage + wantLease map[string][]base.Z }{ { pending: map[string][]*base.TaskMessage{ "default": {t1}, }, - qnames: []string{"default"}, - wantMsg: t1, + qnames: []string{"default"}, + wantMsg: t1, + wantExpirationTime: now.Add(LeaseDuration), wantPending: map[string][]*base.TaskMessage{ "default": {}, }, @@ -369,8 +371,9 @@ func TestDequeue(t *testing.T) { "critical": {t2}, "low": {t3}, }, - qnames: []string{"critical", "default", "low"}, - wantMsg: t2, + qnames: []string{"critical", "default", "low"}, + wantMsg: t2, + wantExpirationTime: now.Add(LeaseDuration), wantPending: map[string][]*base.TaskMessage{ "default": {t1}, "critical": {}, @@ -393,8 +396,9 @@ func TestDequeue(t *testing.T) { "critical": {}, "low": {t3}, }, - qnames: []string{"critical", "default", "low"}, - wantMsg: t1, + qnames: []string{"critical", "default", "low"}, + wantMsg: t1, + wantExpirationTime: now.Add(LeaseDuration), wantPending: map[string][]*base.TaskMessage{ "default": {}, "critical": {}, @@ -417,7 +421,7 @@ func TestDequeue(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - gotMsg, err := r.Dequeue(tc.qnames...) + gotMsg, gotExpirationTime, err := r.Dequeue(tc.qnames...) if err != nil { t.Errorf("(*RDB).Dequeue(%v) returned error %v", tc.qnames, err) continue @@ -427,6 +431,10 @@ func TestDequeue(t *testing.T) { tc.qnames, gotMsg, tc.wantMsg) continue } + if gotExpirationTime != tc.wantExpirationTime { + t.Errorf("(*RDB).Dequeue(%v) returned expiration time %v, want %v", + tc.qnames, gotExpirationTime, tc.wantExpirationTime) + } for queue, want := range tc.wantPending { gotPending := h.GetPendingMessages(t, r.client, queue) @@ -507,7 +515,7 @@ func TestDequeueError(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - gotMsg, gotErr := r.Dequeue(tc.qnames...) + gotMsg, _, gotErr := r.Dequeue(tc.qnames...) if !errors.Is(gotErr, tc.wantErr) { t.Errorf("(*RDB).Dequeue(%v) returned error %v; want %v", tc.qnames, gotErr, tc.wantErr) @@ -630,7 +638,7 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { } h.SeedAllPendingQueues(t, r.client, tc.pending) - got, err := r.Dequeue(tc.qnames...) + got, _, err := r.Dequeue(tc.qnames...) if !cmp.Equal(got, tc.wantMsg) || !errors.Is(err, tc.wantErr) { t.Errorf("Dequeue(%v) = %v, %v; want %v, %v", tc.qnames, got, err, tc.wantMsg, tc.wantErr) @@ -764,7 +772,7 @@ func TestDone(t *testing.T) { } } - err := r.Done(tc.target) + err := r.Done(context.Background(), tc.target) if err != nil { t.Errorf("%s; (*RDB).Done(task) = %v, want nil", tc.desc, err) continue @@ -833,7 +841,7 @@ func TestDoneWithMaxCounter(t *testing.T) { t.Fatalf("Redis command failed: SET %q %v", processedTotalKey, math.MaxInt64) } - if err := r.Done(msg); err != nil { + if err := r.Done(context.Background(), msg); err != nil { t.Fatalf("RDB.Done failed: %v", err) } @@ -984,7 +992,7 @@ func TestMarkAsComplete(t *testing.T) { } } - err := r.MarkAsComplete(tc.target) + err := r.MarkAsComplete(context.Background(), tc.target) if err != nil { t.Errorf("%s; (*RDB).MarkAsCompleted(task) = %v, want nil", tc.desc, err) continue @@ -1148,7 +1156,7 @@ func TestRequeue(t *testing.T) { h.SeedAllActiveQueues(t, r.client, tc.active) h.SeedAllLease(t, r.client, tc.lease) - err := r.Requeue(tc.target) + err := r.Requeue(context.Background(), tc.target) if err != nil { t.Errorf("(*RDB).Requeue(task) = %v, want nil", err) continue @@ -1529,7 +1537,7 @@ func TestRetry(t *testing.T) { h.SeedAllLease(t, r.client, tc.lease) h.SeedAllRetryQueues(t, r.client, tc.retry) - err := r.Retry(tc.msg, tc.processAt, tc.errMsg, true /*isFailure*/) + err := r.Retry(context.Background(), tc.msg, tc.processAt, tc.errMsg, true /*isFailure*/) if err != nil { t.Errorf("(*RDB).Retry = %v, want nil", err) continue @@ -1702,7 +1710,7 @@ func TestRetryWithNonFailureError(t *testing.T) { h.SeedAllLease(t, r.client, tc.lease) h.SeedAllRetryQueues(t, r.client, tc.retry) - err := r.Retry(tc.msg, tc.processAt, tc.errMsg, false /*isFailure*/) + err := r.Retry(context.Background(), tc.msg, tc.processAt, tc.errMsg, false /*isFailure*/) if err != nil { t.Errorf("(*RDB).Retry = %v, want nil", err) continue @@ -1908,7 +1916,7 @@ func TestArchive(t *testing.T) { h.SeedAllLease(t, r.client, tc.lease) h.SeedAllArchivedQueues(t, r.client, tc.archived) - err := r.Archive(tc.target, errMsg) + err := r.Archive(context.Background(), tc.target, errMsg) if err != nil { t.Errorf("(*RDB).Archive(%v, %v) = %v, want nil", tc.target, errMsg, err) continue @@ -2304,11 +2312,12 @@ func TestExtendLease(t *testing.T) { t4 := h.NewTaskMessageWithQueue("task4", nil, "default") tests := []struct { - desc string - lease map[string][]base.Z - qname string - ids []string - wantLease map[string][]base.Z + desc string + lease map[string][]base.Z + qname string + ids []string + wantExpirationTime time.Time + wantLease map[string][]base.Z }{ { desc: "Should extends lease for a single message in a queue", @@ -2316,8 +2325,9 @@ func TestExtendLease(t *testing.T) { "default": {{Message: t1, Score: now.Add(10 * time.Second).Unix()}}, "critical": {{Message: t3, Score: now.Add(10 * time.Second).Unix()}}, }, - qname: "default", - ids: []string{t1.ID}, + qname: "default", + ids: []string{t1.ID}, + wantExpirationTime: now.Add(LeaseDuration), wantLease: map[string][]base.Z{ "default": {{Message: t1, Score: now.Add(LeaseDuration).Unix()}}, "critical": {{Message: t3, Score: now.Add(10 * time.Second).Unix()}}, @@ -2329,8 +2339,9 @@ func TestExtendLease(t *testing.T) { "default": {{Message: t1, Score: now.Add(10 * time.Second).Unix()}, {Message: t2, Score: now.Add(10 * time.Second).Unix()}}, "critical": {{Message: t3, Score: now.Add(10 * time.Second).Unix()}}, }, - qname: "default", - ids: []string{t1.ID, t2.ID}, + qname: "default", + ids: []string{t1.ID, t2.ID}, + wantExpirationTime: now.Add(LeaseDuration), wantLease: map[string][]base.Z{ "default": {{Message: t1, Score: now.Add(LeaseDuration).Unix()}, {Message: t2, Score: now.Add(LeaseDuration).Unix()}}, "critical": {{Message: t3, Score: now.Add(10 * time.Second).Unix()}}, @@ -2346,8 +2357,9 @@ func TestExtendLease(t *testing.T) { }, "critical": {{Message: t3, Score: now.Add(10 * time.Second).Unix()}}, }, - qname: "default", - ids: []string{t2.ID, t4.ID}, + qname: "default", + ids: []string{t2.ID, t4.ID}, + wantExpirationTime: now.Add(LeaseDuration), wantLease: map[string][]base.Z{ "default": { {Message: t1, Score: now.Add(10 * time.Second).Unix()}, @@ -2364,8 +2376,9 @@ func TestExtendLease(t *testing.T) { {Message: t1, Score: now.Add(10 * time.Second).Unix()}, }, }, - qname: "default", - ids: []string{t1.ID, t2.ID}, + qname: "default", + ids: []string{t1.ID, t2.ID}, + wantExpirationTime: now.Add(LeaseDuration), wantLease: map[string][]base.Z{ "default": { {Message: t1, Score: now.Add(LeaseDuration).Unix()}, @@ -2379,8 +2392,9 @@ func TestExtendLease(t *testing.T) { {Message: t1, Score: now.Add(LeaseDuration).Add(10 * time.Second).Unix()}, }, }, - qname: "default", - ids: []string{t1.ID}, + qname: "default", + ids: []string{t1.ID}, + wantExpirationTime: now.Add(LeaseDuration), wantLease: map[string][]base.Z{ "default": { {Message: t1, Score: now.Add(LeaseDuration).Add(10 * time.Second).Unix()}, @@ -2393,9 +2407,13 @@ func TestExtendLease(t *testing.T) { h.FlushDB(t, r.client) h.SeedAllLease(t, r.client, tc.lease) - if err := r.ExtendLease(tc.qname, tc.ids...); err != nil { + gotExpirationTime, err := r.ExtendLease(tc.qname, tc.ids...) + if err != nil { t.Fatalf("%s: ExtendLease(%q, %v) returned error: %v", tc.desc, tc.qname, tc.ids, err) } + if gotExpirationTime != tc.wantExpirationTime { + t.Errorf("%s: ExtendLease(%q, %v) returned expirationTime %v, want %v", tc.desc, tc.qname, tc.ids, gotExpirationTime, tc.wantExpirationTime) + } for qname, want := range tc.wantLease { gotLease := h.GetLeaseEntries(t, r.client, qname)