diff --git a/CHANGELOG.md b/CHANGELOG.md index fa5569d..e28b9ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- `EnqueueContext` method is added to `Client`. + +### Fixed + +- Fixed an error when user pass a duration less than 1s to `Unique` option + ## [0.19.0] - 2021-11-06 ### Changed diff --git a/client.go b/client.go index 61db5bf..078116e 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package asynq import ( + "context" "fmt" "strings" "time" @@ -292,7 +293,7 @@ func (c *Client) Close() error { return c.rdb.Close() } -// Enqueue enqueues the given task to be processed asynchronously. +// Enqueue enqueues the given task to a queue. // // Enqueue returns TaskInfo and nil error if the task is enqueued successfully, otherwise returns a non-nil error. // @@ -302,7 +303,25 @@ func (c *Client) Close() error { // By deafult, max retry is set to 25 and timeout is set to 30 minutes. // // If no ProcessAt or ProcessIn options are provided, the task will be pending immediately. +// +// Enqueue uses context.Background internally; to specify the context, use EnqueueContext. func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { + return c.EnqueueContext(context.Background(), task, opts...) +} + +// EnqueueContext enqueues the given task to a queue. +// +// EnqueueContext returns TaskInfo and nil error if the task is enqueued successfully, otherwise returns a non-nil error. +// +// The argument opts specifies the behavior of task processing. +// If there are conflicting Option values the last one overrides others. +// Any options provided to NewTask can be overridden by options passed to Enqueue. +// By deafult, max retry is set to 25 and timeout is set to 30 minutes. +// +// If no ProcessAt or ProcessIn options are provided, the task will be pending immediately. +// +// The first argument context applies to the enqueue operation. To specify task timeout and deadline, use Timeout and Deadline option instead. +func (c *Client) EnqueueContext(ctx context.Context, task *Task, opts ...Option) (*TaskInfo, error) { if strings.TrimSpace(task.Type()) == "" { return nil, fmt.Errorf("task typename cannot be empty") } @@ -343,10 +362,10 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { var state base.TaskState if opt.processAt.Before(now) || opt.processAt.Equal(now) { opt.processAt = now - err = c.enqueue(msg, opt.uniqueTTL) + err = c.enqueue(ctx, msg, opt.uniqueTTL) state = base.TaskStatePending } else { - err = c.schedule(msg, opt.processAt, opt.uniqueTTL) + err = c.schedule(ctx, msg, opt.processAt, opt.uniqueTTL) state = base.TaskStateScheduled } switch { @@ -360,17 +379,17 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { return newTaskInfo(msg, state, opt.processAt, nil), nil } -func (c *Client) enqueue(msg *base.TaskMessage, uniqueTTL time.Duration) error { +func (c *Client) enqueue(ctx context.Context, msg *base.TaskMessage, uniqueTTL time.Duration) error { if uniqueTTL > 0 { - return c.rdb.EnqueueUnique(msg, uniqueTTL) + return c.rdb.EnqueueUnique(ctx, msg, uniqueTTL) } - return c.rdb.Enqueue(msg) + return c.rdb.Enqueue(ctx, msg) } -func (c *Client) schedule(msg *base.TaskMessage, t time.Time, uniqueTTL time.Duration) error { +func (c *Client) schedule(ctx context.Context, msg *base.TaskMessage, t time.Time, uniqueTTL time.Duration) error { if uniqueTTL > 0 { ttl := t.Add(uniqueTTL).Sub(time.Now()) - return c.rdb.ScheduleUnique(msg, t, ttl) + return c.rdb.ScheduleUnique(ctx, msg, t, ttl) } - return c.rdb.Schedule(msg, t) + return c.rdb.Schedule(ctx, msg, t) } diff --git a/internal/base/base.go b/internal/base/base.go index 5315f17..2ef52fe 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -660,14 +660,14 @@ func (c *Cancelations) Get(id string) (fn context.CancelFunc, ok bool) { // See rdb.RDB as a reference implementation. type Broker interface { Ping() error - Enqueue(msg *TaskMessage) error - EnqueueUnique(msg *TaskMessage, ttl time.Duration) error + Enqueue(ctx context.Context, msg *TaskMessage) error + EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Done(msg *TaskMessage) error MarkAsComplete(msg *TaskMessage) error Requeue(msg *TaskMessage) error - Schedule(msg *TaskMessage, processAt time.Time) error - ScheduleUnique(msg *TaskMessage, processAt time.Time, ttl time.Duration) error + Schedule(ctx context.Context, msg *TaskMessage, processAt time.Time) error + ScheduleUnique(ctx context.Context, msg *TaskMessage, processAt time.Time, ttl time.Duration) error Retry(msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error Archive(msg *TaskMessage, errMsg string) error ForwardIfReady(qnames ...string) error diff --git a/internal/rdb/benchmark_test.go b/internal/rdb/benchmark_test.go index 8880b3b..468e309 100644 --- a/internal/rdb/benchmark_test.go +++ b/internal/rdb/benchmark_test.go @@ -5,6 +5,7 @@ package rdb import ( + "context" "fmt" "testing" "time" @@ -15,6 +16,7 @@ import ( func BenchmarkEnqueue(b *testing.B) { r := setup(b) + ctx := context.Background() msg := asynqtest.NewTaskMessage("task1", nil) b.ResetTimer() @@ -23,7 +25,7 @@ func BenchmarkEnqueue(b *testing.B) { asynqtest.FlushDB(b, r.client) b.StartTimer() - if err := r.Enqueue(msg); err != nil { + if err := r.Enqueue(ctx, msg); err != nil { b.Fatalf("Enqueue failed: %v", err) } } @@ -31,6 +33,7 @@ func BenchmarkEnqueue(b *testing.B) { func BenchmarkEnqueueUnique(b *testing.B) { r := setup(b) + ctx := context.Background() msg := &base.TaskMessage{ Type: "task1", Payload: nil, @@ -45,7 +48,7 @@ func BenchmarkEnqueueUnique(b *testing.B) { asynqtest.FlushDB(b, r.client) b.StartTimer() - if err := r.EnqueueUnique(msg, uniqueTTL); err != nil { + if err := r.EnqueueUnique(ctx, msg, uniqueTTL); err != nil { b.Fatalf("EnqueueUnique failed: %v", err) } } @@ -53,6 +56,7 @@ func BenchmarkEnqueueUnique(b *testing.B) { func BenchmarkSchedule(b *testing.B) { r := setup(b) + ctx := context.Background() msg := asynqtest.NewTaskMessage("task1", nil) processAt := time.Now().Add(3 * time.Minute) b.ResetTimer() @@ -62,7 +66,7 @@ func BenchmarkSchedule(b *testing.B) { asynqtest.FlushDB(b, r.client) b.StartTimer() - if err := r.Schedule(msg, processAt); err != nil { + if err := r.Schedule(ctx, msg, processAt); err != nil { b.Fatalf("Schedule failed: %v", err) } } @@ -70,6 +74,7 @@ func BenchmarkSchedule(b *testing.B) { func BenchmarkScheduleUnique(b *testing.B) { r := setup(b) + ctx := context.Background() msg := &base.TaskMessage{ Type: "task1", Payload: nil, @@ -85,7 +90,7 @@ func BenchmarkScheduleUnique(b *testing.B) { asynqtest.FlushDB(b, r.client) b.StartTimer() - if err := r.ScheduleUnique(msg, processAt, uniqueTTL); err != nil { + if err := r.ScheduleUnique(ctx, msg, processAt, uniqueTTL); err != nil { b.Fatalf("EnqueueUnique failed: %v", err) } } @@ -93,6 +98,7 @@ func BenchmarkScheduleUnique(b *testing.B) { func BenchmarkDequeueSingleQueue(b *testing.B) { r := setup(b) + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -101,7 +107,7 @@ func BenchmarkDequeueSingleQueue(b *testing.B) { for i := 0; i < 10; i++ { m := asynqtest.NewTaskMessageWithQueue( fmt.Sprintf("task%d", i), nil, base.DefaultQueueName) - if err := r.Enqueue(m); err != nil { + if err := r.Enqueue(ctx, m); err != nil { b.Fatalf("Enqueue failed: %v", err) } } @@ -116,6 +122,7 @@ func BenchmarkDequeueSingleQueue(b *testing.B) { func BenchmarkDequeueMultipleQueues(b *testing.B) { qnames := []string{"critical", "default", "low"} r := setup(b) + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -125,7 +132,7 @@ func BenchmarkDequeueMultipleQueues(b *testing.B) { for _, qname := range qnames { m := asynqtest.NewTaskMessageWithQueue( fmt.Sprintf("%s_task%d", qname, i), nil, qname) - if err := r.Enqueue(m); err != nil { + if err := r.Enqueue(ctx, m); err != nil { b.Fatalf("Enqueue failed: %v", err) } } diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 8b380ab..1a9f49f 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -879,7 +879,7 @@ func TestListScheduledPagination(t *testing.T) { // create 100 tasks with an increasing number of wait time. for i := 0; i < 100; i++ { msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) - if err := r.Schedule(msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { + if err := r.Schedule(context.Background(), msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { t.Fatal(err) } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 3a611d7..a2a0fe7 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -43,16 +43,16 @@ func (r *RDB) Ping() error { return r.client.Ping(context.Background()).Err() } -func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args ...interface{}) error { - if err := script.Run(context.Background(), r.client, keys, args...).Err(); err != nil { +func (r *RDB) runScript(ctx context.Context, op errors.Op, script *redis.Script, keys []string, args ...interface{}) error { + if err := script.Run(ctx, r.client, keys, args...).Err(); err != nil { return errors.E(op, errors.Internal, fmt.Sprintf("redis eval error: %v", err)) } return nil } // Runs the given script with keys and args and retuns the script's return value as int64. -func (r *RDB) runScriptWithErrorCode(op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) { - res, err := script.Run(context.Background(), r.client, keys, args...).Result() +func (r *RDB) runScriptWithErrorCode(ctx context.Context, op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) { + res, err := script.Run(ctx, r.client, keys, args...).Result() if err != nil { return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) } @@ -91,13 +91,13 @@ return 1 `) // Enqueue adds the given task to the pending list of the queue. -func (r *RDB) Enqueue(msg *base.TaskMessage) error { +func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { var op errors.Op = "rdb.Enqueue" encoded, err := base.EncodeMessage(msg) if err != nil { return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) } - if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) } keys := []string{ @@ -110,7 +110,7 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error { msg.Timeout, msg.Deadline, } - n, err := r.runScriptWithErrorCode(op, enqueueCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(ctx, op, enqueueCmd, keys, argv...) if err != nil { return err } @@ -156,13 +156,13 @@ return 1 // EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired. -func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { +func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time.Duration) error { var op errors.Op = "rdb.EnqueueUnique" encoded, err := base.EncodeMessage(msg) if err != nil { return errors.E(op, errors.Internal, "cannot encode task message: %v", err) } - if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) } keys := []string{ @@ -177,7 +177,7 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { msg.Timeout, msg.Deadline, } - n, err := r.runScriptWithErrorCode(op, enqueueUniqueCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(ctx, op, enqueueUniqueCmd, keys, argv...) if err != nil { return err } @@ -334,6 +334,7 @@ return redis.status_reply("OK") // It removes a uniqueness lock acquired by the task, if any. func (r *RDB) Done(msg *base.TaskMessage) error { var op errors.Op = "rdb.Done" + ctx := context.Background() now := time.Now() expireAt := now.Add(statsTTL) keys := []string{ @@ -349,9 +350,9 @@ func (r *RDB) Done(msg *base.TaskMessage) error { // Note: We cannot pass empty unique key when running this script in redis-cluster. if len(msg.UniqueKey) > 0 { keys = append(keys, msg.UniqueKey) - return r.runScript(op, doneUniqueCmd, keys, argv...) + return r.runScript(ctx, op, doneUniqueCmd, keys, argv...) } - return r.runScript(op, doneCmd, keys, argv...) + return r.runScript(ctx, op, doneCmd, keys, argv...) } // KEYS[1] -> asynq:{}:active @@ -416,6 +417,7 @@ return redis.status_reply("OK") // It removes a uniqueness lock acquired by the task, if any. func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error { var op errors.Op = "rdb.MarkAsComplete" + ctx := context.Background() now := time.Now() statsExpireAt := now.Add(statsTTL) msg.CompletedAt = now.Unix() @@ -439,9 +441,9 @@ func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error { // Note: We cannot pass empty unique key when running this script in redis-cluster. if len(msg.UniqueKey) > 0 { keys = append(keys, msg.UniqueKey) - return r.runScript(op, markAsCompleteUniqueCmd, keys, argv...) + return r.runScript(ctx, op, markAsCompleteUniqueCmd, keys, argv...) } - return r.runScript(op, markAsCompleteCmd, keys, argv...) + return r.runScript(ctx, op, markAsCompleteCmd, keys, argv...) } // KEYS[1] -> asynq:{}:active @@ -464,13 +466,14 @@ return redis.status_reply("OK")`) // Requeue moves the task from active queue to the specified queue. func (r *RDB) Requeue(msg *base.TaskMessage) error { var op errors.Op = "rdb.Requeue" + ctx := context.Background() keys := []string{ base.ActiveKey(msg.Queue), base.DeadlinesKey(msg.Queue), base.PendingKey(msg.Queue), base.TaskKey(msg.Queue, msg.ID), } - return r.runScript(op, requeueCmd, keys, msg.ID) + return r.runScript(ctx, op, requeueCmd, keys, msg.ID) } // KEYS[1] -> asynq:{}:t: @@ -498,13 +501,13 @@ return 1 `) // Schedule adds the task to the scheduled set to be processed in the future. -func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { +func (r *RDB) Schedule(ctx context.Context, msg *base.TaskMessage, processAt time.Time) error { var op errors.Op = "rdb.Schedule" encoded, err := base.EncodeMessage(msg) if err != nil { return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) } - if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) } keys := []string{ @@ -518,7 +521,7 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { msg.Timeout, msg.Deadline, } - n, err := r.runScriptWithErrorCode(op, scheduleCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(ctx, op, scheduleCmd, keys, argv...) if err != nil { return err } @@ -562,13 +565,13 @@ return 1 // ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired. -func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { +func (r *RDB) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { var op errors.Op = "rdb.ScheduleUnique" encoded, err := base.EncodeMessage(msg) if err != nil { return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode task message: %v", err)) } - if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) } keys := []string{ @@ -584,7 +587,7 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim msg.Timeout, msg.Deadline, } - n, err := r.runScriptWithErrorCode(op, scheduleUniqueCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(ctx, op, scheduleUniqueCmd, keys, argv...) if err != nil { return err } @@ -634,6 +637,7 @@ return redis.status_reply("OK")`) // if isFailure is true increments the retried counter. func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { var op errors.Op = "rdb.Retry" + ctx := context.Background() now := time.Now() modified := *msg if isFailure { @@ -661,7 +665,7 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, i expireAt.Unix(), isFailure, } - return r.runScript(op, retryCmd, keys, argv...) + return r.runScript(ctx, op, retryCmd, keys, argv...) } const ( @@ -706,6 +710,7 @@ return redis.status_reply("OK")`) // It also trims the archive by timestamp and set size. func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { var op errors.Op = "rdb.Archive" + ctx := context.Background() now := time.Now() modified := *msg modified.ErrorMsg = errMsg @@ -732,7 +737,7 @@ func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { maxArchiveSize, expireAt.Unix(), } - return r.runScript(op, archiveCmd, keys, argv...) + return r.runScript(ctx, op, archiveCmd, keys, argv...) } // ForwardIfReady checks scheduled and retry sets of the given queues @@ -903,6 +908,7 @@ return redis.status_reply("OK")`) // WriteServerState writes server state data to redis with expiration set to the value ttl. func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error { var op errors.Op = "rdb.WriteServerState" + ctx := context.Background() bytes, err := base.EncodeServerInfo(info) if err != nil { return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode server info: %v", err)) @@ -918,13 +924,13 @@ func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo } skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) - if err := r.client.ZAdd(context.Background(), base.AllServers, &redis.Z{Score: float64(exp.Unix()), Member: skey}).Err(); err != nil { + if err := r.client.ZAdd(ctx, base.AllServers, &redis.Z{Score: float64(exp.Unix()), Member: skey}).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) } - if err := r.client.ZAdd(context.Background(), base.AllWorkers, &redis.Z{Score: float64(exp.Unix()), Member: wkey}).Err(); err != nil { + if err := r.client.ZAdd(ctx, base.AllWorkers, &redis.Z{Score: float64(exp.Unix()), Member: wkey}).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err}) } - return r.runScript(op, writeServerStateCmd, []string{skey, wkey}, args...) + return r.runScript(ctx, op, writeServerStateCmd, []string{skey, wkey}, args...) } // KEYS[1] -> asynq:servers:{} @@ -937,15 +943,16 @@ return redis.status_reply("OK")`) // ClearServerState deletes server state data from redis. func (r *RDB) ClearServerState(host string, pid int, serverID string) error { var op errors.Op = "rdb.ClearServerState" + ctx := context.Background() skey := base.ServerInfoKey(host, pid, serverID) wkey := base.WorkersKey(host, pid, serverID) - if err := r.client.ZRem(context.Background(), base.AllServers, skey).Err(); err != nil { + if err := r.client.ZRem(ctx, base.AllServers, skey).Err(); err != nil { return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err}) } - if err := r.client.ZRem(context.Background(), base.AllWorkers, wkey).Err(); err != nil { + if err := r.client.ZRem(ctx, base.AllWorkers, wkey).Err(); err != nil { return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err}) } - return r.runScript(op, clearServerStateCmd, []string{skey, wkey}) + return r.runScript(ctx, op, clearServerStateCmd, []string{skey, wkey}) } // KEYS[1] -> asynq:schedulers:{} @@ -962,6 +969,7 @@ return redis.status_reply("OK")`) // WriteSchedulerEntries writes scheduler entries data to redis with expiration set to the value ttl. func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.SchedulerEntry, ttl time.Duration) error { var op errors.Op = "rdb.WriteSchedulerEntries" + ctx := context.Background() args := []interface{}{ttl.Seconds()} for _, e := range entries { bytes, err := base.EncodeSchedulerEntry(e) @@ -972,21 +980,22 @@ func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.Schedule } exp := time.Now().Add(ttl).UTC() key := base.SchedulerEntriesKey(schedulerID) - err := r.client.ZAdd(context.Background(), base.AllSchedulers, &redis.Z{Score: float64(exp.Unix()), Member: key}).Err() + err := r.client.ZAdd(ctx, base.AllSchedulers, &redis.Z{Score: float64(exp.Unix()), Member: key}).Err() if err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err}) } - return r.runScript(op, writeSchedulerEntriesCmd, []string{key}, args...) + return r.runScript(ctx, op, writeSchedulerEntriesCmd, []string{key}, args...) } // ClearSchedulerEntries deletes scheduler entries data from redis. func (r *RDB) ClearSchedulerEntries(scheduelrID string) error { var op errors.Op = "rdb.ClearSchedulerEntries" + ctx := context.Background() key := base.SchedulerEntriesKey(scheduelrID) - if err := r.client.ZRem(context.Background(), base.AllSchedulers, key).Err(); err != nil { + if err := r.client.ZRem(ctx, base.AllSchedulers, key).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zrem", Err: err}) } - if err := r.client.Del(context.Background(), key).Err(); err != nil { + if err := r.client.Del(ctx, key).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err}) } return nil @@ -995,8 +1004,9 @@ func (r *RDB) ClearSchedulerEntries(scheduelrID string) error { // CancelationPubSub returns a pubsub for cancelation messages. func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { var op errors.Op = "rdb.CancelationPubSub" - pubsub := r.client.Subscribe(context.Background(), base.CancelChannel) - _, err := pubsub.Receive(context.Background()) + ctx := context.Background() + pubsub := r.client.Subscribe(ctx, base.CancelChannel) + _, err := pubsub.Receive(ctx) if err != nil { return nil, errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub receive error: %v", err)) } @@ -1007,7 +1017,8 @@ func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { // The message is the ID for the task to be canceled. func (r *RDB) PublishCancelation(id string) error { var op errors.Op = "rdb.PublishCancelation" - if err := r.client.Publish(context.Background(), base.CancelChannel, id).Err(); err != nil { + ctx := context.Background() + if err := r.client.Publish(ctx, base.CancelChannel, id).Err(); err != nil { return errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub publish error: %v", err)) } return nil @@ -1028,6 +1039,7 @@ const maxEvents = 1000 // RecordSchedulerEnqueueEvent records the time when the given task was enqueued. func (r *RDB) RecordSchedulerEnqueueEvent(entryID string, event *base.SchedulerEnqueueEvent) error { var op errors.Op = "rdb.RecordSchedulerEnqueueEvent" + ctx := context.Background() data, err := base.EncodeSchedulerEnqueueEvent(event) if err != nil { return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode scheduler enqueue event: %v", err)) @@ -1040,14 +1052,15 @@ func (r *RDB) RecordSchedulerEnqueueEvent(entryID string, event *base.SchedulerE data, maxEvents, } - return r.runScript(op, recordSchedulerEnqueueEventCmd, keys, argv...) + return r.runScript(ctx, op, recordSchedulerEnqueueEventCmd, keys, argv...) } // ClearSchedulerHistory deletes the enqueue event history for the given scheduler entry. func (r *RDB) ClearSchedulerHistory(entryID string) error { var op errors.Op = "rdb.ClearSchedulerHistory" + ctx := context.Background() key := base.SchedulerHistoryKey(entryID) - if err := r.client.Del(context.Background(), key).Err(); err != nil { + if err := r.client.Del(ctx, key).Err(); err != nil { return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err}) } return nil @@ -1056,8 +1069,9 @@ func (r *RDB) ClearSchedulerHistory(entryID string) error { // WriteResult writes the given result data for the specified task. func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) { var op errors.Op = "rdb.WriteResult" + ctx := context.Background() taskKey := base.TaskKey(qname, taskID) - if err := r.client.HSet(context.Background(), taskKey, "result", data).Err(); err != nil { + if err := r.client.HSet(ctx, taskKey, "result", data).Err(); err != nil { return 0, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "hset", Err: err}) } return len(data), nil diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index abb0d39..8d6a464 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -78,7 +78,7 @@ func TestEnqueue(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - err := r.Enqueue(tc.msg) + err := r.Enqueue(context.Background(), tc.msg) if err != nil { t.Errorf("(*RDB).Enqueue(msg) = %v, want nil", err) continue @@ -148,11 +148,11 @@ func TestEnqueueTaskIdConflictError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - if err := r.Enqueue(tc.firstMsg); err != nil { + if err := r.Enqueue(context.Background(), tc.firstMsg); err != nil { t.Errorf("First message: Enqueue failed: %v", err) continue } - if err := r.Enqueue(tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) { + if err := r.Enqueue(context.Background(), tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) { t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict) continue } @@ -181,7 +181,7 @@ func TestEnqueueUnique(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case. // Enqueue the first message, should succeed. - err := r.EnqueueUnique(tc.msg, tc.ttl) + err := r.EnqueueUnique(context.Background(), tc.msg, tc.ttl) if err != nil { t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil", tc.msg, tc.ttl, err) @@ -241,7 +241,7 @@ func TestEnqueueUnique(t *testing.T) { } // Enqueue the second message, should fail. - got := r.EnqueueUnique(tc.msg, tc.ttl) + got := r.EnqueueUnique(context.Background(), tc.msg, tc.ttl) if !errors.Is(got, errors.ErrDuplicateTask) { t.Errorf("Second message: (*RDB).EnqueueUnique(msg, ttl) = %v, want %v", got, errors.ErrDuplicateTask) continue @@ -282,11 +282,11 @@ func TestEnqueueUniqueTaskIdConflictError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - if err := r.EnqueueUnique(tc.firstMsg, ttl); err != nil { + if err := r.EnqueueUnique(context.Background(), tc.firstMsg, ttl); err != nil { t.Errorf("First message: EnqueueUnique failed: %v", err) continue } - if err := r.EnqueueUnique(tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { + if err := r.EnqueueUnique(context.Background(), tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict) continue } @@ -1162,7 +1162,7 @@ func TestSchedule(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case - err := r.Schedule(tc.msg, tc.processAt) + err := r.Schedule(context.Background(), tc.msg, tc.processAt) if err != nil { t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", tc.msg, tc.processAt, err) @@ -1245,11 +1245,11 @@ func TestScheduleTaskIdConflictError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - if err := r.Schedule(tc.firstMsg, processAt); err != nil { + if err := r.Schedule(context.Background(), tc.firstMsg, processAt); err != nil { t.Errorf("First message: Schedule failed: %v", err) continue } - if err := r.Schedule(tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) { + if err := r.Schedule(context.Background(), tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) { t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict) continue } @@ -1279,7 +1279,7 @@ func TestScheduleUnique(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)" - err := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) + err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl) if err != nil { t.Errorf("Frist task: %s = %v, want nil", desc, err) continue @@ -1336,7 +1336,7 @@ func TestScheduleUnique(t *testing.T) { } // Enqueue the second message, should fail. - got := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) + got := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl) if !errors.Is(got, errors.ErrDuplicateTask) { t.Errorf("Second task: %s = %v, want %v", desc, got, errors.ErrDuplicateTask) continue @@ -1379,11 +1379,11 @@ func TestScheduleUniqueTaskIdConflictError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - if err := r.ScheduleUnique(tc.firstMsg, processAt, ttl); err != nil { + if err := r.ScheduleUnique(context.Background(), tc.firstMsg, processAt, ttl); err != nil { t.Errorf("First message: ScheduleUnique failed: %v", err) continue } - if err := r.ScheduleUnique(tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { + if err := r.ScheduleUnique(context.Background(), tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict) continue } diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index cec9463..5f3a023 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -6,6 +6,7 @@ package testbroker import ( + "context" "errors" "sync" "time" @@ -45,22 +46,22 @@ func (tb *TestBroker) Wakeup() { tb.sleeping = false } -func (tb *TestBroker) Enqueue(msg *base.TaskMessage) error { +func (tb *TestBroker) Enqueue(ctx context.Context, msg *base.TaskMessage) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.Enqueue(msg) + return tb.real.Enqueue(ctx, msg) } -func (tb *TestBroker) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { +func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time.Duration) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.EnqueueUnique(msg, ttl) + return tb.real.EnqueueUnique(ctx, msg, ttl) } func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) { @@ -99,22 +100,22 @@ func (tb *TestBroker) Requeue(msg *base.TaskMessage) error { return tb.real.Requeue(msg) } -func (tb *TestBroker) Schedule(msg *base.TaskMessage, processAt time.Time) error { +func (tb *TestBroker) Schedule(ctx context.Context, msg *base.TaskMessage, processAt time.Time) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.Schedule(msg, processAt) + return tb.real.Schedule(ctx, msg, processAt) } -func (tb *TestBroker) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { +func (tb *TestBroker) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.ScheduleUnique(msg, processAt, ttl) + return tb.real.ScheduleUnique(ctx, msg, processAt, ttl) } func (tb *TestBroker) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { diff --git a/processor_test.go b/processor_test.go index 4ee28be..21ba01f 100644 --- a/processor_test.go +++ b/processor_test.go @@ -126,7 +126,7 @@ func TestProcessorSuccessWithSingleQueue(t *testing.T) { p.start(&sync.WaitGroup{}) for _, msg := range tc.incoming { - err := rdbClient.Enqueue(msg) + err := rdbClient.Enqueue(context.Background(), msg) if err != nil { p.shutdown() t.Fatal(err)