diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index 3f6e856..77adf2a 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -440,19 +440,6 @@ func (r *RDB) listZSetEntries(key, qname string, pgn Pagination) ([]base.Z, erro return zs, nil } -// RunTask finds a task that matches the id from the given queue and stages it for processing. -// If a task that matches the id does not exist, it returns ErrTaskNotFound. -func (r *RDB) RunTask(qname string, id uuid.UUID) error { - n, err := r.runTask(qname, id.String()) - if err != nil { - return err - } - if n == 0 { - return ErrTaskNotFound - } - return nil -} - // RunAllScheduledTasks enqueues all scheduled tasks from the given queue // and returns the number of tasks enqueued. func (r *RDB) RunAllScheduledTasks(qname string) (int64, error) { @@ -471,19 +458,37 @@ func (r *RDB) RunAllArchivedTasks(qname string) (int64, error) { return r.removeAndRunAll(base.ArchivedKey(qname), base.PendingKey(qname)) } +// runTaskCmd is a Lua script that updates the given task to pending state. +// +// Input: // KEYS[1] -> asynq:{}:t: // KEYS[2] -> asynq:{}:pending +// KEYS[3] -> all queues key +// -- // ARGV[1] -> task ID // ARGV[2] -> queue key prefix; asynq:{}: +// ARGV[3] -> queue name +// +// Output: +// Numeric code indicating the status: +// Returns 1 if task is successfully updated. +// Returns 0 if task is not found. +// Returns -1 if queue doesn't exist. +// Returns -2 if task is in active state. +// Returns -3 if task is in pending state. +// Returns error reply if unexpected error occurs. var runTaskCmd = redis.NewScript(` +if redis.call("SISMEMBER", KEYS[3], ARGV[3]) == 0 then + return -1 +end if redis.call("EXISTS", KEYS[1]) == 0 then return 0 end local state = redis.call("HGET", KEYS[1], "state") if state == "active" then - return redis.error_reply("task is already running") + return -2 elseif state == "pending" then - return redis.error_reply("task is already pending to be run") + return -3 end local n = redis.call("ZREM", ARGV[2] .. state, ARGV[1]) if n == 0 then @@ -494,24 +499,46 @@ redis.call("HSET", KEYS[1], "state", "pending") return 1 `) -func (r *RDB) runTask(qname, id string) (int64, error) { +// RunTask finds a task that matches the id from the given queue and updates it to pending state. +// It returns nil if it successfully updated the task. +// +// If a queue with the given name doesn't exist, it returns QueueNotFoundError. +// If a task with the given id doesn't exist in the queue, it returns TaskNotFoundError +// If a task is in active or pending state it returns non-nil error with Code FailedPrecondition. +func (r *RDB) RunTask(qname string, id uuid.UUID) error { + var op errors.Op = "rdb.RunTask" keys := []string{ - base.TaskKey(qname, id), + base.TaskKey(qname, id.String()), base.PendingKey(qname), + base.AllQueues, } argv := []interface{}{ - id, + id.String(), base.QueueKeyPrefix(qname), + qname, } res, err := runTaskCmd.Run(r.client, keys, argv...).Result() if err != nil { - return 0, err + return errors.E(op, errors.Unknown, err) } n, ok := res.(int64) if !ok { - return 0, fmt.Errorf("internal error: could not cast %v to int64", res) + return errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res)) + } + switch n { + case 1: + return nil + case 0: + return errors.E(op, errors.NotFound, &errors.TaskNotFoundError{Queue: qname, ID: id.String()}) + case -1: + return errors.E(op, errors.NotFound, &errors.QueueNotFoundError{Queue: qname}) + case -2: + return errors.E(op, errors.FailedPrecondition, "task is already running") + case -3: + return errors.E(op, errors.FailedPrecondition, "task is already in pending state") + default: + return errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script %d", n)) } - return n, nil } var removeAndRunAllCmd = redis.NewScript(` diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 275f323..2734c5d 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -1010,7 +1010,6 @@ func TestRunArchivedTask(t *testing.T) { archived map[string][]base.Z qname string id uuid.UUID - want error // expected return value from calling RunArchivedTask wantArchived map[string][]*base.TaskMessage wantPending map[string][]*base.TaskMessage }{ @@ -1023,7 +1022,6 @@ func TestRunArchivedTask(t *testing.T) { }, qname: "default", id: t2.ID, - want: nil, wantArchived: map[string][]*base.TaskMessage{ "default": {t1}, }, @@ -1031,23 +1029,6 @@ func TestRunArchivedTask(t *testing.T) { "default": {t2}, }, }, - { - archived: map[string][]base.Z{ - "default": { - {Message: t1, Score: s1}, - {Message: t2, Score: s2}, - }, - }, - qname: "default", - id: uuid.New(), - want: ErrTaskNotFound, - wantArchived: map[string][]*base.TaskMessage{ - "default": {t1, t2}, - }, - wantPending: map[string][]*base.TaskMessage{ - "default": {}, - }, - }, { archived: map[string][]base.Z{ "default": { @@ -1060,7 +1041,6 @@ func TestRunArchivedTask(t *testing.T) { }, qname: "critical", id: t3.ID, - want: nil, wantArchived: map[string][]*base.TaskMessage{ "default": {t1, t2}, "critical": {}, @@ -1076,9 +1056,8 @@ func TestRunArchivedTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllArchivedQueues(t, r.client, tc.archived) - got := r.RunTask(tc.qname, tc.id) - if got != tc.want { - t.Errorf("r.RunTask(%q, %s) = %v, want %v", tc.qname, tc.id, got, tc.want) + if got := r.RunTask(tc.qname, tc.id); got != nil { + t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -1111,7 +1090,6 @@ func TestRunRetryTask(t *testing.T) { retry map[string][]base.Z qname string id uuid.UUID - want error // expected return value from calling RunRetryTask wantRetry map[string][]*base.TaskMessage wantPending map[string][]*base.TaskMessage }{ @@ -1124,7 +1102,6 @@ func TestRunRetryTask(t *testing.T) { }, qname: "default", id: t2.ID, - want: nil, wantRetry: map[string][]*base.TaskMessage{ "default": {t1}, }, @@ -1132,23 +1109,6 @@ func TestRunRetryTask(t *testing.T) { "default": {t2}, }, }, - { - retry: map[string][]base.Z{ - "default": { - {Message: t1, Score: s1}, - {Message: t2, Score: s2}, - }, - }, - qname: "default", - id: uuid.New(), - want: ErrTaskNotFound, - wantRetry: map[string][]*base.TaskMessage{ - "default": {t1, t2}, - }, - wantPending: map[string][]*base.TaskMessage{ - "default": {}, - }, - }, { retry: map[string][]base.Z{ "default": { @@ -1161,7 +1121,6 @@ func TestRunRetryTask(t *testing.T) { }, qname: "low", id: t3.ID, - want: nil, wantRetry: map[string][]*base.TaskMessage{ "default": {t1, t2}, "low": {}, @@ -1177,9 +1136,8 @@ func TestRunRetryTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllRetryQueues(t, r.client, tc.retry) // initialize retry queue - got := r.RunTask(tc.qname, tc.id) - if got != tc.want { - t.Errorf("r.RunTask(%q, %s) = %v, want %v", tc.qname, tc.id, got, tc.want) + if got := r.RunTask(tc.qname, tc.id); got != nil { + t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -1212,7 +1170,6 @@ func TestRunScheduledTask(t *testing.T) { scheduled map[string][]base.Z qname string id uuid.UUID - want error // expected return value from calling RunScheduledTask wantScheduled map[string][]*base.TaskMessage wantPending map[string][]*base.TaskMessage }{ @@ -1225,7 +1182,6 @@ func TestRunScheduledTask(t *testing.T) { }, qname: "default", id: t2.ID, - want: nil, wantScheduled: map[string][]*base.TaskMessage{ "default": {t1}, }, @@ -1233,23 +1189,6 @@ func TestRunScheduledTask(t *testing.T) { "default": {t2}, }, }, - { - scheduled: map[string][]base.Z{ - "default": { - {Message: t1, Score: s1}, - {Message: t2, Score: s2}, - }, - }, - qname: "default", - id: uuid.New(), - want: ErrTaskNotFound, - wantScheduled: map[string][]*base.TaskMessage{ - "default": {t1, t2}, - }, - wantPending: map[string][]*base.TaskMessage{ - "default": {}, - }, - }, { scheduled: map[string][]base.Z{ "default": { @@ -1262,7 +1201,6 @@ func TestRunScheduledTask(t *testing.T) { }, qname: "notifications", id: t3.ID, - want: nil, wantScheduled: map[string][]*base.TaskMessage{ "default": {t1, t2}, "notifications": {}, @@ -1278,9 +1216,8 @@ func TestRunScheduledTask(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case h.SeedAllScheduledQueues(t, r.client, tc.scheduled) - got := r.RunTask(tc.qname, tc.id) - if got != tc.want { - t.Errorf("r.RunTask(%q, %s) = %v, want %v", tc.qname, tc.id, got, tc.want) + if got := r.RunTask(tc.qname, tc.id); got != nil { + t.Errorf("r.RunTask(%q, %s) returned error: %v", tc.qname, tc.id, got) continue } @@ -1300,6 +1237,162 @@ func TestRunScheduledTask(t *testing.T) { } } +func TestRunTaskError(t *testing.T) { + r := setup(t) + defer r.Close() + t1 := h.NewTaskMessage("send_email", nil) + s1 := time.Now().Add(-5 * time.Minute).Unix() + + tests := []struct { + desc string + active map[string][]*base.TaskMessage + pending map[string][]*base.TaskMessage + scheduled map[string][]base.Z + qname string + id uuid.UUID + match func(err error) bool + wantActive map[string][]*base.TaskMessage + wantPending map[string][]*base.TaskMessage + wantScheduled map[string][]*base.TaskMessage + }{ + { + desc: "It should return QueueNotFoundError if the queue doesn't exist", + active: map[string][]*base.TaskMessage{ + "default": {}, + }, + pending: map[string][]*base.TaskMessage{ + "default": {}, + }, + scheduled: map[string][]base.Z{ + "default": { + {Message: t1, Score: s1}, + }, + }, + qname: "nonexistent", + id: t1.ID, + match: errors.IsQueueNotFound, + wantActive: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantScheduled: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + }, + { + desc: "It should return TaskNotFound if the task is not found in the queue", + active: map[string][]*base.TaskMessage{ + "default": {}, + }, + pending: map[string][]*base.TaskMessage{ + "default": {}, + }, + scheduled: map[string][]base.Z{ + "default": { + {Message: t1, Score: s1}, + }, + }, + qname: "default", + id: uuid.New(), + match: errors.IsTaskNotFound, + wantActive: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantScheduled: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + }, + { + desc: "It should return FailedPrecondition error if task is already active", + active: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + pending: map[string][]*base.TaskMessage{ + "default": {}, + }, + scheduled: map[string][]base.Z{ + "default": {}, + }, + qname: "default", + id: t1.ID, + match: func(err error) bool { return errors.CanonicalCode(err) == errors.FailedPrecondition }, + wantActive: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantScheduled: map[string][]*base.TaskMessage{ + "default": {}, + }, + }, + { + desc: "It should return FailedPrecondition error if task is already pending", + active: map[string][]*base.TaskMessage{ + "default": {}, + }, + pending: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + scheduled: map[string][]base.Z{ + "default": {}, + }, + qname: "default", + id: t1.ID, + match: func(err error) bool { return errors.CanonicalCode(err) == errors.FailedPrecondition }, + wantActive: map[string][]*base.TaskMessage{ + "default": {}, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {t1}, + }, + wantScheduled: map[string][]*base.TaskMessage{ + "default": {}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case + h.SeedAllActiveQueues(t, r.client, tc.active) + h.SeedAllPendingQueues(t, r.client, tc.pending) + h.SeedAllScheduledQueues(t, r.client, tc.scheduled) + + got := r.RunTask(tc.qname, tc.id) + if !tc.match(got) { + t.Errorf("%s: unexpected return value %v", tc.desc, got) + continue + } + + for qname, want := range tc.wantActive { + gotActive := h.GetActiveMessages(t, r.client, qname) + if diff := cmp.Diff(want, gotActive, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q: (-want, +got)\n%s", base.ActiveKey(qname), diff) + } + } + + for qname, want := range tc.wantPending { + gotPending := h.GetPendingMessages(t, r.client, qname) + if diff := cmp.Diff(want, gotPending, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", base.PendingKey(qname), diff) + } + } + + for qname, want := range tc.wantScheduled { + gotScheduled := h.GetScheduledMessages(t, r.client, qname) + if diff := cmp.Diff(want, gotScheduled, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q, (-want, +got)\n%s", base.ScheduledKey(qname), diff) + } + } + } + +} + func TestRunAllScheduledTasks(t *testing.T) { r := setup(t) defer r.Close()