diff --git a/CHANGELOG.md b/CHANGELOG.md index 892195e..95be409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `NewTask` takes `Option` as variadic argument +### Added + +- `TaskID` option is added to allow user to specify task ID. +- `ErrTaskIDConflict` sentinel error value is added. + ### Removed - `Client.SetDefaultOptions` is removed. Use `NewTask` instead to pass default options for tasks. diff --git a/client.go b/client.go index d14c17e..0d61a6a 100644 --- a/client.go +++ b/client.go @@ -45,6 +45,7 @@ const ( UniqueOpt ProcessAtOpt ProcessInOpt + TaskIDOpt ) // Option specifies the task processing behavior. @@ -63,6 +64,7 @@ type Option interface { type ( retryOption int queueOption string + taskIDOption string timeoutOption time.Duration deadlineOption time.Time uniqueOption time.Duration @@ -94,6 +96,15 @@ func (qname queueOption) String() string { return fmt.Sprintf("Queue(%q)", s func (qname queueOption) Type() OptionType { return QueueOpt } func (qname queueOption) Value() interface{} { return string(qname) } +// TaskID returns an option to specify the task ID. +func TaskID(id string) Option { + return taskIDOption(id) +} + +func (id taskIDOption) String() string { return fmt.Sprintf("TaskID(%q)", string(id)) } +func (id taskIDOption) Type() OptionType { return TaskIDOpt } +func (id taskIDOption) Value() interface{} { return string(id) } + // Timeout returns an option to specify how long a task may run. // If the timeout elapses before the Handler returns, then the task // will be retried. @@ -172,9 +183,15 @@ func (d processInOption) Value() interface{} { return time.Duration(d) } // ErrDuplicateTask error only applies to tasks enqueued with a Unique option. var ErrDuplicateTask = errors.New("task already exists") +// ErrTaskIDConflict indicates that the given task could not be enqueued since its task ID already exists. +// +// ErrTaskIDConflict error only applies to tasks enqueued with a TaskID option. +var ErrTaskIDConflict = errors.New("task ID conflicts with another task") + type option struct { retry int queue string + taskID string timeout time.Duration deadline time.Time uniqueTTL time.Duration @@ -189,6 +206,7 @@ func composeOptions(opts ...Option) (option, error) { res := option{ retry: defaultMaxRetry, queue: base.DefaultQueueName, + taskID: uuid.NewString(), timeout: 0, // do not set to deafultTimeout here deadline: time.Time{}, processAt: time.Now(), @@ -203,6 +221,12 @@ func composeOptions(opts ...Option) (option, error) { return option{}, err } res.queue = qname + case taskIDOption: + id := string(opt) + if err := validateTaskID(id); err != nil { + return option{}, err + } + res.taskID = id case timeoutOption: res.timeout = time.Duration(opt) case deadlineOption: @@ -220,6 +244,14 @@ func composeOptions(opts ...Option) (option, error) { return res, nil } +// validates user provided task ID string. +func validateTaskID(id string) error { + if strings.TrimSpace(id) == "" { + return errors.New("task ID cannot be empty") + } + return nil +} + const ( // Default max retry count used if nothing is specified. defaultMaxRetry = 25 @@ -276,7 +308,7 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { uniqueKey = base.UniqueKey(opt.queue, task.Type(), task.Payload()) } msg := &base.TaskMessage{ - ID: uuid.NewString(), + ID: opt.taskID, Type: task.Type(), Payload: task.Payload(), Queue: opt.queue, @@ -298,6 +330,8 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) { switch { case errors.Is(err, errors.ErrDuplicateTask): return nil, fmt.Errorf("%w", ErrDuplicateTask) + case errors.Is(err, errors.ErrTaskIdConflict): + return nil, fmt.Errorf("%w", ErrTaskIDConflict) case err != nil: return nil, err } diff --git a/client_test.go b/client_test.go index bfe205b..a4e9751 100644 --- a/client_test.go +++ b/client_test.go @@ -444,6 +444,100 @@ func TestClientEnqueue(t *testing.T) { } } +func TestClientEnqueueWithTaskIDOption(t *testing.T) { + r := setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + task := NewTask("send_email", nil) + now := time.Now() + + tests := []struct { + desc string + task *Task + opts []Option + wantInfo *TaskInfo + wantPending map[string][]*base.TaskMessage + }{ + { + desc: "With a valid TaskID option", + task: task, + opts: []Option{ + TaskID("custom_id"), + }, + wantInfo: &TaskInfo{ + ID: "custom_id", + Queue: "default", + Type: task.Type(), + Payload: task.Payload(), + State: TaskStatePending, + MaxRetry: defaultMaxRetry, + Retried: 0, + LastErr: "", + LastFailedAt: time.Time{}, + Timeout: defaultTimeout, + Deadline: time.Time{}, + NextProcessAt: now, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": { + { + ID: "custom_id", + Type: task.Type(), + Payload: task.Payload(), + Retry: defaultMaxRetry, + Queue: "default", + Timeout: int64(defaultTimeout.Seconds()), + Deadline: noDeadline.Unix(), + }, + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) // clean up db before each test case. + + gotInfo, err := client.Enqueue(tc.task, tc.opts...) + if err != nil { + t.Errorf("got non-nil error %v, want nil", err) + continue + } + + cmpOptions := []cmp.Option{ + cmpopts.EquateApproxTime(500 * time.Millisecond), + } + if diff := cmp.Diff(tc.wantInfo, gotInfo, cmpOptions...); diff != "" { + t.Errorf("%s;\nEnqueue(task) returned %v, want %v; (-want,+got)\n%s", + tc.desc, gotInfo, tc.wantInfo, diff) + } + + for qname, want := range tc.wantPending { + got := h.GetPendingMessages(t, r, qname) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("%s;\nmismatch found in %q; (-want,+got)\n%s", tc.desc, base.PendingKey(qname), diff) + } + } + } +} + +func TestClientEnqueueWithConflictingTaskID(t *testing.T) { + setup(t) + client := NewClient(getRedisConnOpt(t)) + defer client.Close() + + const taskID = "custom_id" + task := NewTask("foo", nil) + + if _, err := client.Enqueue(task, TaskID(taskID)); err != nil { + t.Fatalf("First task: Enqueue failed: %v", err) + } + _, err := client.Enqueue(task, TaskID(taskID)) + if !errors.Is(err, ErrTaskIDConflict) { + t.Errorf("Second task: Enqueue returned %v, want %v", err, ErrTaskIDConflict) + } +} + func TestClientEnqueueWithProcessInOption(t *testing.T) { r := setup(t) client := NewClient(getRedisConnOpt(t)) @@ -596,6 +690,16 @@ func TestClientEnqueueError(t *testing.T) { task: NewTask(" ", h.JSON(map[string]interface{}{})), opts: []Option{}, }, + { + desc: "With empty task ID", + task: NewTask("foo", nil), + opts: []Option{TaskID("")}, + }, + { + desc: "With blank task ID", + task: NewTask("foo", nil), + opts: []Option{TaskID(" ")}, + }, } for _, tc := range tests { diff --git a/inspector.go b/inspector.go index 70826bc..e3da671 100644 --- a/inspector.go +++ b/inspector.go @@ -11,7 +11,6 @@ import ( "time" "github.com/go-redis/redis/v8" - "github.com/google/uuid" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/rdb" @@ -178,11 +177,7 @@ func (i *Inspector) DeleteQueue(qname string, force bool) error { // Returns ErrQueueNotFound if a queue with the given name doesn't exist. // Returns ErrTaskNotFound if a task with the given id doesn't exist in the queue. func (i *Inspector) GetTaskInfo(qname, id string) (*TaskInfo, error) { - taskid, err := uuid.Parse(id) - if err != nil { - return nil, fmt.Errorf("asynq: %s is not a valid task id", id) - } - info, err := i.rdb.GetTaskInfo(qname, taskid) + info, err := i.rdb.GetTaskInfo(qname, id) switch { case errors.IsQueueNotFound(err): return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -437,11 +432,7 @@ func (i *Inspector) DeleteTask(qname, id string) error { if err := base.ValidateQueueName(qname); err != nil { return fmt.Errorf("asynq: %v", err) } - taskid, err := uuid.Parse(id) - if err != nil { - return fmt.Errorf("asynq: %s is not a valid task id", id) - } - err = i.rdb.DeleteTask(qname, taskid) + err := i.rdb.DeleteTask(qname, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -495,11 +486,7 @@ func (i *Inspector) RunTask(qname, id string) error { if err := base.ValidateQueueName(qname); err != nil { return fmt.Errorf("asynq: %v", err) } - taskid, err := uuid.Parse(id) - if err != nil { - return fmt.Errorf("asynq: %s is not a valid task id", id) - } - err = i.rdb.RunTask(qname, taskid) + err := i.rdb.RunTask(qname, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound) @@ -552,11 +539,7 @@ func (i *Inspector) ArchiveTask(qname, id string) error { if err := base.ValidateQueueName(qname); err != nil { return fmt.Errorf("asynq: err") } - taskid, err := uuid.Parse(id) - if err != nil { - return fmt.Errorf("asynq: %s is not a valid task id", id) - } - err = i.rdb.ArchiveTask(qname, taskid) + err := i.rdb.ArchiveTask(qname, id) switch { case errors.IsQueueNotFound(err): return fmt.Errorf("asynq: %w", ErrQueueNotFound)