mirror of
https://github.com/hibiken/asynq.git
synced 2024-12-26 07:42:17 +08:00
Add TaskID option to allow user to specify task id
This commit is contained in:
parent
dbdd9c6d5f
commit
9e2f88c00d
@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- `NewTask` takes `Option` as variadic argument
|
||||
|
||||
### Added
|
||||
|
||||
- `TaskID` option is added to allow user to specify task ID.
|
||||
- `ErrTaskIDConflict` sentinel error value is added.
|
||||
|
||||
### Removed
|
||||
|
||||
- `Client.SetDefaultOptions` is removed. Use `NewTask` instead to pass default options for tasks.
|
||||
|
36
client.go
36
client.go
@ -45,6 +45,7 @@ const (
|
||||
UniqueOpt
|
||||
ProcessAtOpt
|
||||
ProcessInOpt
|
||||
TaskIDOpt
|
||||
)
|
||||
|
||||
// Option specifies the task processing behavior.
|
||||
@ -63,6 +64,7 @@ type Option interface {
|
||||
type (
|
||||
retryOption int
|
||||
queueOption string
|
||||
taskIDOption string
|
||||
timeoutOption time.Duration
|
||||
deadlineOption time.Time
|
||||
uniqueOption time.Duration
|
||||
@ -94,6 +96,15 @@ func (qname queueOption) String() string { return fmt.Sprintf("Queue(%q)", s
|
||||
func (qname queueOption) Type() OptionType { return QueueOpt }
|
||||
func (qname queueOption) Value() interface{} { return string(qname) }
|
||||
|
||||
// TaskID returns an option to specify the task ID.
|
||||
func TaskID(id string) Option {
|
||||
return taskIDOption(id)
|
||||
}
|
||||
|
||||
func (id taskIDOption) String() string { return fmt.Sprintf("TaskID(%q)", string(id)) }
|
||||
func (id taskIDOption) Type() OptionType { return TaskIDOpt }
|
||||
func (id taskIDOption) Value() interface{} { return string(id) }
|
||||
|
||||
// Timeout returns an option to specify how long a task may run.
|
||||
// If the timeout elapses before the Handler returns, then the task
|
||||
// will be retried.
|
||||
@ -172,9 +183,15 @@ func (d processInOption) Value() interface{} { return time.Duration(d) }
|
||||
// ErrDuplicateTask error only applies to tasks enqueued with a Unique option.
|
||||
var ErrDuplicateTask = errors.New("task already exists")
|
||||
|
||||
// ErrTaskIDConflict indicates that the given task could not be enqueued since its task ID already exists.
|
||||
//
|
||||
// ErrTaskIDConflict error only applies to tasks enqueued with a TaskID option.
|
||||
var ErrTaskIDConflict = errors.New("task ID conflicts with another task")
|
||||
|
||||
type option struct {
|
||||
retry int
|
||||
queue string
|
||||
taskID string
|
||||
timeout time.Duration
|
||||
deadline time.Time
|
||||
uniqueTTL time.Duration
|
||||
@ -189,6 +206,7 @@ func composeOptions(opts ...Option) (option, error) {
|
||||
res := option{
|
||||
retry: defaultMaxRetry,
|
||||
queue: base.DefaultQueueName,
|
||||
taskID: uuid.NewString(),
|
||||
timeout: 0, // do not set to deafultTimeout here
|
||||
deadline: time.Time{},
|
||||
processAt: time.Now(),
|
||||
@ -203,6 +221,12 @@ func composeOptions(opts ...Option) (option, error) {
|
||||
return option{}, err
|
||||
}
|
||||
res.queue = qname
|
||||
case taskIDOption:
|
||||
id := string(opt)
|
||||
if err := validateTaskID(id); err != nil {
|
||||
return option{}, err
|
||||
}
|
||||
res.taskID = id
|
||||
case timeoutOption:
|
||||
res.timeout = time.Duration(opt)
|
||||
case deadlineOption:
|
||||
@ -220,6 +244,14 @@ func composeOptions(opts ...Option) (option, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// validates user provided task ID string.
|
||||
func validateTaskID(id string) error {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
return errors.New("task ID cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// Default max retry count used if nothing is specified.
|
||||
defaultMaxRetry = 25
|
||||
@ -276,7 +308,7 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) {
|
||||
uniqueKey = base.UniqueKey(opt.queue, task.Type(), task.Payload())
|
||||
}
|
||||
msg := &base.TaskMessage{
|
||||
ID: uuid.NewString(),
|
||||
ID: opt.taskID,
|
||||
Type: task.Type(),
|
||||
Payload: task.Payload(),
|
||||
Queue: opt.queue,
|
||||
@ -298,6 +330,8 @@ func (c *Client) Enqueue(task *Task, opts ...Option) (*TaskInfo, error) {
|
||||
switch {
|
||||
case errors.Is(err, errors.ErrDuplicateTask):
|
||||
return nil, fmt.Errorf("%w", ErrDuplicateTask)
|
||||
case errors.Is(err, errors.ErrTaskIdConflict):
|
||||
return nil, fmt.Errorf("%w", ErrTaskIDConflict)
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
104
client_test.go
104
client_test.go
@ -444,6 +444,100 @@ func TestClientEnqueue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEnqueueWithTaskIDOption(t *testing.T) {
|
||||
r := setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
defer client.Close()
|
||||
|
||||
task := NewTask("send_email", nil)
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
task *Task
|
||||
opts []Option
|
||||
wantInfo *TaskInfo
|
||||
wantPending map[string][]*base.TaskMessage
|
||||
}{
|
||||
{
|
||||
desc: "With a valid TaskID option",
|
||||
task: task,
|
||||
opts: []Option{
|
||||
TaskID("custom_id"),
|
||||
},
|
||||
wantInfo: &TaskInfo{
|
||||
ID: "custom_id",
|
||||
Queue: "default",
|
||||
Type: task.Type(),
|
||||
Payload: task.Payload(),
|
||||
State: TaskStatePending,
|
||||
MaxRetry: defaultMaxRetry,
|
||||
Retried: 0,
|
||||
LastErr: "",
|
||||
LastFailedAt: time.Time{},
|
||||
Timeout: defaultTimeout,
|
||||
Deadline: time.Time{},
|
||||
NextProcessAt: now,
|
||||
},
|
||||
wantPending: map[string][]*base.TaskMessage{
|
||||
"default": {
|
||||
{
|
||||
ID: "custom_id",
|
||||
Type: task.Type(),
|
||||
Payload: task.Payload(),
|
||||
Retry: defaultMaxRetry,
|
||||
Queue: "default",
|
||||
Timeout: int64(defaultTimeout.Seconds()),
|
||||
Deadline: noDeadline.Unix(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r) // clean up db before each test case.
|
||||
|
||||
gotInfo, err := client.Enqueue(tc.task, tc.opts...)
|
||||
if err != nil {
|
||||
t.Errorf("got non-nil error %v, want nil", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmpOptions := []cmp.Option{
|
||||
cmpopts.EquateApproxTime(500 * time.Millisecond),
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantInfo, gotInfo, cmpOptions...); diff != "" {
|
||||
t.Errorf("%s;\nEnqueue(task) returned %v, want %v; (-want,+got)\n%s",
|
||||
tc.desc, gotInfo, tc.wantInfo, diff)
|
||||
}
|
||||
|
||||
for qname, want := range tc.wantPending {
|
||||
got := h.GetPendingMessages(t, r, qname)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("%s;\nmismatch found in %q; (-want,+got)\n%s", tc.desc, base.PendingKey(qname), diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEnqueueWithConflictingTaskID(t *testing.T) {
|
||||
setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
defer client.Close()
|
||||
|
||||
const taskID = "custom_id"
|
||||
task := NewTask("foo", nil)
|
||||
|
||||
if _, err := client.Enqueue(task, TaskID(taskID)); err != nil {
|
||||
t.Fatalf("First task: Enqueue failed: %v", err)
|
||||
}
|
||||
_, err := client.Enqueue(task, TaskID(taskID))
|
||||
if !errors.Is(err, ErrTaskIDConflict) {
|
||||
t.Errorf("Second task: Enqueue returned %v, want %v", err, ErrTaskIDConflict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientEnqueueWithProcessInOption(t *testing.T) {
|
||||
r := setup(t)
|
||||
client := NewClient(getRedisConnOpt(t))
|
||||
@ -596,6 +690,16 @@ func TestClientEnqueueError(t *testing.T) {
|
||||
task: NewTask(" ", h.JSON(map[string]interface{}{})),
|
||||
opts: []Option{},
|
||||
},
|
||||
{
|
||||
desc: "With empty task ID",
|
||||
task: NewTask("foo", nil),
|
||||
opts: []Option{TaskID("")},
|
||||
},
|
||||
{
|
||||
desc: "With blank task ID",
|
||||
task: NewTask("foo", nil),
|
||||
opts: []Option{TaskID(" ")},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
25
inspector.go
25
inspector.go
@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/hibiken/asynq/internal/errors"
|
||||
"github.com/hibiken/asynq/internal/rdb"
|
||||
@ -178,11 +177,7 @@ func (i *Inspector) DeleteQueue(qname string, force bool) error {
|
||||
// Returns ErrQueueNotFound if a queue with the given name doesn't exist.
|
||||
// Returns ErrTaskNotFound if a task with the given id doesn't exist in the queue.
|
||||
func (i *Inspector) GetTaskInfo(qname, id string) (*TaskInfo, error) {
|
||||
taskid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("asynq: %s is not a valid task id", id)
|
||||
}
|
||||
info, err := i.rdb.GetTaskInfo(qname, taskid)
|
||||
info, err := i.rdb.GetTaskInfo(qname, id)
|
||||
switch {
|
||||
case errors.IsQueueNotFound(err):
|
||||
return nil, fmt.Errorf("asynq: %w", ErrQueueNotFound)
|
||||
@ -437,11 +432,7 @@ func (i *Inspector) DeleteTask(qname, id string) error {
|
||||
if err := base.ValidateQueueName(qname); err != nil {
|
||||
return fmt.Errorf("asynq: %v", err)
|
||||
}
|
||||
taskid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("asynq: %s is not a valid task id", id)
|
||||
}
|
||||
err = i.rdb.DeleteTask(qname, taskid)
|
||||
err := i.rdb.DeleteTask(qname, id)
|
||||
switch {
|
||||
case errors.IsQueueNotFound(err):
|
||||
return fmt.Errorf("asynq: %w", ErrQueueNotFound)
|
||||
@ -495,11 +486,7 @@ func (i *Inspector) RunTask(qname, id string) error {
|
||||
if err := base.ValidateQueueName(qname); err != nil {
|
||||
return fmt.Errorf("asynq: %v", err)
|
||||
}
|
||||
taskid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("asynq: %s is not a valid task id", id)
|
||||
}
|
||||
err = i.rdb.RunTask(qname, taskid)
|
||||
err := i.rdb.RunTask(qname, id)
|
||||
switch {
|
||||
case errors.IsQueueNotFound(err):
|
||||
return fmt.Errorf("asynq: %w", ErrQueueNotFound)
|
||||
@ -552,11 +539,7 @@ func (i *Inspector) ArchiveTask(qname, id string) error {
|
||||
if err := base.ValidateQueueName(qname); err != nil {
|
||||
return fmt.Errorf("asynq: err")
|
||||
}
|
||||
taskid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("asynq: %s is not a valid task id", id)
|
||||
}
|
||||
err = i.rdb.ArchiveTask(qname, taskid)
|
||||
err := i.rdb.ArchiveTask(qname, id)
|
||||
switch {
|
||||
case errors.IsQueueNotFound(err):
|
||||
return fmt.Errorf("asynq: %w", ErrQueueNotFound)
|
||||
|
Loading…
Reference in New Issue
Block a user