2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-10 11:31:58 +08:00

Add EnqueueContext method to Client

This commit is contained in:
Ken Hibino 2021-11-15 16:34:26 -08:00 committed by GitHub
parent e2b61c9056
commit 9f2c321e98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 131 additions and 82 deletions

View File

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Added
- `EnqueueContext` method is added to `Client`.
### Fixed
- Fixed an error when user pass a duration less than 1s to `Unique` option
## [0.19.0] - 2021-11-06 ## [0.19.0] - 2021-11-06
### Changed ### Changed

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -292,7 +293,7 @@ func (c *Client) Close() error {
return c.rdb.Close() return c.rdb.Close()
} }
// Enqueue enqueues the given task to be processed asynchronously. // Enqueue enqueues the given task to a queue.
// //
// Enqueue returns TaskInfo and nil error if the task is enqueued successfully, otherwise returns a non-nil error. // Enqueue returns TaskInfo and nil error if the task is enqueued successfully, otherwise returns a non-nil error.
// //
@ -302,7 +303,25 @@ func (c *Client) Close() error {
// By deafult, max retry is set to 25 and timeout is set to 30 minutes. // By deafult, max retry is set to 25 and timeout is set to 30 minutes.
// //
// If no ProcessAt or ProcessIn options are provided, the task will be pending immediately. // If no ProcessAt or ProcessIn options are provided, the task will be pending immediately.
//
// Enqueue uses context.Background internally; to specify the context, use EnqueueContext.
func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) {
return c.EnqueueContext(context.Background(), task, opts...)
}
// EnqueueContext enqueues the given task to a queue.
//
// EnqueueContext returns TaskInfo and nil error if the task is enqueued successfully, otherwise returns a non-nil error.
//
// The argument opts specifies the behavior of task processing.
// If there are conflicting Option values the last one overrides others.
// Any options provided to NewTask can be overridden by options passed to Enqueue.
// By deafult, max retry is set to 25 and timeout is set to 30 minutes.
//
// If no ProcessAt or ProcessIn options are provided, the task will be pending immediately.
//
// The first argument context applies to the enqueue operation. To specify task timeout and deadline, use Timeout and Deadline option instead.
func (c *Client) EnqueueContext(ctx context.Context, task *Task, opts ...Option) (*TaskInfo, error) {
if strings.TrimSpace(task.Type()) == "" { if strings.TrimSpace(task.Type()) == "" {
return nil, fmt.Errorf("task typename cannot be empty") return nil, fmt.Errorf("task typename cannot be empty")
} }
@ -343,10 +362,10 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) {
var state base.TaskState var state base.TaskState
if opt.processAt.Before(now) || opt.processAt.Equal(now) { if opt.processAt.Before(now) || opt.processAt.Equal(now) {
opt.processAt = now opt.processAt = now
err = c.enqueue(msg, opt.uniqueTTL) err = c.enqueue(ctx, msg, opt.uniqueTTL)
state = base.TaskStatePending state = base.TaskStatePending
} else { } else {
err = c.schedule(msg, opt.processAt, opt.uniqueTTL) err = c.schedule(ctx, msg, opt.processAt, opt.uniqueTTL)
state = base.TaskStateScheduled state = base.TaskStateScheduled
} }
switch { switch {
@ -360,17 +379,17 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) {
return newTaskInfo(msg, state, opt.processAt, nil), nil return newTaskInfo(msg, state, opt.processAt, nil), nil
} }
func (c *Client) enqueue(msg *base.TaskMessage, uniqueTTL time.Duration) error { func (c *Client) enqueue(ctx context.Context, msg *base.TaskMessage, uniqueTTL time.Duration) error {
if uniqueTTL > 0 { if uniqueTTL > 0 {
return c.rdb.EnqueueUnique(msg, uniqueTTL) return c.rdb.EnqueueUnique(ctx, msg, uniqueTTL)
} }
return c.rdb.Enqueue(msg) return c.rdb.Enqueue(ctx, msg)
} }
func (c *Client) schedule(msg *base.TaskMessage, t time.Time, uniqueTTL time.Duration) error { func (c *Client) schedule(ctx context.Context, msg *base.TaskMessage, t time.Time, uniqueTTL time.Duration) error {
if uniqueTTL > 0 { if uniqueTTL > 0 {
ttl := t.Add(uniqueTTL).Sub(time.Now()) ttl := t.Add(uniqueTTL).Sub(time.Now())
return c.rdb.ScheduleUnique(msg, t, ttl) return c.rdb.ScheduleUnique(ctx, msg, t, ttl)
} }
return c.rdb.Schedule(msg, t) return c.rdb.Schedule(ctx, msg, t)
} }

View File

@ -660,14 +660,14 @@ func (c *Cancelations) Get(id string) (fn context.CancelFunc, ok bool) {
// See rdb.RDB as a reference implementation. // See rdb.RDB as a reference implementation.
type Broker interface { type Broker interface {
Ping() error Ping() error
Enqueue(msg *TaskMessage) error Enqueue(ctx context.Context, msg *TaskMessage) error
EnqueueUnique(msg *TaskMessage, ttl time.Duration) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error
Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Dequeue(qnames ...string) (*TaskMessage, time.Time, error)
Done(msg *TaskMessage) error Done(msg *TaskMessage) error
MarkAsComplete(msg *TaskMessage) error MarkAsComplete(msg *TaskMessage) error
Requeue(msg *TaskMessage) error Requeue(msg *TaskMessage) error
Schedule(msg *TaskMessage, processAt time.Time) error Schedule(ctx context.Context, msg *TaskMessage, processAt time.Time) error
ScheduleUnique(msg *TaskMessage, processAt time.Time, ttl time.Duration) error ScheduleUnique(ctx context.Context, msg *TaskMessage, processAt time.Time, ttl time.Duration) error
Retry(msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error Retry(msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error
Archive(msg *TaskMessage, errMsg string) error Archive(msg *TaskMessage, errMsg string) error
ForwardIfReady(qnames ...string) error ForwardIfReady(qnames ...string) error

View File

@ -5,6 +5,7 @@
package rdb package rdb
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -15,6 +16,7 @@ import (
func BenchmarkEnqueue(b *testing.B) { func BenchmarkEnqueue(b *testing.B) {
r := setup(b) r := setup(b)
ctx := context.Background()
msg := asynqtest.NewTaskMessage("task1", nil) msg := asynqtest.NewTaskMessage("task1", nil)
b.ResetTimer() b.ResetTimer()
@ -23,7 +25,7 @@ func BenchmarkEnqueue(b *testing.B) {
asynqtest.FlushDB(b, r.client) asynqtest.FlushDB(b, r.client)
b.StartTimer() b.StartTimer()
if err := r.Enqueue(msg); err != nil { if err := r.Enqueue(ctx, msg); err != nil {
b.Fatalf("Enqueue failed: %v", err) b.Fatalf("Enqueue failed: %v", err)
} }
} }
@ -31,6 +33,7 @@ func BenchmarkEnqueue(b *testing.B) {
func BenchmarkEnqueueUnique(b *testing.B) { func BenchmarkEnqueueUnique(b *testing.B) {
r := setup(b) r := setup(b)
ctx := context.Background()
msg := &base.TaskMessage{ msg := &base.TaskMessage{
Type: "task1", Type: "task1",
Payload: nil, Payload: nil,
@ -45,7 +48,7 @@ func BenchmarkEnqueueUnique(b *testing.B) {
asynqtest.FlushDB(b, r.client) asynqtest.FlushDB(b, r.client)
b.StartTimer() b.StartTimer()
if err := r.EnqueueUnique(msg, uniqueTTL); err != nil { if err := r.EnqueueUnique(ctx, msg, uniqueTTL); err != nil {
b.Fatalf("EnqueueUnique failed: %v", err) b.Fatalf("EnqueueUnique failed: %v", err)
} }
} }
@ -53,6 +56,7 @@ func BenchmarkEnqueueUnique(b *testing.B) {
func BenchmarkSchedule(b *testing.B) { func BenchmarkSchedule(b *testing.B) {
r := setup(b) r := setup(b)
ctx := context.Background()
msg := asynqtest.NewTaskMessage("task1", nil) msg := asynqtest.NewTaskMessage("task1", nil)
processAt := time.Now().Add(3 * time.Minute) processAt := time.Now().Add(3 * time.Minute)
b.ResetTimer() b.ResetTimer()
@ -62,7 +66,7 @@ func BenchmarkSchedule(b *testing.B) {
asynqtest.FlushDB(b, r.client) asynqtest.FlushDB(b, r.client)
b.StartTimer() b.StartTimer()
if err := r.Schedule(msg, processAt); err != nil { if err := r.Schedule(ctx, msg, processAt); err != nil {
b.Fatalf("Schedule failed: %v", err) b.Fatalf("Schedule failed: %v", err)
} }
} }
@ -70,6 +74,7 @@ func BenchmarkSchedule(b *testing.B) {
func BenchmarkScheduleUnique(b *testing.B) { func BenchmarkScheduleUnique(b *testing.B) {
r := setup(b) r := setup(b)
ctx := context.Background()
msg := &base.TaskMessage{ msg := &base.TaskMessage{
Type: "task1", Type: "task1",
Payload: nil, Payload: nil,
@ -85,7 +90,7 @@ func BenchmarkScheduleUnique(b *testing.B) {
asynqtest.FlushDB(b, r.client) asynqtest.FlushDB(b, r.client)
b.StartTimer() b.StartTimer()
if err := r.ScheduleUnique(msg, processAt, uniqueTTL); err != nil { if err := r.ScheduleUnique(ctx, msg, processAt, uniqueTTL); err != nil {
b.Fatalf("EnqueueUnique failed: %v", err) b.Fatalf("EnqueueUnique failed: %v", err)
} }
} }
@ -93,6 +98,7 @@ func BenchmarkScheduleUnique(b *testing.B) {
func BenchmarkDequeueSingleQueue(b *testing.B) { func BenchmarkDequeueSingleQueue(b *testing.B) {
r := setup(b) r := setup(b)
ctx := context.Background()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -101,7 +107,7 @@ func BenchmarkDequeueSingleQueue(b *testing.B) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
m := asynqtest.NewTaskMessageWithQueue( m := asynqtest.NewTaskMessageWithQueue(
fmt.Sprintf("task%d", i), nil, base.DefaultQueueName) fmt.Sprintf("task%d", i), nil, base.DefaultQueueName)
if err := r.Enqueue(m); err != nil { if err := r.Enqueue(ctx, m); err != nil {
b.Fatalf("Enqueue failed: %v", err) b.Fatalf("Enqueue failed: %v", err)
} }
} }
@ -116,6 +122,7 @@ func BenchmarkDequeueSingleQueue(b *testing.B) {
func BenchmarkDequeueMultipleQueues(b *testing.B) { func BenchmarkDequeueMultipleQueues(b *testing.B) {
qnames := []string{"critical", "default", "low"} qnames := []string{"critical", "default", "low"}
r := setup(b) r := setup(b)
ctx := context.Background()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -125,7 +132,7 @@ func BenchmarkDequeueMultipleQueues(b *testing.B) {
for _, qname := range qnames { for _, qname := range qnames {
m := asynqtest.NewTaskMessageWithQueue( m := asynqtest.NewTaskMessageWithQueue(
fmt.Sprintf("%s_task%d", qname, i), nil, qname) fmt.Sprintf("%s_task%d", qname, i), nil, qname)
if err := r.Enqueue(m); err != nil { if err := r.Enqueue(ctx, m); err != nil {
b.Fatalf("Enqueue failed: %v", err) b.Fatalf("Enqueue failed: %v", err)
} }
} }

View File

@ -879,7 +879,7 @@ func TestListScheduledPagination(t *testing.T) {
// create 100 tasks with an increasing number of wait time. // create 100 tasks with an increasing number of wait time.
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil) msg := h.NewTaskMessage(fmt.Sprintf("task %d", i), nil)
if err := r.Schedule(msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil { if err := r.Schedule(context.Background(), msg, time.Now().Add(time.Duration(i)*time.Second)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -43,16 +43,16 @@ func (r *RDB) Ping() error {
return r.client.Ping(context.Background()).Err() return r.client.Ping(context.Background()).Err()
} }
func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args ...interface{}) error { func (r *RDB) runScript(ctx context.Context, op errors.Op, script *redis.Script, keys []string, args ...interface{}) error {
if err := script.Run(context.Background(), r.client, keys, args...).Err(); err != nil { if err := script.Run(ctx, r.client, keys, args...).Err(); err != nil {
return errors.E(op, errors.Internal, fmt.Sprintf("redis eval error: %v", err)) return errors.E(op, errors.Internal, fmt.Sprintf("redis eval error: %v", err))
} }
return nil return nil
} }
// Runs the given script with keys and args and retuns the script's return value as int64. // Runs the given script with keys and args and retuns the script's return value as int64.
func (r *RDB) runScriptWithErrorCode(op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) { func (r *RDB) runScriptWithErrorCode(ctx context.Context, op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) {
res, err := script.Run(context.Background(), r.client, keys, args...).Result() res, err := script.Run(ctx, r.client, keys, args...).Result()
if err != nil { if err != nil {
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err))
} }
@ -91,13 +91,13 @@ return 1
`) `)
// Enqueue adds the given task to the pending list of the queue. // Enqueue adds the given task to the pending list of the queue.
func (r *RDB) Enqueue(msg *base.TaskMessage) error { func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
var op errors.Op = "rdb.Enqueue" var op errors.Op = "rdb.Enqueue"
encoded, err := base.EncodeMessage(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
} }
if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
} }
keys := []string{ keys := []string{
@ -110,7 +110,7 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
n, err := r.runScriptWithErrorCode(op, enqueueCmd, keys, argv...) n, err := r.runScriptWithErrorCode(ctx, op, enqueueCmd, keys, argv...)
if err != nil { if err != nil {
return err return err
} }
@ -156,13 +156,13 @@ return 1
// EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired. // EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time.Duration) error {
var op errors.Op = "rdb.EnqueueUnique" var op errors.Op = "rdb.EnqueueUnique"
encoded, err := base.EncodeMessage(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return errors.E(op, errors.Internal, "cannot encode task message: %v", err) return errors.E(op, errors.Internal, "cannot encode task message: %v", err)
} }
if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
} }
keys := []string{ keys := []string{
@ -177,7 +177,7 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
n, err := r.runScriptWithErrorCode(op, enqueueUniqueCmd, keys, argv...) n, err := r.runScriptWithErrorCode(ctx, op, enqueueUniqueCmd, keys, argv...)
if err != nil { if err != nil {
return err return err
} }
@ -334,6 +334,7 @@ return redis.status_reply("OK")
// It removes a uniqueness lock acquired by the task, if any. // It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) Done(msg *base.TaskMessage) error { func (r *RDB) Done(msg *base.TaskMessage) error {
var op errors.Op = "rdb.Done" var op errors.Op = "rdb.Done"
ctx := context.Background()
now := time.Now() now := time.Now()
expireAt := now.Add(statsTTL) expireAt := now.Add(statsTTL)
keys := []string{ keys := []string{
@ -349,9 +350,9 @@ func (r *RDB) Done(msg *base.TaskMessage) error {
// Note: We cannot pass empty unique key when running this script in redis-cluster. // Note: We cannot pass empty unique key when running this script in redis-cluster.
if len(msg.UniqueKey) > 0 { if len(msg.UniqueKey) > 0 {
keys = append(keys, msg.UniqueKey) keys = append(keys, msg.UniqueKey)
return r.runScript(op, doneUniqueCmd, keys, argv...) return r.runScript(ctx, op, doneUniqueCmd, keys, argv...)
} }
return r.runScript(op, doneCmd, keys, argv...) return r.runScript(ctx, op, doneCmd, keys, argv...)
} }
// KEYS[1] -> asynq:{<qname>}:active // KEYS[1] -> asynq:{<qname>}:active
@ -416,6 +417,7 @@ return redis.status_reply("OK")
// It removes a uniqueness lock acquired by the task, if any. // It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error { func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error {
var op errors.Op = "rdb.MarkAsComplete" var op errors.Op = "rdb.MarkAsComplete"
ctx := context.Background()
now := time.Now() now := time.Now()
statsExpireAt := now.Add(statsTTL) statsExpireAt := now.Add(statsTTL)
msg.CompletedAt = now.Unix() msg.CompletedAt = now.Unix()
@ -439,9 +441,9 @@ func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error {
// Note: We cannot pass empty unique key when running this script in redis-cluster. // Note: We cannot pass empty unique key when running this script in redis-cluster.
if len(msg.UniqueKey) > 0 { if len(msg.UniqueKey) > 0 {
keys = append(keys, msg.UniqueKey) keys = append(keys, msg.UniqueKey)
return r.runScript(op, markAsCompleteUniqueCmd, keys, argv...) return r.runScript(ctx, op, markAsCompleteUniqueCmd, keys, argv...)
} }
return r.runScript(op, markAsCompleteCmd, keys, argv...) return r.runScript(ctx, op, markAsCompleteCmd, keys, argv...)
} }
// KEYS[1] -> asynq:{<qname>}:active // KEYS[1] -> asynq:{<qname>}:active
@ -464,13 +466,14 @@ return redis.status_reply("OK")`)
// Requeue moves the task from active queue to the specified queue. // Requeue moves the task from active queue to the specified queue.
func (r *RDB) Requeue(msg *base.TaskMessage) error { func (r *RDB) Requeue(msg *base.TaskMessage) error {
var op errors.Op = "rdb.Requeue" var op errors.Op = "rdb.Requeue"
ctx := context.Background()
keys := []string{ keys := []string{
base.ActiveKey(msg.Queue), base.ActiveKey(msg.Queue),
base.DeadlinesKey(msg.Queue), base.DeadlinesKey(msg.Queue),
base.PendingKey(msg.Queue), base.PendingKey(msg.Queue),
base.TaskKey(msg.Queue, msg.ID), base.TaskKey(msg.Queue, msg.ID),
} }
return r.runScript(op, requeueCmd, keys, msg.ID) return r.runScript(ctx, op, requeueCmd, keys, msg.ID)
} }
// KEYS[1] -> asynq:{<qname>}:t:<task_id> // KEYS[1] -> asynq:{<qname>}:t:<task_id>
@ -498,13 +501,13 @@ return 1
`) `)
// Schedule adds the task to the scheduled set to be processed in the future. // Schedule adds the task to the scheduled set to be processed in the future.
func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { func (r *RDB) Schedule(ctx context.Context, msg *base.TaskMessage, processAt time.Time) error {
var op errors.Op = "rdb.Schedule" var op errors.Op = "rdb.Schedule"
encoded, err := base.EncodeMessage(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
} }
if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
} }
keys := []string{ keys := []string{
@ -518,7 +521,7 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
n, err := r.runScriptWithErrorCode(op, scheduleCmd, keys, argv...) n, err := r.runScriptWithErrorCode(ctx, op, scheduleCmd, keys, argv...)
if err != nil { if err != nil {
return err return err
} }
@ -562,13 +565,13 @@ return 1
// ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired. // ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { func (r *RDB) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error {
var op errors.Op = "rdb.ScheduleUnique" var op errors.Op = "rdb.ScheduleUnique"
encoded, err := base.EncodeMessage(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode task message: %v", err)) return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode task message: %v", err))
} }
if err := r.client.SAdd(context.Background(), base.AllQueues, msg.Queue).Err(); err != nil { if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
} }
keys := []string{ keys := []string{
@ -584,7 +587,7 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
n, err := r.runScriptWithErrorCode(op, scheduleUniqueCmd, keys, argv...) n, err := r.runScriptWithErrorCode(ctx, op, scheduleUniqueCmd, keys, argv...)
if err != nil { if err != nil {
return err return err
} }
@ -634,6 +637,7 @@ return redis.status_reply("OK")`)
// if isFailure is true increments the retried counter. // if isFailure is true increments the retried counter.
func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error {
var op errors.Op = "rdb.Retry" var op errors.Op = "rdb.Retry"
ctx := context.Background()
now := time.Now() now := time.Now()
modified := *msg modified := *msg
if isFailure { if isFailure {
@ -661,7 +665,7 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, i
expireAt.Unix(), expireAt.Unix(),
isFailure, isFailure,
} }
return r.runScript(op, retryCmd, keys, argv...) return r.runScript(ctx, op, retryCmd, keys, argv...)
} }
const ( const (
@ -706,6 +710,7 @@ return redis.status_reply("OK")`)
// It also trims the archive by timestamp and set size. // It also trims the archive by timestamp and set size.
func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error {
var op errors.Op = "rdb.Archive" var op errors.Op = "rdb.Archive"
ctx := context.Background()
now := time.Now() now := time.Now()
modified := *msg modified := *msg
modified.ErrorMsg = errMsg modified.ErrorMsg = errMsg
@ -732,7 +737,7 @@ func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error {
maxArchiveSize, maxArchiveSize,
expireAt.Unix(), expireAt.Unix(),
} }
return r.runScript(op, archiveCmd, keys, argv...) return r.runScript(ctx, op, archiveCmd, keys, argv...)
} }
// ForwardIfReady checks scheduled and retry sets of the given queues // ForwardIfReady checks scheduled and retry sets of the given queues
@ -903,6 +908,7 @@ return redis.status_reply("OK")`)
// WriteServerState writes server state data to redis with expiration set to the value ttl. // WriteServerState writes server state data to redis with expiration set to the value ttl.
func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error { func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error {
var op errors.Op = "rdb.WriteServerState" var op errors.Op = "rdb.WriteServerState"
ctx := context.Background()
bytes, err := base.EncodeServerInfo(info) bytes, err := base.EncodeServerInfo(info)
if err != nil { if err != nil {
return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode server info: %v", err)) return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode server info: %v", err))
@ -918,13 +924,13 @@ func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo
} }
skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID)
wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) wkey := base.WorkersKey(info.Host, info.PID, info.ServerID)
if err := r.client.ZAdd(context.Background(), base.AllServers, &redis.Z{Score: float64(exp.Unix()), Member: skey}).Err(); err != nil { if err := r.client.ZAdd(ctx, base.AllServers, &redis.Z{Score: float64(exp.Unix()), Member: skey}).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
} }
if err := r.client.ZAdd(context.Background(), base.AllWorkers, &redis.Z{Score: float64(exp.Unix()), Member: wkey}).Err(); err != nil { if err := r.client.ZAdd(ctx, base.AllWorkers, &redis.Z{Score: float64(exp.Unix()), Member: wkey}).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err})
} }
return r.runScript(op, writeServerStateCmd, []string{skey, wkey}, args...) return r.runScript(ctx, op, writeServerStateCmd, []string{skey, wkey}, args...)
} }
// KEYS[1] -> asynq:servers:{<host:pid:sid>} // KEYS[1] -> asynq:servers:{<host:pid:sid>}
@ -937,15 +943,16 @@ return redis.status_reply("OK")`)
// ClearServerState deletes server state data from redis. // ClearServerState deletes server state data from redis.
func (r *RDB) ClearServerState(host string, pid int, serverID string) error { func (r *RDB) ClearServerState(host string, pid int, serverID string) error {
var op errors.Op = "rdb.ClearServerState" var op errors.Op = "rdb.ClearServerState"
ctx := context.Background()
skey := base.ServerInfoKey(host, pid, serverID) skey := base.ServerInfoKey(host, pid, serverID)
wkey := base.WorkersKey(host, pid, serverID) wkey := base.WorkersKey(host, pid, serverID)
if err := r.client.ZRem(context.Background(), base.AllServers, skey).Err(); err != nil { if err := r.client.ZRem(ctx, base.AllServers, skey).Err(); err != nil {
return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err}) return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err})
} }
if err := r.client.ZRem(context.Background(), base.AllWorkers, wkey).Err(); err != nil { if err := r.client.ZRem(ctx, base.AllWorkers, wkey).Err(); err != nil {
return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err}) return errors.E(op, errors.Internal, &errors.RedisCommandError{Command: "zrem", Err: err})
} }
return r.runScript(op, clearServerStateCmd, []string{skey, wkey}) return r.runScript(ctx, op, clearServerStateCmd, []string{skey, wkey})
} }
// KEYS[1] -> asynq:schedulers:{<schedulerID>} // KEYS[1] -> asynq:schedulers:{<schedulerID>}
@ -962,6 +969,7 @@ return redis.status_reply("OK")`)
// WriteSchedulerEntries writes scheduler entries data to redis with expiration set to the value ttl. // WriteSchedulerEntries writes scheduler entries data to redis with expiration set to the value ttl.
func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.SchedulerEntry, ttl time.Duration) error { func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.SchedulerEntry, ttl time.Duration) error {
var op errors.Op = "rdb.WriteSchedulerEntries" var op errors.Op = "rdb.WriteSchedulerEntries"
ctx := context.Background()
args := []interface{}{ttl.Seconds()} args := []interface{}{ttl.Seconds()}
for _, e := range entries { for _, e := range entries {
bytes, err := base.EncodeSchedulerEntry(e) bytes, err := base.EncodeSchedulerEntry(e)
@ -972,21 +980,22 @@ func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.Schedule
} }
exp := time.Now().Add(ttl).UTC() exp := time.Now().Add(ttl).UTC()
key := base.SchedulerEntriesKey(schedulerID) key := base.SchedulerEntriesKey(schedulerID)
err := r.client.ZAdd(context.Background(), base.AllSchedulers, &redis.Z{Score: float64(exp.Unix()), Member: key}).Err() err := r.client.ZAdd(ctx, base.AllSchedulers, &redis.Z{Score: float64(exp.Unix()), Member: key}).Err()
if err != nil { if err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zadd", Err: err})
} }
return r.runScript(op, writeSchedulerEntriesCmd, []string{key}, args...) return r.runScript(ctx, op, writeSchedulerEntriesCmd, []string{key}, args...)
} }
// ClearSchedulerEntries deletes scheduler entries data from redis. // ClearSchedulerEntries deletes scheduler entries data from redis.
func (r *RDB) ClearSchedulerEntries(scheduelrID string) error { func (r *RDB) ClearSchedulerEntries(scheduelrID string) error {
var op errors.Op = "rdb.ClearSchedulerEntries" var op errors.Op = "rdb.ClearSchedulerEntries"
ctx := context.Background()
key := base.SchedulerEntriesKey(scheduelrID) key := base.SchedulerEntriesKey(scheduelrID)
if err := r.client.ZRem(context.Background(), base.AllSchedulers, key).Err(); err != nil { if err := r.client.ZRem(ctx, base.AllSchedulers, key).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zrem", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zrem", Err: err})
} }
if err := r.client.Del(context.Background(), key).Err(); err != nil { if err := r.client.Del(ctx, key).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err})
} }
return nil return nil
@ -995,8 +1004,9 @@ func (r *RDB) ClearSchedulerEntries(scheduelrID string) error {
// CancelationPubSub returns a pubsub for cancelation messages. // CancelationPubSub returns a pubsub for cancelation messages.
func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { func (r *RDB) CancelationPubSub() (*redis.PubSub, error) {
var op errors.Op = "rdb.CancelationPubSub" var op errors.Op = "rdb.CancelationPubSub"
pubsub := r.client.Subscribe(context.Background(), base.CancelChannel) ctx := context.Background()
_, err := pubsub.Receive(context.Background()) pubsub := r.client.Subscribe(ctx, base.CancelChannel)
_, err := pubsub.Receive(ctx)
if err != nil { if err != nil {
return nil, errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub receive error: %v", err)) return nil, errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub receive error: %v", err))
} }
@ -1007,7 +1017,8 @@ func (r *RDB) CancelationPubSub() (*redis.PubSub, error) {
// The message is the ID for the task to be canceled. // The message is the ID for the task to be canceled.
func (r *RDB) PublishCancelation(id string) error { func (r *RDB) PublishCancelation(id string) error {
var op errors.Op = "rdb.PublishCancelation" var op errors.Op = "rdb.PublishCancelation"
if err := r.client.Publish(context.Background(), base.CancelChannel, id).Err(); err != nil { ctx := context.Background()
if err := r.client.Publish(ctx, base.CancelChannel, id).Err(); err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub publish error: %v", err)) return errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub publish error: %v", err))
} }
return nil return nil
@ -1028,6 +1039,7 @@ const maxEvents = 1000
// RecordSchedulerEnqueueEvent records the time when the given task was enqueued. // RecordSchedulerEnqueueEvent records the time when the given task was enqueued.
func (r *RDB) RecordSchedulerEnqueueEvent(entryID string, event *base.SchedulerEnqueueEvent) error { func (r *RDB) RecordSchedulerEnqueueEvent(entryID string, event *base.SchedulerEnqueueEvent) error {
var op errors.Op = "rdb.RecordSchedulerEnqueueEvent" var op errors.Op = "rdb.RecordSchedulerEnqueueEvent"
ctx := context.Background()
data, err := base.EncodeSchedulerEnqueueEvent(event) data, err := base.EncodeSchedulerEnqueueEvent(event)
if err != nil { if err != nil {
return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode scheduler enqueue event: %v", err)) return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode scheduler enqueue event: %v", err))
@ -1040,14 +1052,15 @@ func (r *RDB) RecordSchedulerEnqueueEvent(entryID string, event *base.SchedulerE
data, data,
maxEvents, maxEvents,
} }
return r.runScript(op, recordSchedulerEnqueueEventCmd, keys, argv...) return r.runScript(ctx, op, recordSchedulerEnqueueEventCmd, keys, argv...)
} }
// ClearSchedulerHistory deletes the enqueue event history for the given scheduler entry. // ClearSchedulerHistory deletes the enqueue event history for the given scheduler entry.
func (r *RDB) ClearSchedulerHistory(entryID string) error { func (r *RDB) ClearSchedulerHistory(entryID string) error {
var op errors.Op = "rdb.ClearSchedulerHistory" var op errors.Op = "rdb.ClearSchedulerHistory"
ctx := context.Background()
key := base.SchedulerHistoryKey(entryID) key := base.SchedulerHistoryKey(entryID)
if err := r.client.Del(context.Background(), key).Err(); err != nil { if err := r.client.Del(ctx, key).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err}) return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "del", Err: err})
} }
return nil return nil
@ -1056,8 +1069,9 @@ func (r *RDB) ClearSchedulerHistory(entryID string) error {
// WriteResult writes the given result data for the specified task. // WriteResult writes the given result data for the specified task.
func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) { func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) {
var op errors.Op = "rdb.WriteResult" var op errors.Op = "rdb.WriteResult"
ctx := context.Background()
taskKey := base.TaskKey(qname, taskID) taskKey := base.TaskKey(qname, taskID)
if err := r.client.HSet(context.Background(), taskKey, "result", data).Err(); err != nil { if err := r.client.HSet(ctx, taskKey, "result", data).Err(); err != nil {
return 0, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "hset", Err: err}) return 0, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "hset", Err: err})
} }
return len(data), nil return len(data), nil

