diff --git a/heartbeat_test.go b/heartbeat_test.go index ec75ace..3ddd4c5 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -205,7 +205,7 @@ func TestHeartbeater(t *testing.T) { // Wait for heartbeater to write to redis time.Sleep(tc.interval * 2) - ss, err := rdbClient.ListServers() + ss, err := rdbClient.ListServers(context.Background()) if err != nil { t.Errorf("%s: could not read server info from redis: %v", tc.desc, err) hb.shutdown() @@ -289,7 +289,7 @@ func TestHeartbeater(t *testing.T) { Status: "closed", ActiveWorkerCount: len(tc.startedWorkers) - len(tc.finishedTasks), } - ss, err = rdbClient.ListServers() + ss, err = rdbClient.ListServers(context.Background()) if err != nil { t.Errorf("%s: could not read server status from redis: %v", tc.desc, err) hb.shutdown() diff --git a/inspector.go b/inspector.go index e583cf6..49e36e6 100644 --- a/inspector.go +++ b/inspector.go @@ -5,12 +5,14 @@ package asynq import ( + "context" "fmt" "strconv" "strings" "time" "github.com/go-redis/redis/v8" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/rdb" @@ -39,13 +41,27 @@ func (i *Inspector) Close() error { } // Queues returns a list of all queue names. +// +// Queues uses context.Background internally; to specify the context, use QueuesContext. func (i *Inspector) Queues() ([]string, error) { - return i.rdb.AllQueues() + return i.QueuesContext(context.Background()) +} + +// QueuesContext returns a list of all queue names. +func (i *Inspector) QueuesContext(ctx context.Context) ([]string, error) { + return i.rdb.AllQueues(ctx) } // Groups returns a list of all groups within the given queue. +// +// Groups uses context.Background internally; to specify the context, use GroupsContext. func (i *Inspector) Groups(queue string) ([]*GroupInfo, error) { - stats, err := i.rdb.GroupStats(queue) + return i.GroupsContext(context.Background(), queue) +} + +// GroupsContext returns a list of all groups within the given queue. +func (i *Inspector) GroupsContext(ctx context.Context, queue string) ([]*GroupInfo, error) { + stats, err := i.rdb.GroupStats(ctx, queue) if err != nil { return nil, err } @@ -123,10 +139,15 @@ type QueueInfo struct { // GetQueueInfo returns current information of the given queue. func (i *Inspector) GetQueueInfo(queue string) (*QueueInfo, error) { + return i.GetQueueInfoContext(context.Background(), queue) +} + +// GetQueueInfoContext returns current information of the given queue. +func (i *Inspector) GetQueueInfoContext(ctx context.Context, queue string) (*QueueInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, err } - stats, err := i.rdb.CurrentStats(queue) + stats, err := i.rdb.CurrentStats(ctx, queue) if err != nil { return nil, err } @@ -167,10 +188,15 @@ type DailyStats struct { // History returns a list of stats from the last n days. func (i *Inspector) History(queue string, n int) ([]*DailyStats, error) { + return i.HistoryContext(context.Background(), queue, n) +} + +// HistoryContext returns a list of stats from the last n days. +func (i *Inspector) HistoryContext(ctx context.Context, queue string, n int) ([]*DailyStats, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, err } - stats, err := i.rdb.HistoricalStats(queue, n) + stats, err := i.rdb.HistoricalStats(ctx, queue, n) if err != nil { return nil, err } @@ -208,7 +234,21 @@ var ( // If force is set to false and the specified queue is not empty, DeleteQueue // returns ErrQueueNotEmpty. func (i *Inspector) DeleteQueue(queue string, force bool) error { - err := i.rdb.RemoveQueue(queue, force) + return i.DeleteQueueContext(context.Background(), queue, force) +} + +// DeleteQueueContext removes the specified queue. +// +// If force is set to true, DeleteQueue will remove the queue regardless of +// the queue size as long as no tasks are active in the queue. +// If force is set to false, DeleteQueue will remove the queue only if +// the queue is empty. +// +// If the specified queue does not exist, DeleteQueue returns ErrQueueNotFound. +// If force is set to false and the specified queue is not empty, DeleteQueue +// returns ErrQueueNotEmpty. +func (i *Inspector) DeleteQueueContext(ctx context.Context, queue string, force bool) error { + err := i.rdb.RemoveQueue(ctx, queue, force) if errors.IsQueueNotFound(err) { return fmt.Errorf("%w: queue=%q", ErrQueueNotFound, queue) } @@ -223,7 +263,15 @@ func (i *Inspector) DeleteQueue(queue string, force bool) error { // Returns an error wrapping ErrQueueNotFound if a queue with the given name doesn't exist. // Returns an error wrapping ErrTaskNotFound if a task with the given id doesn't exist in the queue. func (i *Inspector) GetTaskInfo(queue, id string) (*TaskInfo, error) { - info, err := i.rdb.GetTaskInfo(queue, id) + return i.GetTaskInfoContext(context.Background(), queue, id) +} + +// GetTaskInfoContext retrieves task information given a task id and queue name. +// +// Returns an error wrapping ErrQueueNotFound if a queue with the given name doesn't exist. +// Returns an error wrapping ErrTaskNotFound if a task with the given id doesn't exist in the queue. +func (i *Inspector) GetTaskInfoContext(ctx context.Context, queue, id string) (*TaskInfo, error) { + info, err := i.rdb.GetTaskInfo(ctx, queue, id) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -300,12 +348,19 @@ func Page(n int) ListOption { // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListPendingTasksContext(context.Background(), queue, opts...) +} + +// ListPendingTasksContext retrieves pending tasks from the specified queue. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListPendingTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListPending(queue, pgn) + infos, err := i.rdb.ListPending(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -328,19 +383,26 @@ func (i *Inspector) ListPendingTasks(queue string, opts ...ListOption) ([]*TaskI // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListActiveTasksContext(context.Background(), queue, opts...) +} + +// ListActiveTasksContext retrieves active tasks from the specified queue. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListActiveTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListActive(queue, pgn) + infos, err := i.rdb.ListActive(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) case err != nil: return nil, fmt.Errorf("asynq: %v", err) } - expired, err := i.rdb.ListLeaseExpired(time.Now(), queue) + expired, err := i.rdb.ListLeaseExpiredContext(ctx, time.Now(), queue) if err != nil { return nil, fmt.Errorf("asynq: %v", err) } @@ -368,12 +430,19 @@ func (i *Inspector) ListActiveTasks(queue string, opts ...ListOption) ([]*TaskIn // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListAggregatingTasksContext(context.Background(), queue, group, opts...) +} + +// ListAggregatingTasksContext retrieves scheduled tasks from the specified group. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListAggregatingTasksContext(ctx context.Context, queue, group string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListAggregating(queue, group, pgn) + infos, err := i.rdb.ListAggregating(ctx, queue, group, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -397,12 +466,20 @@ func (i *Inspector) ListAggregatingTasks(queue, group string, opts ...ListOption // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListScheduledTasksContext(context.Background(), queue, opts...) +} + +// ListScheduledTasksContext retrieves scheduled tasks from the specified queue. +// Tasks are sorted by NextProcessAt in ascending order. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListScheduledTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListScheduled(queue, pgn) + infos, err := i.rdb.ListScheduled(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -426,12 +503,20 @@ func (i *Inspector) ListScheduledTasks(queue string, opts ...ListOption) ([]*Tas // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListRetryTasksContext(context.Background(), queue, opts...) +} + +// ListRetryTasksContext retrieves retry tasks from the specified queue. +// Tasks are sorted by NextProcessAt in ascending order. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListRetryTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListRetry(queue, pgn) + infos, err := i.rdb.ListRetry(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -455,12 +540,20 @@ func (i *Inspector) ListRetryTasks(queue string, opts ...ListOption) ([]*TaskInf // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListArchivedTasksContext(context.Background(), queue, opts...) +} + +// ListArchivedTasksContext retrieves archived tasks from the specified queue. +// Tasks are sorted by LastFailedAt in descending order. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListArchivedTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListArchived(queue, pgn) + infos, err := i.rdb.ListArchived(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -484,12 +577,20 @@ func (i *Inspector) ListArchivedTasks(queue string, opts ...ListOption) ([]*Task // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListCompletedTasks(queue string, opts ...ListOption) ([]*TaskInfo, error) { + return i.ListCompletedTasksContext(context.Background(), queue, opts...) +} + +// ListCompletedTasksContext retrieves completed tasks from the specified queue. +// Tasks are sorted by expiration time (i.e. CompletedAt + Retention) in descending order. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListCompletedTasksContext(ctx context.Context, queue string, opts ...ListOption) ([]*TaskInfo, error) { if err := base.ValidateQueueName(queue); err != nil { return nil, fmt.Errorf("asynq: %v", err) } opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - infos, err := i.rdb.ListCompleted(queue, pgn) + infos, err := i.rdb.ListCompleted(ctx, queue, pgn) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -511,60 +612,96 @@ func (i *Inspector) ListCompletedTasks(queue string, opts ...ListOption) ([]*Tas // DeleteAllPendingTasks deletes all pending tasks from the specified queue, // and reports the number tasks deleted. func (i *Inspector) DeleteAllPendingTasks(queue string) (int, error) { + return i.DeleteAllPendingTasksContext(context.Background(), queue) +} + +// DeleteAllPendingTasksContext deletes all pending tasks from the specified queue, +// and reports the number tasks deleted. +func (i *Inspector) DeleteAllPendingTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllPendingTasks(queue) + n, err := i.rdb.DeleteAllPendingTasks(ctx, queue) return int(n), err } // DeleteAllScheduledTasks deletes all scheduled tasks from the specified queue, // and reports the number tasks deleted. func (i *Inspector) DeleteAllScheduledTasks(queue string) (int, error) { + return i.DeleteAllScheduledTasksContext(context.Background(), queue) +} + +// DeleteAllScheduledTasksContext deletes all scheduled tasks from the specified queue, +// and reports the number tasks deleted. +func (i *Inspector) DeleteAllScheduledTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllScheduledTasks(queue) + n, err := i.rdb.DeleteAllScheduledTasks(ctx, queue) return int(n), err } // DeleteAllRetryTasks deletes all retry tasks from the specified queue, // and reports the number tasks deleted. func (i *Inspector) DeleteAllRetryTasks(queue string) (int, error) { + return i.DeleteAllRetryTasksContext(context.Background(), queue) +} + +// DeleteAllRetryTasksContext deletes all retry tasks from the specified queue, +// and reports the number tasks deleted. +func (i *Inspector) DeleteAllRetryTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllRetryTasks(queue) + n, err := i.rdb.DeleteAllRetryTasks(ctx, queue) return int(n), err } // DeleteAllArchivedTasks deletes all archived tasks from the specified queue, // and reports the number tasks deleted. func (i *Inspector) DeleteAllArchivedTasks(queue string) (int, error) { + return i.DeleteAllArchivedTasksContext(context.Background(), queue) +} + +// DeleteAllArchivedTasksContext deletes all archived tasks from the specified queue, +// and reports the number tasks deleted. +func (i *Inspector) DeleteAllArchivedTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllArchivedTasks(queue) + n, err := i.rdb.DeleteAllArchivedTasks(ctx, queue) return int(n), err } // DeleteAllCompletedTasks deletes all completed tasks from the specified queue, // and reports the number tasks deleted. func (i *Inspector) DeleteAllCompletedTasks(queue string) (int, error) { + return i.DeleteAllCompletedTasksContext(context.Background(), queue) +} + +// DeleteAllCompletedTasksContext deletes all completed tasks from the specified queue, +// and reports the number tasks deleted. +func (i *Inspector) DeleteAllCompletedTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllCompletedTasks(queue) + n, err := i.rdb.DeleteAllCompletedTasks(ctx, queue) return int(n), err } // DeleteAllAggregatingTasks deletes all tasks from the specified group, // and reports the number of tasks deleted. func (i *Inspector) DeleteAllAggregatingTasks(queue, group string) (int, error) { + return i.DeleteAllAggregatingTasksContext(context.Background(), queue, group) +} + +// DeleteAllAggregatingTasksContext deletes all tasks from the specified group, +// and reports the number of tasks deleted. +func (i *Inspector) DeleteAllAggregatingTasksContext(ctx context.Context, queue, group string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.DeleteAllAggregatingTasks(queue, group) + n, err := i.rdb.DeleteAllAggregatingTasks(ctx, queue, group) return int(n), err } @@ -575,11 +712,22 @@ func (i *Inspector) DeleteAllAggregatingTasks(queue, group string) (int, error) // If a queue with the given name doesn't exist, it returns an error wrapping ErrQueueNotFound. // If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. // If the task is in active state, it returns a non-nil error. -func (i *Inspector) DeleteTask(queue, id string) error { - if err := base.ValidateQueueName(queue); err != nil { +func (i *Inspector) DeleteTask(qname, id string) error { + return i.DeleteTaskContext(context.Background(), qname, id) +} + +// DeleteTaskContext deletes a task with the given id from the given queue. +// The task needs to be in pending, scheduled, retry, or archived state, +// otherwise DeleteTask will return an error. +// +// If a queue with the given name doesn't exist, it returns an error wrapping ErrQueueNotFound. +// If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. +// If the task is in active state, it returns a non-nil error. +func (i *Inspector) DeleteTaskContext(ctx context.Context, qname, id string) error { + if err := base.ValidateQueueName(qname); err != nil { return fmt.Errorf("asynq: %v", err) } - err := i.rdb.DeleteTask(queue, id) + err := i.rdb.DeleteTask(ctx, qname, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -589,46 +737,69 @@ func (i *Inspector) DeleteTask(queue, id string) error { return fmt.Errorf("asynq: %v", err) } return nil - } // RunAllScheduledTasks schedules all scheduled tasks from the given queue to run, // and reports the number of tasks scheduled to run. func (i *Inspector) RunAllScheduledTasks(queue string) (int, error) { + return i.RunAllScheduledTasksContext(context.Background(), queue) +} + +// RunAllScheduledTasksContext schedules all scheduled tasks from the given queue to run, +// and reports the number of tasks scheduled to run. +func (i *Inspector) RunAllScheduledTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.RunAllScheduledTasks(queue) + n, err := i.rdb.RunAllScheduledTasks(ctx, queue) return int(n), err } // RunAllRetryTasks schedules all retry tasks from the given queue to run, // and reports the number of tasks scheduled to run. func (i *Inspector) RunAllRetryTasks(queue string) (int, error) { + return i.RunAllRetryTasksContext(context.Background(), queue) +} + +// RunAllRetryTasksContext schedules all retry tasks from the given queue to run, +// and reports the number of tasks scheduled to run. +func (i *Inspector) RunAllRetryTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.RunAllRetryTasks(queue) + n, err := i.rdb.RunAllRetryTasks(ctx, queue) return int(n), err } // RunAllArchivedTasks schedules all archived tasks from the given queue to run, // and reports the number of tasks scheduled to run. func (i *Inspector) RunAllArchivedTasks(queue string) (int, error) { + return i.RunAllArchivedTasksContext(context.Background(), queue) +} + +// RunAllArchivedTasksContext schedules all archived tasks from the given queue to run, +// and reports the number of tasks scheduled to run. +func (i *Inspector) RunAllArchivedTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.RunAllArchivedTasks(queue) + n, err := i.rdb.RunAllArchivedTasks(ctx, queue) return int(n), err } // RunAllAggregatingTasks schedules all tasks from the given grou to run. // and reports the number of tasks scheduled to run. func (i *Inspector) RunAllAggregatingTasks(queue, group string) (int, error) { + return i.RunAllAggregatingTasksContext(context.Background(), queue, group) +} + +// RunAllAggregatingTasksContext schedules all tasks from the given grou to run. +// and reports the number of tasks scheduled to run. +func (i *Inspector) RunAllAggregatingTasksContext(ctx context.Context, queue, group string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.RunAllAggregatingTasks(queue, group) + n, err := i.rdb.RunAllAggregatingTasks(ctx, queue, group) return int(n), err } @@ -640,10 +811,21 @@ func (i *Inspector) RunAllAggregatingTasks(queue, group string) (int, error) { // If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. // If the task is in pending or active state, it returns a non-nil error. func (i *Inspector) RunTask(queue, id string) error { + return i.RunTaskContext(context.Background(), queue, id) +} + +// RunTaskContext updates the task to pending state given a queue name and task id. +// The task needs to be in scheduled, retry, or archived state, otherwise RunTask +// will return an error. +// +// If a queue with the given name doesn't exist, it returns an error wrapping ErrQueueNotFound. +// If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. +// If the task is in pending or active state, it returns a non-nil error. +func (i *Inspector) RunTaskContext(ctx context.Context, queue, id string) error { if err := base.ValidateQueueName(queue); err != nil { return fmt.Errorf("asynq: %v", err) } - err := i.rdb.RunTask(queue, id) + err := i.rdb.RunTask(ctx, queue, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -658,40 +840,64 @@ func (i *Inspector) RunTask(queue, id string) error { // ArchiveAllPendingTasks archives all pending tasks from the given queue, // and reports the number of tasks archived. func (i *Inspector) ArchiveAllPendingTasks(queue string) (int, error) { + return i.ArchiveAllPendingTasksContext(context.Background(), queue) +} + +// ArchiveAllPendingTasksContext archives all pending tasks from the given queue, +// and reports the number of tasks archived. +func (i *Inspector) ArchiveAllPendingTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.ArchiveAllPendingTasks(queue) + n, err := i.rdb.ArchiveAllPendingTasks(ctx, queue) return int(n), err } // ArchiveAllScheduledTasks archives all scheduled tasks from the given queue, -// and reports the number of tasks archiveed. +// and reports the number of tasks archived. func (i *Inspector) ArchiveAllScheduledTasks(queue string) (int, error) { + return i.ArchiveAllScheduledTasksContext(context.Background(), queue) +} + +// ArchiveAllScheduledTasksContext archives all scheduled tasks from the given queue, +// and reports the number of tasks archived. +func (i *Inspector) ArchiveAllScheduledTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.ArchiveAllScheduledTasks(queue) + n, err := i.rdb.ArchiveAllScheduledTasks(ctx, queue) return int(n), err } // ArchiveAllRetryTasks archives all retry tasks from the given queue, -// and reports the number of tasks archiveed. +// and reports the number of tasks archived. func (i *Inspector) ArchiveAllRetryTasks(queue string) (int, error) { + return i.ArchiveAllRetryTasksContext(context.Background(), queue) +} + +// ArchiveAllRetryTasksContext archives all retry tasks from the given queue, +// and reports the number of tasks archived. +func (i *Inspector) ArchiveAllRetryTasksContext(ctx context.Context, queue string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.ArchiveAllRetryTasks(queue) + n, err := i.rdb.ArchiveAllRetryTasks(ctx, queue) return int(n), err } // ArchiveAllAggregatingTasks archives all tasks from the given group, // and reports the number of tasks archived. func (i *Inspector) ArchiveAllAggregatingTasks(queue, group string) (int, error) { + return i.ArchiveAllAggregatingTasksContext(context.Background(), queue, group) +} + +// ArchiveAllAggregatingTasksContext archives all tasks from the given group, +// and reports the number of tasks archived. +func (i *Inspector) ArchiveAllAggregatingTasksContext(ctx context.Context, queue, group string) (int, error) { if err := base.ValidateQueueName(queue); err != nil { return 0, err } - n, err := i.rdb.ArchiveAllAggregatingTasks(queue, group) + n, err := i.rdb.ArchiveAllAggregatingTasks(ctx, queue, group) return int(n), err } @@ -703,10 +909,21 @@ func (i *Inspector) ArchiveAllAggregatingTasks(queue, group string) (int, error) // If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. // If the task is in already archived, it returns a non-nil error. func (i *Inspector) ArchiveTask(queue, id string) error { + return i.ArchiveTaskContext(context.Background(), queue, id) +} + +// ArchiveTaskContext archives a task with the given id in the given queue. +// The task needs to be in pending, scheduled, or retry state, otherwise ArchiveTask +// will return an error. +// +// If a queue with the given name doesn't exist, it returns an error wrapping ErrQueueNotFound. +// If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. +// If the task is in already archived, it returns a non-nil error. +func (i *Inspector) ArchiveTaskContext(ctx context.Context, queue, id string) error { if err := base.ValidateQueueName(queue); err != nil { return fmt.Errorf("asynq: err") } - err := i.rdb.ArchiveTask(queue, id) + err := i.rdb.ArchiveTask(ctx, queue, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -723,34 +940,59 @@ func (i *Inspector) ArchiveTask(queue, id string) error { // guarantee that the task with the given id will be canceled. The return // value only indicates whether the cancelation signal has been sent. func (i *Inspector) CancelProcessing(id string) error { - return i.rdb.PublishCancelation(id) + return i.CancelProcessingContext(context.Background(), id) +} + +// CancelProcessingContext sends a signal to cancel processing of the task +// given a task id. CancelProcessing is best-effort, which means that it does not +// guarantee that the task with the given id will be canceled. The return +// value only indicates whether the cancelation signal has been sent. +func (i *Inspector) CancelProcessingContext(ctx context.Context, id string) error { + return i.rdb.PublishCancelationContext(ctx, id) } // PauseQueue pauses task processing on the specified queue. // If the queue is already paused, it will return a non-nil error. func (i *Inspector) PauseQueue(queue string) error { + return i.PauseQueueContext(context.Background(), queue) +} + +// PauseQueueContext pauses task processing on the specified queue. +// If the queue is already paused, it will return a non-nil error. +func (i *Inspector) PauseQueueContext(ctx context.Context, queue string) error { if err := base.ValidateQueueName(queue); err != nil { return err } - return i.rdb.Pause(queue) + return i.rdb.Pause(ctx, queue) } // UnpauseQueue resumes task processing on the specified queue. // If the queue is not paused, it will return a non-nil error. func (i *Inspector) UnpauseQueue(queue string) error { + return i.UnpauseQueueContext(context.Background(), queue) +} + +// UnpauseQueueContext resumes task processing on the specified queue. +// If the queue is not paused, it will return a non-nil error. +func (i *Inspector) UnpauseQueueContext(ctx context.Context, queue string) error { if err := base.ValidateQueueName(queue); err != nil { return err } - return i.rdb.Unpause(queue) + return i.rdb.Unpause(ctx, queue) } // Servers return a list of running servers' information. func (i *Inspector) Servers() ([]*ServerInfo, error) { - servers, err := i.rdb.ListServers() + return i.ServersContext(context.Background()) +} + +// ServersContext return a list of running servers' information. +func (i *Inspector) ServersContext(ctx context.Context) ([]*ServerInfo, error) { + servers, err := i.rdb.ListServers(ctx) if err != nil { return nil, err } - workers, err := i.rdb.ListWorkers() + workers, err := i.rdb.ListWorkers(ctx) if err != nil { return nil, err } @@ -832,7 +1074,12 @@ type WorkerInfo struct { // ClusterKeySlot returns an integer identifying the hash slot the given queue hashes to. func (i *Inspector) ClusterKeySlot(queue string) (int64, error) { - return i.rdb.ClusterKeySlot(queue) + return i.ClusterKeySlotContext(context.Background(), queue) +} + +// ClusterKeySlotContext returns an integer identifying the hash slot the given queue hashes to. +func (i *Inspector) ClusterKeySlotContext(ctx context.Context, queue string) (int64, error) { + return i.rdb.ClusterKeySlot(ctx, queue) } // ClusterNode describes a node in redis cluster. @@ -848,7 +1095,14 @@ type ClusterNode struct { // // Only relevant if task queues are stored in redis cluster. func (i *Inspector) ClusterNodes(queue string) ([]*ClusterNode, error) { - nodes, err := i.rdb.ClusterNodes(queue) + return i.ClusterNodesContext(context.Background(), queue) +} + +// ClusterNodesContext returns a list of nodes the given queue belongs to. +// +// Only relevant if task queues are stored in redis cluster. +func (i *Inspector) ClusterNodesContext(ctx context.Context, queue string) ([]*ClusterNode, error) { + nodes, err := i.rdb.ClusterNodes(ctx, queue) if err != nil { return nil, err } @@ -884,8 +1138,14 @@ type SchedulerEntry struct { // SchedulerEntries returns a list of all entries registered with // currently running schedulers. func (i *Inspector) SchedulerEntries() ([]*SchedulerEntry, error) { + return i.SchedulerEntriesContext(context.Background()) +} + +// SchedulerEntriesContext returns a list of all entries registered with +// currently running schedulers. +func (i *Inspector) SchedulerEntriesContext(ctx context.Context) ([]*SchedulerEntry, error) { var entries []*SchedulerEntry - res, err := i.rdb.ListSchedulerEntries() + res, err := i.rdb.ListSchedulerEntries(ctx) if err != nil { return nil, err } @@ -997,9 +1257,16 @@ type SchedulerEnqueueEvent struct { // // By default, it retrieves the first 30 tasks. func (i *Inspector) ListSchedulerEnqueueEvents(entryID string, opts ...ListOption) ([]*SchedulerEnqueueEvent, error) { + return i.ListSchedulerEnqueueEventsContext(context.Background(), entryID, opts...) +} + +// ListSchedulerEnqueueEventsContext retrieves a list of enqueue events from the specified scheduler entry. +// +// By default, it retrieves the first 30 tasks. +func (i *Inspector) ListSchedulerEnqueueEventsContext(ctx context.Context, entryID string, opts ...ListOption) ([]*SchedulerEnqueueEvent, error) { opt := composeListOptions(opts...) pgn := rdb.Pagination{Size: opt.pageSize, Page: opt.pageNum - 1} - data, err := i.rdb.ListSchedulerEnqueueEvents(entryID, pgn) + data, err := i.rdb.ListSchedulerEnqueueEvents(ctx, entryID, pgn) if err != nil { return nil, err } diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 6deb4f1..4e4b470 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -11,14 +11,15 @@ import ( "time" "github.com/go-redis/redis/v8" + "github.com/spf13/cast" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" - "github.com/spf13/cast" ) // AllQueues returns a list of all queue names. -func (r *RDB) AllQueues() ([]string, error) { - return r.client.SMembers(context.Background(), base.AllQueues).Result() +func (r *RDB) AllQueues(ctx context.Context) ([]string, error) { + return r.client.SMembers(ctx, base.AllQueues).Result() } // Stats represents a state of queues at a certain time. @@ -137,9 +138,9 @@ table.insert(res, aggregating_count) return res`) // CurrentStats returns a current state of the queues. -func (r *RDB) CurrentStats(qname string) (*Stats, error) { +func (r *RDB) CurrentStats(ctx context.Context, qname string) (*Stats, error) { var op errors.Op = "rdb.CurrentStats" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, err) } @@ -165,7 +166,7 @@ func (r *RDB) CurrentStats(qname string) (*Stats, error) { base.TaskKeyPrefix(qname), base.GroupKeyPrefix(qname), } - res, err := currentStatsCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := currentStatsCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return nil, errors.E(op, errors.Unknown, err) } @@ -228,7 +229,7 @@ func (r *RDB) CurrentStats(qname string) (*Stats, error) { } } stats.Size = size - memusg, err := r.memoryUsage(qname) + memusg, err := r.memoryUsage(ctx, qname) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -315,7 +316,7 @@ end return memusg `) -func (r *RDB) memoryUsage(qname string) (int64, error) { +func (r *RDB) memoryUsage(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.memoryUsage" const ( taskSampleSize = 20 @@ -337,7 +338,7 @@ func (r *RDB) memoryUsage(qname string) (int64, error) { groupSampleSize, base.GroupKeyPrefix(qname), } - res, err := memoryUsageCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := memoryUsageCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) } @@ -360,12 +361,12 @@ end return res`) // HistoricalStats returns a list of stats from the last n days for the given queue. -func (r *RDB) HistoricalStats(qname string, n int) ([]*DailyStats, error) { +func (r *RDB) HistoricalStats(ctx context.Context, qname string, n int) ([]*DailyStats, error) { var op errors.Op = "rdb.HistoricalStats" if n < 1 { return nil, errors.E(op, errors.FailedPrecondition, "the number of days must be positive") } - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } @@ -382,7 +383,7 @@ func (r *RDB) HistoricalStats(qname string, n int) ([]*DailyStats, error) { keys = append(keys, base.ProcessedKey(qname, ts)) keys = append(keys, base.FailedKey(qname, ts)) } - res, err := historicalStatsCmd.Run(context.Background(), r.client, keys).Result() + res, err := historicalStatsCmd.Run(ctx, r.client, keys).Result() if err != nil { return nil, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) } @@ -403,8 +404,8 @@ func (r *RDB) HistoricalStats(qname string, n int) ([]*DailyStats, error) { } // RedisInfo returns a map of redis info. -func (r *RDB) RedisInfo() (map[string]string, error) { - res, err := r.client.Info(context.Background()).Result() +func (r *RDB) RedisInfo(ctx context.Context) (map[string]string, error) { + res, err := r.client.Info(ctx).Result() if err != nil { return nil, err } @@ -412,8 +413,8 @@ func (r *RDB) RedisInfo() (map[string]string, error) { } // RedisClusterInfo returns a map of redis cluster info. -func (r *RDB) RedisClusterInfo() (map[string]string, error) { - res, err := r.client.ClusterInfo(context.Background()).Result() +func (r *RDB) RedisClusterInfo(ctx context.Context) (map[string]string, error) { + res, err := r.client.ClusterInfo(ctx).Result() if err != nil { return nil, err } @@ -442,8 +443,8 @@ func reverse(x []*base.TaskInfo) { // checkQueueExists verifies whether the queue exists. // It returns QueueNotFoundError if queue doesn't exist. -func (r *RDB) checkQueueExists(qname string) error { - exists, err := r.queueExists(qname) +func (r *RDB) checkQueueExists(ctx context.Context, qname string) error { + exists, err := r.queueExists(ctx, qname) if err != nil { return errors.E(errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } @@ -482,9 +483,9 @@ var getTaskInfoCmd = redis.NewScript(` `) // GetTaskInfo returns a TaskInfo describing the task from the given queue. -func (r *RDB) GetTaskInfo(qname, id string) (*base.TaskInfo, error) { +func (r *RDB) GetTaskInfo(ctx context.Context, qname, id string) (*base.TaskInfo, error) { var op errors.Op = "rdb.GetTaskInfo" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{base.TaskKey(qname, id)} @@ -493,7 +494,7 @@ func (r *RDB) GetTaskInfo(qname, id string) (*base.TaskInfo, error) { r.clock.Now().Unix(), base.QueueKeyPrefix(qname), } - res, err := getTaskInfoCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := getTaskInfoCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { if err.Error() == "NOT FOUND" { return nil, errors.E(op, errors.NotFound, &errors.TaskNotFoundError{Queue: qname, ID: id}) @@ -575,11 +576,11 @@ end return res `) -func (r *RDB) GroupStats(qname string) ([]*GroupStat, error) { +func (r *RDB) GroupStats(ctx context.Context, qname string) ([]*GroupStat, error) { var op errors.Op = "RDB.GroupStats" keys := []string{base.AllGroups(qname)} argv := []interface{}{base.GroupKeyPrefix(qname)} - res, err := groupStatsCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := groupStatsCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return nil, errors.E(op, errors.Unknown, err) } @@ -616,16 +617,16 @@ func (p Pagination) stop() int64 { } // ListPending returns pending tasks that are ready to be processed. -func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListPending(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListPending" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStatePending, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStatePending, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -633,16 +634,16 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error } // ListActive returns all tasks that are currently being processed for the given queue. -func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListActive(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListActive" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listMessages(qname, base.TaskStateActive, pgn) + res, err := r.listMessages(ctx, qname, base.TaskStateActive, pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -666,7 +667,7 @@ return data `) // listMessages returns a list of TaskInfo in Redis list with the given key. -func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) listMessages(ctx context.Context, qname string, state base.TaskState, pgn Pagination) ([]*base.TaskInfo, error) { var key string switch state { case base.TaskStateActive: @@ -680,7 +681,7 @@ func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ( // correct range and reverse the list to get the tasks with pagination. stop := -pgn.start() - 1 start := -pgn.stop() - 1 - res, err := listMessagesCmd.Run(context.Background(), r.client, + res, err := listMessagesCmd.Run(ctx, r.client, []string{key}, start, stop, base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) @@ -717,16 +718,16 @@ func (r *RDB) listMessages(qname string, state base.TaskState, pgn Pagination) ( // ListScheduled returns all tasks from the given queue that are scheduled // to be processed in the future. -func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListScheduled(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListScheduled" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateScheduled, base.ScheduledKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -734,17 +735,17 @@ func (r *RDB) ListScheduled(qname string, pgn Pagination) ([]*base.TaskInfo, err } // ListRetry returns all tasks from the given queue that have failed before -// and willl be retried in the future. -func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +// and will be retried in the future. +func (r *RDB) ListRetry(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListRetry" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - res, err := r.listZSetEntries(qname, base.TaskStateRetry, base.RetryKey(qname), pgn) + res, err := r.listZSetEntries(ctx, qname, base.TaskStateRetry, base.RetryKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -752,16 +753,16 @@ func (r *RDB) ListRetry(qname string, pgn Pagination) ([]*base.TaskInfo, error) } // ListArchived returns all tasks from the given queue that have exhausted its retry limit. -func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListArchived(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListArchived" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateArchived, base.ArchivedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -769,16 +770,16 @@ func (r *RDB) ListArchived(qname string, pgn Pagination) ([]*base.TaskInfo, erro } // ListCompleted returns all tasks from the given queue that have completed successfully. -func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListCompleted(ctx context.Context, qname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListCompleted" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateCompleted, base.CompletedKey(qname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -786,16 +787,16 @@ func (r *RDB) ListCompleted(qname string, pgn Pagination) ([]*base.TaskInfo, err } // ListAggregating returns all tasks from the given group. -func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { +func (r *RDB) ListAggregating(ctx context.Context, qname, gname string, pgn Pagination) ([]*base.TaskInfo, error) { var op errors.Op = "rdb.ListAggregating" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return nil, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sismember", Err: err}) } if !exists { return nil, errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) } - zs, err := r.listZSetEntries(qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) + zs, err := r.listZSetEntries(ctx, qname, base.TaskStateAggregating, base.GroupKey(qname, gname), pgn) if err != nil { return nil, errors.E(op, errors.CanonicalCode(err), err) } @@ -803,8 +804,8 @@ func (r *RDB) ListAggregating(qname, gname string, pgn Pagination) ([]*base.Task } // Reports whether a queue with the given name exists. -func (r *RDB) queueExists(qname string) (bool, error) { - return r.client.SIsMember(context.Background(), base.AllQueues, qname).Result() +func (r *RDB) queueExists(ctx context.Context, qname string) (bool, error) { + return r.client.SIsMember(ctx, base.AllQueues, qname).Result() } // KEYS[1] -> key for ids set (e.g. asynq:{}:scheduled) @@ -831,8 +832,8 @@ return data // listZSetEntries returns a list of message and score pairs in Redis sorted-set // with the given key. -func (r *RDB) listZSetEntries(qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { - res, err := listZSetEntriesCmd.Run(context.Background(), r.client, []string{key}, +func (r *RDB) listZSetEntries(ctx context.Context, qname string, state base.TaskState, key string, pgn Pagination) ([]*base.TaskInfo, error) { + res, err := listZSetEntriesCmd.Run(ctx, r.client, []string{key}, pgn.start(), pgn.stop(), base.TaskKeyPrefix(qname)).Result() if err != nil { return nil, errors.E(errors.Unknown, err) @@ -880,9 +881,9 @@ func (r *RDB) listZSetEntries(qname string, state base.TaskState, key string, pg // RunAllScheduledTasks enqueues all scheduled tasks from the given queue // and returns the number of tasks enqueued. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) RunAllScheduledTasks(qname string) (int64, error) { +func (r *RDB) RunAllScheduledTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.RunAllScheduledTasks" - n, err := r.runAll(base.ScheduledKey(qname), qname) + n, err := r.runAll(ctx, base.ScheduledKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -895,9 +896,9 @@ func (r *RDB) RunAllScheduledTasks(qname string) (int64, error) { // RunAllRetryTasks enqueues all retry tasks from the given queue // and returns the number of tasks enqueued. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) RunAllRetryTasks(qname string) (int64, error) { +func (r *RDB) RunAllRetryTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.RunAllRetryTasks" - n, err := r.runAll(base.RetryKey(qname), qname) + n, err := r.runAll(ctx, base.RetryKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -910,9 +911,9 @@ func (r *RDB) RunAllRetryTasks(qname string) (int64, error) { // RunAllArchivedTasks enqueues all archived tasks from the given queue // and returns the number of tasks enqueued. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) RunAllArchivedTasks(qname string) (int64, error) { +func (r *RDB) RunAllArchivedTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.RunAllArchivedTasks" - n, err := r.runAll(base.ArchivedKey(qname), qname) + n, err := r.runAll(ctx, base.ArchivedKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -948,9 +949,9 @@ return table.getn(ids) // RunAllAggregatingTasks schedules all tasks from the given queue to run // and returns the number of tasks scheduled to run. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) RunAllAggregatingTasks(qname, gname string) (int64, error) { +func (r *RDB) RunAllAggregatingTasks(ctx context.Context, qname, gname string) (int64, error) { var op errors.Op = "rdb.RunAllAggregatingTasks" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -962,7 +963,7 @@ func (r *RDB) RunAllAggregatingTasks(qname, gname string) (int64, error) { base.TaskKeyPrefix(qname), gname, } - res, err := runAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := runAllAggregatingCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Internal, err) } @@ -1025,9 +1026,9 @@ return 1 // If a queue with the given name doesn't exist, it returns QueueNotFoundError. // If a task with the given id doesn't exist in the queue, it returns TaskNotFoundError // If a task is in active or pending state it returns non-nil error with Code FailedPrecondition. -func (r *RDB) RunTask(qname, id string) error { +func (r *RDB) RunTask(ctx context.Context, qname, id string) error { var op errors.Op = "rdb.RunTask" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1040,7 +1041,7 @@ func (r *RDB) RunTask(qname, id string) error { base.QueueKeyPrefix(qname), base.GroupKeyPrefix(qname), } - res, err := runTaskCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := runTaskCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return errors.E(op, errors.Unknown, err) } @@ -1082,8 +1083,8 @@ end redis.call("DEL", KEYS[1]) return table.getn(ids)`) -func (r *RDB) runAll(zset, qname string) (int64, error) { - if err := r.checkQueueExists(qname); err != nil { +func (r *RDB) runAll(ctx context.Context, zset, qname string) (int64, error) { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, err } keys := []string{ @@ -1093,7 +1094,7 @@ func (r *RDB) runAll(zset, qname string) (int64, error) { argv := []interface{}{ base.TaskKeyPrefix(qname), } - res, err := runAllCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := runAllCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, err } @@ -1110,9 +1111,9 @@ func (r *RDB) runAll(zset, qname string) (int64, error) { // ArchiveAllRetryTasks archives all retry tasks from the given queue and // returns the number of tasks that were moved. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) ArchiveAllRetryTasks(qname string) (int64, error) { +func (r *RDB) ArchiveAllRetryTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.ArchiveAllRetryTasks" - n, err := r.archiveAll(base.RetryKey(qname), base.ArchivedKey(qname), qname) + n, err := r.archiveAll(ctx, base.RetryKey(qname), base.ArchivedKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1125,9 +1126,9 @@ func (r *RDB) ArchiveAllRetryTasks(qname string) (int64, error) { // ArchiveAllScheduledTasks archives all scheduled tasks from the given queue and // returns the number of tasks that were moved. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) ArchiveAllScheduledTasks(qname string) (int64, error) { +func (r *RDB) ArchiveAllScheduledTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.ArchiveAllScheduledTasks" - n, err := r.archiveAll(base.ScheduledKey(qname), base.ArchivedKey(qname), qname) + n, err := r.archiveAll(ctx, base.ScheduledKey(qname), base.ArchivedKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1168,9 +1169,9 @@ return table.getn(ids) // ArchiveAllAggregatingTasks archives all aggregating tasks from the given group // and returns the number of tasks archived. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) ArchiveAllAggregatingTasks(qname, gname string) (int64, error) { +func (r *RDB) ArchiveAllAggregatingTasks(ctx context.Context, qname, gname string) (int64, error) { var op errors.Op = "rdb.ArchiveAllAggregatingTasks" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1186,7 +1187,7 @@ func (r *RDB) ArchiveAllAggregatingTasks(qname, gname string) (int64, error) { base.TaskKeyPrefix(qname), gname, } - res, err := archiveAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := archiveAllAggregatingCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Internal, err) } @@ -1225,9 +1226,9 @@ return table.getn(ids)`) // ArchiveAllPendingTasks archives all pending tasks from the given queue and // returns the number of tasks moved. // If a queue with the given name doesn't exist, it returns QueueNotFoundError. -func (r *RDB) ArchiveAllPendingTasks(qname string) (int64, error) { +func (r *RDB) ArchiveAllPendingTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.ArchiveAllPendingTasks" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1241,7 +1242,7 @@ func (r *RDB) ArchiveAllPendingTasks(qname string) (int64, error) { maxArchiveSize, base.TaskKeyPrefix(qname), } - res, err := archiveAllPendingCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := archiveAllPendingCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Internal, err) } @@ -1314,9 +1315,9 @@ return 1 // If a task with the given id doesn't exist in the queue, it returns TaskNotFoundError // If a task is already archived, it returns TaskAlreadyArchivedError. // If a task is in active state it returns non-nil error with FailedPrecondition code. -func (r *RDB) ArchiveTask(qname, id string) error { +func (r *RDB) ArchiveTask(ctx context.Context, qname, id string) error { var op errors.Op = "rdb.ArchiveTask" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1333,7 +1334,7 @@ func (r *RDB) ArchiveTask(qname, id string) error { base.QueueKeyPrefix(qname), base.GroupKeyPrefix(qname), } - res, err := archiveTaskCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := archiveTaskCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return errors.E(op, errors.Unknown, err) } @@ -1382,8 +1383,8 @@ redis.call("ZREMRANGEBYRANK", KEYS[2], 0, -ARGV[3]) redis.call("DEL", KEYS[1]) return table.getn(ids)`) -func (r *RDB) archiveAll(src, dst, qname string) (int64, error) { - if err := r.checkQueueExists(qname); err != nil { +func (r *RDB) archiveAll(ctx context.Context, src, dst, qname string) (int64, error) { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, err } keys := []string{ @@ -1398,7 +1399,7 @@ func (r *RDB) archiveAll(src, dst, qname string) (int64, error) { base.TaskKeyPrefix(qname), qname, } - res, err := archiveAllCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := archiveAllCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, err } @@ -1462,9 +1463,9 @@ return redis.call("DEL", KEYS[1]) // If a queue with the given name doesn't exist, it returns QueueNotFoundError. // If a task with the given id doesn't exist in the queue, it returns TaskNotFoundError // If a task is in active state it returns non-nil error with Code FailedPrecondition. -func (r *RDB) DeleteTask(qname, id string) error { +func (r *RDB) DeleteTask(ctx context.Context, qname, id string) error { var op errors.Op = "rdb.DeleteTask" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1476,7 +1477,7 @@ func (r *RDB) DeleteTask(qname, id string) error { base.QueueKeyPrefix(qname), base.GroupKeyPrefix(qname), } - res, err := deleteTaskCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := deleteTaskCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return errors.E(op, errors.Unknown, err) } @@ -1498,9 +1499,9 @@ func (r *RDB) DeleteTask(qname, id string) error { // DeleteAllArchivedTasks deletes all archived tasks from the given queue // and returns the number of tasks deleted. -func (r *RDB) DeleteAllArchivedTasks(qname string) (int64, error) { +func (r *RDB) DeleteAllArchivedTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.DeleteAllArchivedTasks" - n, err := r.deleteAll(base.ArchivedKey(qname), qname) + n, err := r.deleteAll(ctx, base.ArchivedKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1512,9 +1513,9 @@ func (r *RDB) DeleteAllArchivedTasks(qname string) (int64, error) { // DeleteAllRetryTasks deletes all retry tasks from the given queue // and returns the number of tasks deleted. -func (r *RDB) DeleteAllRetryTasks(qname string) (int64, error) { +func (r *RDB) DeleteAllRetryTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.DeleteAllRetryTasks" - n, err := r.deleteAll(base.RetryKey(qname), qname) + n, err := r.deleteAll(ctx, base.RetryKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1526,9 +1527,9 @@ func (r *RDB) DeleteAllRetryTasks(qname string) (int64, error) { // DeleteAllScheduledTasks deletes all scheduled tasks from the given queue // and returns the number of tasks deleted. -func (r *RDB) DeleteAllScheduledTasks(qname string) (int64, error) { +func (r *RDB) DeleteAllScheduledTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.DeleteAllScheduledTasks" - n, err := r.deleteAll(base.ScheduledKey(qname), qname) + n, err := r.deleteAll(ctx, base.ScheduledKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1540,9 +1541,9 @@ func (r *RDB) DeleteAllScheduledTasks(qname string) (int64, error) { // DeleteAllCompletedTasks deletes all completed tasks from the given queue // and returns the number of tasks deleted. -func (r *RDB) DeleteAllCompletedTasks(qname string) (int64, error) { +func (r *RDB) DeleteAllCompletedTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.DeleteAllCompletedTasks" - n, err := r.deleteAll(base.CompletedKey(qname), qname) + n, err := r.deleteAll(ctx, base.CompletedKey(qname), qname) if errors.IsQueueNotFound(err) { return 0, errors.E(op, errors.NotFound, err) } @@ -1574,15 +1575,15 @@ end redis.call("DEL", KEYS[1]) return table.getn(ids)`) -func (r *RDB) deleteAll(key, qname string) (int64, error) { - if err := r.checkQueueExists(qname); err != nil { +func (r *RDB) deleteAll(ctx context.Context, key, qname string) (int64, error) { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, err } argv := []interface{}{ base.TaskKeyPrefix(qname), qname, } - res, err := deleteAllCmd.Run(context.Background(), r.client, []string{key}, argv...).Result() + res, err := deleteAllCmd.Run(ctx, r.client, []string{key}, argv...).Result() if err != nil { return 0, err } @@ -1613,9 +1614,9 @@ return table.getn(ids) // DeleteAllAggregatingTasks deletes all aggregating tasks from the given group // and returns the number of tasks deleted. -func (r *RDB) DeleteAllAggregatingTasks(qname, gname string) (int64, error) { +func (r *RDB) DeleteAllAggregatingTasks(ctx context.Context, qname, gname string) (int64, error) { var op errors.Op = "rdb.DeleteAllAggregatingTasks" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1626,7 +1627,7 @@ func (r *RDB) DeleteAllAggregatingTasks(qname, gname string) (int64, error) { base.TaskKeyPrefix(qname), gname, } - res, err := deleteAllAggregatingCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := deleteAllAggregatingCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Unknown, err) } @@ -1656,9 +1657,9 @@ return table.getn(ids)`) // DeleteAllPendingTasks deletes all pending tasks from the given queue // and returns the number of tasks deleted. -func (r *RDB) DeleteAllPendingTasks(qname string) (int64, error) { +func (r *RDB) DeleteAllPendingTasks(ctx context.Context, qname string) (int64, error) { var op errors.Op = "rdb.DeleteAllPendingTasks" - if err := r.checkQueueExists(qname); err != nil { + if err := r.checkQueueExists(ctx, qname); err != nil { return 0, errors.E(op, errors.CanonicalCode(err), err) } keys := []string{ @@ -1667,7 +1668,7 @@ func (r *RDB) DeleteAllPendingTasks(qname string) (int64, error) { argv := []interface{}{ base.TaskKeyPrefix(qname), } - res, err := deleteAllPendingCmd.Run(context.Background(), r.client, keys, argv...).Result() + res, err := deleteAllPendingCmd.Run(ctx, r.client, keys, argv...).Result() if err != nil { return 0, errors.E(op, errors.Unknown, err) } @@ -1796,9 +1797,9 @@ return 1`) // as long as no tasks are active for the queue. // If force is set to false, it will only remove the queue if // the queue is empty. -func (r *RDB) RemoveQueue(qname string, force bool) error { +func (r *RDB) RemoveQueue(ctx context.Context, qname string, force bool) error { var op errors.Op = "rdb.RemoveQueue" - exists, err := r.queueExists(qname) + exists, err := r.queueExists(ctx, qname) if err != nil { return err } @@ -1819,7 +1820,7 @@ func (r *RDB) RemoveQueue(qname string, force bool) error { base.ArchivedKey(qname), base.LeaseKey(qname), } - res, err := script.Run(context.Background(), r.client, keys, base.TaskKeyPrefix(qname)).Result() + res, err := script.Run(ctx, r.client, keys, base.TaskKeyPrefix(qname)).Result() if err != nil { return errors.E(op, errors.Unknown, err) } @@ -1829,7 +1830,7 @@ func (r *RDB) RemoveQueue(qname string, force bool) error { } switch n { case 1: - if err := r.client.SRem(context.Background(), base.AllQueues, qname).Err(); err != nil { + if err := r.client.SRem(ctx, base.AllQueues, qname).Err(); err != nil { return errors.E(op, errors.Unknown, err) } return nil @@ -1850,9 +1851,9 @@ redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", now-1) return keys`) // ListServers returns the list of server info. -func (r *RDB) ListServers() ([]*base.ServerInfo, error) { +func (r *RDB) ListServers(ctx context.Context) ([]*base.ServerInfo, error) { now := r.clock.Now() - res, err := listServerKeysCmd.Run(context.Background(), r.client, []string{base.AllServers}, now.Unix()).Result() + res, err := listServerKeysCmd.Run(ctx, r.client, []string{base.AllServers}, now.Unix()).Result() if err != nil { return nil, err } @@ -1862,7 +1863,7 @@ func (r *RDB) ListServers() ([]*base.ServerInfo, error) { } var servers []*base.ServerInfo for _, key := range keys { - data, err := r.client.Get(context.Background(), key).Result() + data, err := r.client.Get(ctx, key).Result() if err != nil { continue // skip bad data } @@ -1883,10 +1884,10 @@ redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", now-1) return keys`) // ListWorkers returns the list of worker stats. -func (r *RDB) ListWorkers() ([]*base.WorkerInfo, error) { +func (r *RDB) ListWorkers(ctx context.Context) ([]*base.WorkerInfo, error) { var op errors.Op = "rdb.ListWorkers" now := r.clock.Now() - res, err := listWorkersCmd.Run(context.Background(), r.client, []string{base.AllWorkers}, now.Unix()).Result() + res, err := listWorkersCmd.Run(ctx, r.client, []string{base.AllWorkers}, now.Unix()).Result() if err != nil { return nil, errors.E(op, errors.Unknown, err) } @@ -1896,7 +1897,7 @@ func (r *RDB) ListWorkers() ([]*base.WorkerInfo, error) { } var workers []*base.WorkerInfo for _, key := range keys { - data, err := r.client.HVals(context.Background(), key).Result() + data, err := r.client.HVals(ctx, key).Result() if err != nil { continue // skip bad data } @@ -1919,9 +1920,9 @@ redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", now-1) return keys`) // ListSchedulerEntries returns the list of scheduler entries. -func (r *RDB) ListSchedulerEntries() ([]*base.SchedulerEntry, error) { +func (r *RDB) ListSchedulerEntries(ctx context.Context) ([]*base.SchedulerEntry, error) { now := r.clock.Now() - res, err := listSchedulerKeysCmd.Run(context.Background(), r.client, []string{base.AllSchedulers}, now.Unix()).Result() + res, err := listSchedulerKeysCmd.Run(ctx, r.client, []string{base.AllSchedulers}, now.Unix()).Result() if err != nil { return nil, err } @@ -1931,7 +1932,7 @@ func (r *RDB) ListSchedulerEntries() ([]*base.SchedulerEntry, error) { } var entries []*base.SchedulerEntry for _, key := range keys { - data, err := r.client.LRange(context.Background(), key, 0, -1).Result() + data, err := r.client.LRange(ctx, key, 0, -1).Result() if err != nil { continue // skip bad data } @@ -1947,9 +1948,9 @@ func (r *RDB) ListSchedulerEntries() ([]*base.SchedulerEntry, error) { } // ListSchedulerEnqueueEvents returns the list of scheduler enqueue events. -func (r *RDB) ListSchedulerEnqueueEvents(entryID string, pgn Pagination) ([]*base.SchedulerEnqueueEvent, error) { +func (r *RDB) ListSchedulerEnqueueEvents(ctx context.Context, entryID string, pgn Pagination) ([]*base.SchedulerEnqueueEvent, error) { key := base.SchedulerHistoryKey(entryID) - zs, err := r.client.ZRevRangeWithScores(context.Background(), key, pgn.start(), pgn.stop()).Result() + zs, err := r.client.ZRevRangeWithScores(ctx, key, pgn.start(), pgn.stop()).Result() if err != nil { return nil, err } @@ -1969,9 +1970,9 @@ func (r *RDB) ListSchedulerEnqueueEvents(entryID string, pgn Pagination) ([]*bas } // Pause pauses processing of tasks from the given queue. -func (r *RDB) Pause(qname string) error { +func (r *RDB) Pause(ctx context.Context, qname string) error { key := base.PausedKey(qname) - ok, err := r.client.SetNX(context.Background(), key, r.clock.Now().Unix(), 0).Result() + ok, err := r.client.SetNX(ctx, key, r.clock.Now().Unix(), 0).Result() if err != nil { return err } @@ -1982,9 +1983,9 @@ func (r *RDB) Pause(qname string) error { } // Unpause resumes processing of tasks from the given queue. -func (r *RDB) Unpause(qname string) error { +func (r *RDB) Unpause(ctx context.Context, qname string) error { key := base.PausedKey(qname) - deleted, err := r.client.Del(context.Background(), key).Result() + deleted, err := r.client.Del(ctx, key).Result() if err != nil { return err } @@ -1995,18 +1996,18 @@ func (r *RDB) Unpause(qname string) error { } // ClusterKeySlot returns an integer identifying the hash slot the given queue hashes to. -func (r *RDB) ClusterKeySlot(qname string) (int64, error) { +func (r *RDB) ClusterKeySlot(ctx context.Context, qname string) (int64, error) { key := base.PendingKey(qname) - return r.client.ClusterKeySlot(context.Background(), key).Result() + return r.client.ClusterKeySlot(ctx, key).Result() } // ClusterNodes returns a list of nodes the given queue belongs to. -func (r *RDB) ClusterNodes(qname string) ([]redis.ClusterNode, error) { - keyslot, err := r.ClusterKeySlot(qname) +func (r *RDB) ClusterNodes(ctx context.Context, qname string) ([]redis.ClusterNode, error) { + keyslot, err := r.ClusterKeySlot(ctx, qname) if err != nil { return nil, err } - clusterSlots, err := r.client.ClusterSlots(context.Background()).Result() + clusterSlots, err := r.client.ClusterSlots(ctx).Result() if err != nil { return nil, err } diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 5987100..6caa316 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" h "github.com/hibiken/asynq/internal/testutil" @@ -42,7 +43,7 @@ func TestAllQueues(t *testing.T) { t.Fatalf("could not initialize all queue set: %v", err) } } - got, err := r.AllQueues() + got, err := r.AllQueues(context.Background()) if err != nil { t.Errorf("AllQueues() returned an error: %v", err) continue @@ -280,7 +281,7 @@ func TestCurrentStats(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case for _, qname := range tc.paused { - if err := r.Pause(qname); err != nil { + if err := r.Pause(context.Background(), qname); err != nil { t.Fatal(err) } } @@ -315,7 +316,7 @@ func TestCurrentStats(t *testing.T) { r.client.HSet(ctx, base.TaskKey(qname, oldestPendingMessageID), "pending_since", enqueueTime.UnixNano()) } - got, err := r.CurrentStats(tc.qname) + got, err := r.CurrentStats(context.Background(), tc.qname) if err != nil { t.Errorf("r.CurrentStats(%q) = %v, %v, want %v, nil", tc.qname, got, err, tc.want) continue @@ -334,7 +335,7 @@ func TestCurrentStatsWithNonExistentQueue(t *testing.T) { defer r.Close() qname := "non-existent" - got, err := r.CurrentStats(qname) + got, err := r.CurrentStats(context.Background(), qname) if !errors.IsQueueNotFound(err) { t.Fatalf("r.CurrentStats(%q) = %v, %v, want nil, %v", qname, got, err, &errors.QueueNotFoundError{Queue: qname}) } @@ -367,7 +368,7 @@ func TestHistoricalStats(t *testing.T) { r.client.Set(context.Background(), failedKey, (i+1)*10, 0) } - got, err := r.HistoricalStats(tc.qname, tc.n) + got, err := r.HistoricalStats(context.Background(), tc.qname, tc.n) if err != nil { t.Errorf("RDB.HistoricalStats(%q, %d) returned error: %v", tc.qname, tc.n, err) continue @@ -399,7 +400,7 @@ func TestRedisInfo(t *testing.T) { r := setup(t) defer r.Close() - info, err := r.RedisInfo() + info, err := r.RedisInfo(context.Background()) if err != nil { t.Fatalf("RDB.RedisInfo() returned error: %v", err) } @@ -504,7 +505,7 @@ func TestGroupStats(t *testing.T) { h.SeedRedisZSets(t, r.client, fixtures.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.GroupStats(tc.qname) + got, err := r.GroupStats(context.Background(), tc.qname) if err != nil { t.Fatalf("GroupStats returned error: %v", err) } @@ -646,7 +647,7 @@ func TestGetTaskInfo(t *testing.T) { } for _, tc := range tests { - got, err := r.GetTaskInfo(tc.qname, tc.id) + got, err := r.GetTaskInfo(context.Background(), tc.qname, tc.id) if err != nil { t.Errorf("GetTaskInfo(%q, %v) returned error: %v", tc.qname, tc.id, err) continue @@ -726,7 +727,7 @@ func TestGetTaskInfoError(t *testing.T) { } for _, tc := range tests { - info, err := r.GetTaskInfo(tc.qname, tc.id) + info, err := r.GetTaskInfo(context.Background(), tc.qname, tc.id) if info != nil { t.Errorf("GetTaskInfo(%q, %v) returned info: %v", tc.qname, tc.id, info) } @@ -796,7 +797,7 @@ func TestListPending(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - got, err := r.ListPending(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListPending(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -846,7 +847,7 @@ func TestListPendingPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListPending(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListPending(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListPending(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -915,7 +916,7 @@ func TestListActive(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllActiveQueues(t, r.client, tc.inProgress) - got, err := r.ListActive(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListActive(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.inProgress) @@ -955,7 +956,7 @@ func TestListActivePagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListActive(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListActive(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1050,7 +1051,7 @@ func TestListScheduled(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got, err := r.ListScheduled(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListScheduled(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1091,7 +1092,7 @@ func TestListScheduledPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListScheduled(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListScheduled(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListScheduled(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { t.Errorf("%s; %s returned error %v", tc.desc, op, err) @@ -1204,7 +1205,7 @@ func TestListRetry(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - got, err := r.ListRetry(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListRetry(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1248,7 +1249,7 @@ func TestListRetryPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListRetry(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListRetry(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListRetry(%q, Pagination{Size: %d, Page: %d})", tc.qname, tc.size, tc.page) if err != nil { @@ -1357,7 +1358,7 @@ func TestListArchived(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ListArchived(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListArchived(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListArchived(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1398,7 +1399,7 @@ func TestListArchivedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListArchived(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListArchived(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListArchived(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1497,7 +1498,7 @@ func TestListCompleted(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllCompletedQueues(t, r.client, tc.completed) - got, err := r.ListCompleted(tc.qname, Pagination{Size: 20, Page: 0}) + got, err := r.ListCompleted(context.Background(), tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListCompleted(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.want) @@ -1539,7 +1540,7 @@ func TestListCompletedPagination(t *testing.T) { } for _, tc := range tests { - got, err := r.ListCompleted(tc.qname, Pagination{Size: tc.size, Page: tc.page}) + got, err := r.ListCompleted(context.Background(), tc.qname, Pagination{Size: tc.size, Page: tc.page}) op := fmt.Sprintf("r.ListCompleted(Pagination{Size: %d, Page: %d})", tc.size, tc.page) if err != nil { @@ -1645,7 +1646,7 @@ func TestListAggregating(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{}) + got, err := r.ListAggregating(context.Background(), tc.qname, tc.gname, Pagination{}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1759,7 +1760,7 @@ func TestListAggregatingPagination(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ListAggregating(tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) + got, err := r.ListAggregating(context.Background(), tc.qname, tc.gname, Pagination{Page: tc.page, Size: tc.size}) if err != nil { t.Fatalf("ListAggregating returned error: %v", err) } @@ -1803,19 +1804,19 @@ func TestListTasksError(t *testing.T) { for _, tc := range tests { pgn := Pagination{Page: 0, Size: 20} - if _, got := r.ListActive(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListActive(context.Background(), tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListActive returned %v", tc.desc, got) } - if _, got := r.ListPending(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListPending(context.Background(), tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListPending returned %v", tc.desc, got) } - if _, got := r.ListScheduled(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListScheduled(context.Background(), tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListScheduled returned %v", tc.desc, got) } - if _, got := r.ListRetry(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListRetry(context.Background(), tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListRetry returned %v", tc.desc, got) } - if _, got := r.ListArchived(tc.qname, pgn); !tc.match(got) { + if _, got := r.ListArchived(context.Background(), tc.qname, pgn); !tc.match(got) { t.Errorf("%s: ListArchived returned %v", tc.desc, got) } } @@ -1885,7 +1886,7 @@ func TestRunArchivedTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - if got := r.RunTask(tc.qname, tc.id); got != nil { + if got := r.RunTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -1965,7 +1966,7 @@ func TestRunRetryTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) // initialize retry queue - if got := r.RunTask(tc.qname, tc.id); got != nil { + if got := r.RunTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -2079,7 +2080,7 @@ func TestRunAggregatingTask(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - err := r.RunTask(tc.qname, tc.id) + err := r.RunTask(context.Background(), tc.qname, tc.id) if err != nil { t.Fatalf("RunTask returned error: %v", err) } @@ -2150,7 +2151,7 @@ func TestRunScheduledTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - if got := r.RunTask(tc.qname, tc.id); got != nil { + if got := r.RunTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -2297,7 +2298,7 @@ func TestRunTaskError(t *testing.T) { h.SeedAllPendingQueues(t, r.client, tc.pending) h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got := r.RunTask(tc.qname, tc.id) + got := r.RunTask(context.Background(), tc.qname, tc.id) if !tc.match(got) { t.Errorf("%s: unexpected return value %v", tc.desc, got) continue @@ -2406,7 +2407,7 @@ func TestRunAllScheduledTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got, err := r.RunAllScheduledTasks(tc.qname) + got, err := r.RunAllScheduledTasks(context.Background(), tc.qname) if err != nil { t.Errorf("%s; r.RunAllScheduledTasks(%q) = %v, %v; want %v, nil", tc.desc, tc.qname, got, err, tc.want) @@ -2512,7 +2513,7 @@ func TestRunAllRetryTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - got, err := r.RunAllRetryTasks(tc.qname) + got, err := r.RunAllRetryTasks(context.Background(), tc.qname) if err != nil { t.Errorf("%s; r.RunAllRetryTasks(%q) = %v, %v; want %v, nil", tc.desc, tc.qname, got, err, tc.want) @@ -2618,7 +2619,7 @@ func TestRunAllArchivedTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.RunAllArchivedTasks(tc.qname) + got, err := r.RunAllArchivedTasks(context.Background(), tc.qname) if err != nil { t.Errorf("%s; r.RunAllArchivedTasks(%q) = %v, %v; want %v, nil", tc.desc, tc.qname, got, err, tc.want) @@ -2662,16 +2663,16 @@ func TestRunAllTasksError(t *testing.T) { } for _, tc := range tests { - if _, got := r.RunAllScheduledTasks(tc.qname); !tc.match(got) { + if _, got := r.RunAllScheduledTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: RunAllScheduledTasks returned %v", tc.desc, got) } - if _, got := r.RunAllRetryTasks(tc.qname); !tc.match(got) { + if _, got := r.RunAllRetryTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: RunAllRetryTasks returned %v", tc.desc, got) } - if _, got := r.RunAllArchivedTasks(tc.qname); !tc.match(got) { + if _, got := r.RunAllArchivedTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: RunAllArchivedTasks returned %v", tc.desc, got) } - if _, got := r.RunAllAggregatingTasks(tc.qname, "mygroup"); !tc.match(got) { + if _, got := r.RunAllAggregatingTasks(context.Background(), tc.qname, "mygroup"); !tc.match(got) { t.Errorf("%s: RunAllAggregatingTasks returned %v", tc.desc, got) } } @@ -2772,7 +2773,7 @@ func TestRunAllAggregatingTasks(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.RunAllAggregatingTasks(tc.qname, tc.gname) + got, err := r.RunAllAggregatingTasks(context.Background(), tc.qname, tc.gname) if err != nil { t.Fatalf("RunAllAggregatingTasks returned error: %v", err) } @@ -2863,7 +2864,7 @@ func TestArchiveRetryTask(t *testing.T) { h.SeedAllRetryQueues(t, r.client, tc.retry) h.SeedAllArchivedQueues(t, r.client, tc.archived) - if got := r.ArchiveTask(tc.qname, tc.id); got != nil { + if got := r.ArchiveTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("(*RDB).ArchiveTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue @@ -2964,7 +2965,7 @@ func TestArchiveScheduledTask(t *testing.T) { h.SeedAllScheduledQueues(t, r.client, tc.scheduled) h.SeedAllArchivedQueues(t, r.client, tc.archived) - if got := r.ArchiveTask(tc.qname, tc.id); got != nil { + if got := r.ArchiveTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("(*RDB).ArchiveTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue @@ -3085,7 +3086,7 @@ func TestArchiveAggregatingTask(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - err := r.ArchiveTask(tc.qname, tc.id) + err := r.ArchiveTask(context.Background(), tc.qname, tc.id) if err != nil { t.Fatalf("ArchiveTask returned error: %v", err) } @@ -3156,7 +3157,7 @@ func TestArchivePendingTask(t *testing.T) { h.SeedAllPendingQueues(t, r.client, tc.pending) h.SeedAllArchivedQueues(t, r.client, tc.archived) - if got := r.ArchiveTask(tc.qname, tc.id); got != nil { + if got := r.ArchiveTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("(*RDB).ArchiveTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue @@ -3304,7 +3305,7 @@ func TestArchiveTaskError(t *testing.T) { h.SeedAllScheduledQueues(t, r.client, tc.scheduled) h.SeedAllArchivedQueues(t, r.client, tc.archived) - got := r.ArchiveTask(tc.qname, tc.id) + got := r.ArchiveTask(context.Background(), tc.qname, tc.id) if !tc.match(got) { t.Errorf("%s: returned error didn't match: got=%v", tc.desc, got) continue @@ -3446,7 +3447,7 @@ func TestArchiveAllPendingTasks(t *testing.T) { h.SeedAllPendingQueues(t, r.client, tc.pending) h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ArchiveAllPendingTasks(tc.qname) + got, err := r.ArchiveAllPendingTasks(context.Background(), tc.qname) if got != tc.want || err != nil { t.Errorf("(*RDB).KillAllRetryTasks(%q) = %v, %v; want %v, nil", tc.qname, got, err, tc.want) @@ -3571,7 +3572,7 @@ func TestArchiveAllAggregatingTasks(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.ArchiveAllAggregatingTasks(tc.qname, tc.gname) + got, err := r.ArchiveAllAggregatingTasks(context.Background(), tc.qname, tc.gname) if err != nil { t.Fatalf("ArchiveAllAggregatingTasks returned error: %v", err) } @@ -3710,7 +3711,7 @@ func TestArchiveAllRetryTasks(t *testing.T) { h.SeedAllRetryQueues(t, r.client, tc.retry) h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ArchiveAllRetryTasks(tc.qname) + got, err := r.ArchiveAllRetryTasks(context.Background(), tc.qname) if got != tc.want || err != nil { t.Errorf("(*RDB).KillAllRetryTasks(%q) = %v, %v; want %v, nil", tc.qname, got, err, tc.want) @@ -3860,7 +3861,7 @@ func TestArchiveAllScheduledTasks(t *testing.T) { h.SeedAllScheduledQueues(t, r.client, tc.scheduled) h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.ArchiveAllScheduledTasks(tc.qname) + got, err := r.ArchiveAllScheduledTasks(context.Background(), tc.qname) if got != tc.want || err != nil { t.Errorf("(*RDB).KillAllScheduledTasks(%q) = %v, %v; want %v, nil", tc.qname, got, err, tc.want) @@ -3902,13 +3903,13 @@ func TestArchiveAllTasksError(t *testing.T) { } for _, tc := range tests { - if _, got := r.ArchiveAllPendingTasks(tc.qname); !tc.match(got) { + if _, got := r.ArchiveAllPendingTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: ArchiveAllPendingTasks returned %v", tc.desc, got) } - if _, got := r.ArchiveAllScheduledTasks(tc.qname); !tc.match(got) { + if _, got := r.ArchiveAllScheduledTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: ArchiveAllScheduledTasks returned %v", tc.desc, got) } - if _, got := r.ArchiveAllRetryTasks(tc.qname); !tc.match(got) { + if _, got := r.ArchiveAllRetryTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: ArchiveAllRetryTasks returned %v", tc.desc, got) } } @@ -3966,7 +3967,7 @@ func TestDeleteArchivedTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - if got := r.DeleteTask(tc.qname, tc.id); got != nil { + if got := r.DeleteTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.DeleteTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue } @@ -4032,7 +4033,7 @@ func TestDeleteRetryTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - if got := r.DeleteTask(tc.qname, tc.id); got != nil { + if got := r.DeleteTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.DeleteTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue } @@ -4098,7 +4099,7 @@ func TestDeleteScheduledTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - if got := r.DeleteTask(tc.qname, tc.id); got != nil { + if got := r.DeleteTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.DeleteTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue } @@ -4197,7 +4198,7 @@ func TestDeleteAggregatingTask(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - err := r.DeleteTask(tc.qname, tc.id) + err := r.DeleteTask(context.Background(), tc.qname, tc.id) if err != nil { t.Fatalf("DeleteTask returned error: %v", err) } @@ -4248,7 +4249,7 @@ func TestDeletePendingTask(t *testing.T) { h.FlushDB(t, r.client) h.SeedAllPendingQueues(t, r.client, tc.pending) - if got := r.DeleteTask(tc.qname, tc.id); got != nil { + if got := r.DeleteTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.DeleteTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue } @@ -4300,7 +4301,7 @@ func TestDeleteTaskWithUniqueLock(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - if got := r.DeleteTask(tc.qname, tc.id); got != nil { + if got := r.DeleteTask(context.Background(), tc.qname, tc.id); got != nil { t.Errorf("r.DeleteTask(%q, %v) returned error: %v", tc.qname, tc.id, got) continue } @@ -4395,7 +4396,7 @@ func TestDeleteTaskError(t *testing.T) { h.SeedAllActiveQueues(t, r.client, tc.active) h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got := r.DeleteTask(tc.qname, tc.id) + got := r.DeleteTask(context.Background(), tc.qname, tc.id) if !tc.match(got) { t.Errorf("%s: r.DeleteTask(qname, id) returned %v", tc.desc, got) continue @@ -4463,7 +4464,7 @@ func TestDeleteAllArchivedTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.DeleteAllArchivedTasks(tc.qname) + got, err := r.DeleteAllArchivedTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllArchivedTasks(%q) returned error: %v", tc.qname, err) } @@ -4533,7 +4534,7 @@ func TestDeleteAllCompletedTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllCompletedQueues(t, r.client, tc.completed) - got, err := r.DeleteAllCompletedTasks(tc.qname) + got, err := r.DeleteAllCompletedTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllCompletedTasks(%q) returned error: %v", tc.qname, err) } @@ -4600,7 +4601,7 @@ func TestDeleteAllArchivedTasksWithUniqueKey(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got, err := r.DeleteAllArchivedTasks(tc.qname) + got, err := r.DeleteAllArchivedTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllArchivedTasks(%q) returned error: %v", tc.qname, err) } @@ -4668,7 +4669,7 @@ func TestDeleteAllRetryTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) - got, err := r.DeleteAllRetryTasks(tc.qname) + got, err := r.DeleteAllRetryTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllRetryTasks(%q) returned error: %v", tc.qname, err) } @@ -4730,7 +4731,7 @@ func TestDeleteAllScheduledTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got, err := r.DeleteAllScheduledTasks(tc.qname) + got, err := r.DeleteAllScheduledTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllScheduledTasks(%q) returned error: %v", tc.qname, err) } @@ -4832,7 +4833,7 @@ func TestDeleteAllAggregatingTasks(t *testing.T) { h.SeedRedisZSets(t, r.client, fxt.groups) t.Run(tc.desc, func(t *testing.T) { - got, err := r.DeleteAllAggregatingTasks(tc.qname, tc.gname) + got, err := r.DeleteAllAggregatingTasks(context.Background(), tc.qname, tc.gname) if err != nil { t.Fatalf("DeleteAllAggregatingTasks returned error: %v", err) } @@ -4886,7 +4887,7 @@ func TestDeleteAllPendingTasks(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllPendingQueues(t, r.client, tc.pending) - got, err := r.DeleteAllPendingTasks(tc.qname) + got, err := r.DeleteAllPendingTasks(context.Background(), tc.qname) if err != nil { t.Errorf("r.DeleteAllPendingTasks(%q) returned error: %v", tc.qname, err) } @@ -4919,16 +4920,16 @@ func TestDeleteAllTasksError(t *testing.T) { } for _, tc := range tests { - if _, got := r.DeleteAllPendingTasks(tc.qname); !tc.match(got) { + if _, got := r.DeleteAllPendingTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: DeleteAllPendingTasks returned %v", tc.desc, got) } - if _, got := r.DeleteAllScheduledTasks(tc.qname); !tc.match(got) { + if _, got := r.DeleteAllScheduledTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: DeleteAllScheduledTasks returned %v", tc.desc, got) } - if _, got := r.DeleteAllRetryTasks(tc.qname); !tc.match(got) { + if _, got := r.DeleteAllRetryTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: DeleteAllRetryTasks returned %v", tc.desc, got) } - if _, got := r.DeleteAllArchivedTasks(tc.qname); !tc.match(got) { + if _, got := r.DeleteAllArchivedTasks(context.Background(), tc.qname); !tc.match(got) { t.Errorf("%s: DeleteAllArchivedTasks returned %v", tc.desc, got) } } @@ -5009,7 +5010,7 @@ func TestRemoveQueue(t *testing.T) { h.SeedAllRetryQueues(t, r.client, tc.retry) h.SeedAllArchivedQueues(t, r.client, tc.archived) - err := r.RemoveQueue(tc.qname, tc.force) + err := r.RemoveQueue(context.Background(), tc.qname, tc.force) if err != nil { t.Errorf("(*RDB).RemoveQueue(%q, %t) = %v, want nil", tc.qname, tc.force, err) @@ -5147,7 +5148,7 @@ func TestRemoveQueueError(t *testing.T) { h.SeedAllRetryQueues(t, r.client, tc.retry) h.SeedAllArchivedQueues(t, r.client, tc.archived) - got := r.RemoveQueue(tc.qname, tc.force) + got := r.RemoveQueue(context.Background(), tc.qname, tc.force) if !tc.match(got) { t.Errorf("%s; returned error didn't match expected value; got=%v", tc.desc, got) continue @@ -5238,7 +5239,7 @@ func TestListServers(t *testing.T) { } } - got, err := r.ListServers() + got, err := r.ListServers(context.Background()) if err != nil { t.Errorf("r.ListServers returned an error: %v", err) } @@ -5314,7 +5315,7 @@ func TestListWorkers(t *testing.T) { continue } - got, err := r.ListWorkers() + got, err := r.ListWorkers(context.Background()) if err != nil { t.Errorf("(*RDB).ListWorkers() returned an error: %v", err) continue @@ -5353,7 +5354,7 @@ func TestWriteListClearSchedulerEntries(t *testing.T) { if err := r.WriteSchedulerEntries(schedulerID, data, 30*time.Second); err != nil { t.Fatalf("WriteSchedulerEnties failed: %v", err) } - entries, err := r.ListSchedulerEntries() + entries, err := r.ListSchedulerEntries(context.Background()) if err != nil { t.Fatalf("ListSchedulerEntries failed: %v", err) } @@ -5363,7 +5364,7 @@ func TestWriteListClearSchedulerEntries(t *testing.T) { if err := r.ClearSchedulerEntries(schedulerID); err != nil { t.Fatalf("ClearSchedulerEntries failed: %v", err) } - entries, err = r.ListSchedulerEntries() + entries, err = r.ListSchedulerEntries(context.Background()) if err != nil { t.Fatalf("ListSchedulerEntries() after clear failed: %v", err) } @@ -5418,7 +5419,7 @@ loop: continue loop } } - got, err := r.ListSchedulerEnqueueEvents(tc.entryID, Pagination{Size: 20, Page: 0}) + got, err := r.ListSchedulerEnqueueEvents(context.Background(), tc.entryID, Pagination{Size: 20, Page: 0}) if err != nil { t.Errorf("ListSchedulerEnqueueEvents(%q) failed: %v", tc.entryID, err) continue @@ -5465,7 +5466,7 @@ func TestRecordSchedulerEnqueueEventTrimsDataSet(t *testing.T) { if n := r.client.ZCard(context.Background(), key).Val(); n != maxEvents { t.Fatalf("unexpected number of events; got %d, want %d", n, maxEvents) } - events, err := r.ListSchedulerEnqueueEvents(entryID, Pagination{Size: maxEvents}) + events, err := r.ListSchedulerEnqueueEvents(context.Background(), entryID, Pagination{Size: maxEvents}) if err != nil { t.Fatalf("ListSchedulerEnqueueEvents failed: %v", err) } @@ -5490,7 +5491,7 @@ func TestPause(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) - err := r.Pause(tc.qname) + err := r.Pause(context.Background(), tc.qname) if err != nil { t.Errorf("Pause(%q) returned error: %v", tc.qname, err) } @@ -5515,12 +5516,12 @@ func TestPauseError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) for _, qname := range tc.paused { - if err := r.Pause(qname); err != nil { + if err := r.Pause(context.Background(), qname); err != nil { t.Fatalf("could not pause %q: %v", qname, err) } } - err := r.Pause(tc.qname) + err := r.Pause(context.Background(), tc.qname) if err == nil { t.Errorf("%s; Pause(%q) returned nil: want error", tc.desc, tc.qname) } @@ -5540,12 +5541,12 @@ func TestUnpause(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) for _, qname := range tc.paused { - if err := r.Pause(qname); err != nil { + if err := r.Pause(context.Background(), qname); err != nil { t.Fatalf("could not pause %q: %v", qname, err) } } - err := r.Unpause(tc.qname) + err := r.Unpause(context.Background(), tc.qname) if err != nil { t.Errorf("Unpause(%q) returned error: %v", tc.qname, err) } @@ -5570,12 +5571,12 @@ func TestUnpauseError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) for _, qname := range tc.paused { - if err := r.Pause(qname); err != nil { + if err := r.Pause(context.Background(), qname); err != nil { t.Fatalf("could not pause %q: %v", qname, err) } } - err := r.Unpause(tc.qname) + err := r.Unpause(context.Background(), tc.qname) if err == nil { t.Errorf("%s; Unpause(%q) returned nil: want error", tc.desc, tc.qname) } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 7b25f15..3cfabeb 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -13,10 +13,11 @@ import ( "github.com/go-redis/redis/v8" "github.com/google/uuid" + "github.com/spf13/cast" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/timeutil" - "github.com/spf13/cast" ) const statsTTL = 90 * 24 * time.Hour // 90 days @@ -1291,10 +1292,15 @@ return res // ListLeaseExpired returns a list of task messages with an expired lease from the given queues. func (r *RDB) ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*base.TaskMessage, error) { + return r.ListLeaseExpiredContext(context.Background(), cutoff, qnames...) +} + +// ListLeaseExpiredContext returns a list of task messages with an expired lease from the given queues. +func (r *RDB) ListLeaseExpiredContext(ctx context.Context, cutoff time.Time, qnames ...string) ([]*base.TaskMessage, error) { var op errors.Op = "rdb.ListLeaseExpired" var msgs []*base.TaskMessage for _, qname := range qnames { - res, err := listLeaseExpiredCmd.Run(context.Background(), r.client, + res, err := listLeaseExpiredCmd.Run(ctx, r.client, []string{base.LeaseKey(qname)}, cutoff.Unix(), base.TaskKeyPrefix(qname)).Result() if err != nil { @@ -1459,8 +1465,13 @@ func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { // PublishCancelation publish cancelation message to all subscribers. // The message is the ID for the task to be canceled. func (r *RDB) PublishCancelation(id string) error { + return r.PublishCancelationContext(context.Background(), id) +} + +// PublishCancelationContext publish cancelation message to all subscribers. +// The message is the ID for the task to be canceled. +func (r *RDB) PublishCancelationContext(ctx context.Context, id string) error { var op errors.Op = "rdb.PublishCancelation" - ctx := context.Background() if err := r.client.Publish(ctx, base.CancelChannel, id).Err(); err != nil { return errors.E(op, errors.Unknown, fmt.Sprintf("redis pubsub publish error: %v", err)) } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 5094aa9..8ff45fe 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -616,7 +616,7 @@ func TestDequeueIgnoresPausedQueues(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case for _, qname := range tc.paused { - if err := r.Pause(qname); err != nil { + if err := r.Pause(context.Background(), qname); err != nil { t.Fatal(err) } }