View File

@ -78,7 +78,7 @@ func TestEnqueue(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
err := r.Enqueue(tc.msg) err := r.Enqueue(context.Background(), tc.msg)
if err != nil { if err != nil {
t.Errorf("(*RDB).Enqueue(msg) = %v, want nil", err) t.Errorf("(*RDB).Enqueue(msg) = %v, want nil", err)
continue continue
@ -148,11 +148,11 @@ func TestEnqueueTaskIdConflictError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.Enqueue(tc.firstMsg); err != nil { if err := r.Enqueue(context.Background(), tc.firstMsg); err != nil {
t.Errorf("First message: Enqueue failed: %v", err) t.Errorf("First message: Enqueue failed: %v", err)
continue continue
} }
if err := r.Enqueue(tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) { if err := r.Enqueue(context.Background(), tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict) t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict)
continue continue
} }
@ -181,7 +181,7 @@ func TestEnqueueUnique(t *testing.T) {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
// Enqueue the first message, should succeed. // Enqueue the first message, should succeed.
err := r.EnqueueUnique(tc.msg, tc.ttl) err := r.EnqueueUnique(context.Background(), tc.msg, tc.ttl)
if err != nil { if err != nil {
t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil", t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil",
tc.msg, tc.ttl, err) tc.msg, tc.ttl, err)
@ -241,7 +241,7 @@ func TestEnqueueUnique(t *testing.T) {
} }
// Enqueue the second message, should fail. // Enqueue the second message, should fail.
got := r.EnqueueUnique(tc.msg, tc.ttl) got := r.EnqueueUnique(context.Background(), tc.msg, tc.ttl)
if !errors.Is(got, errors.ErrDuplicateTask) { if !errors.Is(got, errors.ErrDuplicateTask) {
t.Errorf("Second message: (*RDB).EnqueueUnique(msg, ttl) = %v, want %v", got, errors.ErrDuplicateTask) t.Errorf("Second message: (*RDB).EnqueueUnique(msg, ttl) = %v, want %v", got, errors.ErrDuplicateTask)
continue continue
@ -282,11 +282,11 @@ func TestEnqueueUniqueTaskIdConflictError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.EnqueueUnique(tc.firstMsg, ttl); err != nil { if err := r.EnqueueUnique(context.Background(), tc.firstMsg, ttl); err != nil {
t.Errorf("First message: EnqueueUnique failed: %v", err) t.Errorf("First message: EnqueueUnique failed: %v", err)
continue continue
} }
if err := r.EnqueueUnique(tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { if err := r.EnqueueUnique(context.Background(), tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict) t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
continue continue
} }
@ -1162,7 +1162,7 @@ func TestSchedule(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case h.FlushDB(t, r.client) // clean up db before each test case
err := r.Schedule(tc.msg, tc.processAt) err := r.Schedule(context.Background(), tc.msg, tc.processAt)
if err != nil { if err != nil {
t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil",
tc.msg, tc.processAt, err) tc.msg, tc.processAt, err)
@ -1245,11 +1245,11 @@ func TestScheduleTaskIdConflictError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.Schedule(tc.firstMsg, processAt); err != nil { if err := r.Schedule(context.Background(), tc.firstMsg, processAt); err != nil {
t.Errorf("First message: Schedule failed: %v", err) t.Errorf("First message: Schedule failed: %v", err)
continue continue
} }
if err := r.Schedule(tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) { if err := r.Schedule(context.Background(), tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict) t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict)
continue continue
} }
@ -1279,7 +1279,7 @@ func TestScheduleUnique(t *testing.T) {
h.FlushDB(t, r.client) // clean up db before each test case h.FlushDB(t, r.client) // clean up db before each test case
desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)" desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)"
err := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl)
if err != nil { if err != nil {
t.Errorf("Frist task: %s = %v, want nil", desc, err) t.Errorf("Frist task: %s = %v, want nil", desc, err)
continue continue
@ -1336,7 +1336,7 @@ func TestScheduleUnique(t *testing.T) {
} }
// Enqueue the second message, should fail. // Enqueue the second message, should fail.
got := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) got := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl)
if !errors.Is(got, errors.ErrDuplicateTask) { if !errors.Is(got, errors.ErrDuplicateTask) {
t.Errorf("Second task: %s = %v, want %v", desc, got, errors.ErrDuplicateTask) t.Errorf("Second task: %s = %v, want %v", desc, got, errors.ErrDuplicateTask)
continue continue
@ -1379,11 +1379,11 @@ func TestScheduleUniqueTaskIdConflictError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case. h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.ScheduleUnique(tc.firstMsg, processAt, ttl); err != nil { if err := r.ScheduleUnique(context.Background(), tc.firstMsg, processAt, ttl); err != nil {
t.Errorf("First message: ScheduleUnique failed: %v", err) t.Errorf("First message: ScheduleUnique failed: %v", err)
continue continue
} }
if err := r.ScheduleUnique(tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { if err := r.ScheduleUnique(context.Background(), tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict) t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
continue continue
} }

View File

@ -6,6 +6,7 @@
package testbroker package testbroker
import ( import (
"context"
"errors" "errors"
"sync" "sync"
"time" "time"
@ -45,22 +46,22 @@ func (tb *TestBroker) Wakeup() {
tb.sleeping = false tb.sleeping = false
} }
func (tb *TestBroker) Enqueue(msg *base.TaskMessage) error { func (tb *TestBroker) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
tb.mu.Lock() tb.mu.Lock()
defer tb.mu.Unlock() defer tb.mu.Unlock()
if tb.sleeping { if tb.sleeping {
return errRedisDown return errRedisDown
} }
return tb.real.Enqueue(msg) return tb.real.Enqueue(ctx, msg)
} }
func (tb *TestBroker) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { func (tb *TestBroker) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time.Duration) error {
tb.mu.Lock() tb.mu.Lock()
defer tb.mu.Unlock() defer tb.mu.Unlock()
if tb.sleeping { if tb.sleeping {
return errRedisDown return errRedisDown
} }
return tb.real.EnqueueUnique(msg, ttl) return tb.real.EnqueueUnique(ctx, msg, ttl)
} }
func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) { func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) {
@ -99,22 +100,22 @@ func (tb *TestBroker) Requeue(msg *base.TaskMessage) error {
return tb.real.Requeue(msg) return tb.real.Requeue(msg)
} }
func (tb *TestBroker) Schedule(msg *base.TaskMessage, processAt time.Time) error { func (tb *TestBroker) Schedule(ctx context.Context, msg *base.TaskMessage, processAt time.Time) error {
tb.mu.Lock() tb.mu.Lock()
defer tb.mu.Unlock() defer tb.mu.Unlock()
if tb.sleeping { if tb.sleeping {
return errRedisDown return errRedisDown
} }
return tb.real.Schedule(msg, processAt) return tb.real.Schedule(ctx, msg, processAt)
} }
func (tb *TestBroker) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { func (tb *TestBroker) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error {
tb.mu.Lock() tb.mu.Lock()
defer tb.mu.Unlock() defer tb.mu.Unlock()
if tb.sleeping { if tb.sleeping {
return errRedisDown return errRedisDown
} }
return tb.real.ScheduleUnique(msg, processAt, ttl) return tb.real.ScheduleUnique(ctx, msg, processAt, ttl)
} }
func (tb *TestBroker) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { func (tb *TestBroker) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error {

View File

@ -126,7 +126,7 @@ func TestProcessorSuccessWithSingleQueue(t *testing.T) {
p.start(&sync.WaitGroup{}) p.start(&sync.WaitGroup{})
for _, msg := range tc.incoming { for _, msg := range tc.incoming {
err := rdbClient.Enqueue(msg) err := rdbClient.Enqueue(context.Background(), msg)
if err != nil { if err != nil {
p.shutdown() p.shutdown()
t.Fatal(err) t.Fatal(err)