From b4a8227776d3c455502865ee4768889853ed7387 Mon Sep 17 00:00:00 2001 From: Aziz Aliyev Date: Tue, 18 Mar 2025 16:15:33 +0400 Subject: [PATCH] Implement UpdateTaskPayload method for inspector --- inspector.go | 24 +++++++ inspector_test.go | 142 ++++++++++++++++++++++++++++++++++++++++ internal/rdb/inspect.go | 87 ++++++++++++++++++++++++ 3 files changed, 253 insertions(+) diff --git a/inspector.go b/inspector.go index e361d22..e4800cb 100644 --- a/inspector.go +++ b/inspector.go @@ -583,6 +583,30 @@ func (i *Inspector) DeleteAllAggregatingTasks(queue, group string) (int, error) return int(n), err } +// UpdateTaskPayload updates a task with the given id from the given queue with given payload. +// The task needs to be in scheduled state, +// otherwise UpdateTaskPayload will return an error. +// +// If a queue with the given name doesn't exist, it returns an error wrapping ErrQueueNotFound. +// If a task with the given id doesn't exist in the queue, it returns an error wrapping ErrTaskNotFound. +// If the task is not in scheduled state, it returns a non-nil error. +func (i *Inspector) UpdateTaskPayload(queue, id string, payload []byte) error { + if err := base.ValidateQueueName(queue); err != nil { + return fmt.Errorf("asynq: %v", err) + } + err := i.rdb.UpdateTaskPayload(queue, id, payload) + switch { + case errors.IsQueueNotFound(err): + return fmt.Errorf("asynq: %w", ErrQueueNotFound) + case errors.IsTaskNotFound(err): + return fmt.Errorf("asynq: %w", ErrTaskNotFound) + case err != nil: + return fmt.Errorf("asynq: %v", err) + } + return nil + +} + // DeleteTask deletes a task with the given id from the given queue. // The task needs to be in pending, scheduled, retry, or archived state, // otherwise DeleteTask will return an error. diff --git a/inspector_test.go b/inspector_test.go index 1d30a03..7ec385e 100644 --- a/inspector_test.go +++ b/inspector_test.go @@ -2369,6 +2369,148 @@ func TestInspectorRunAllArchivedTasks(t *testing.T) { } } +func TestInspectorUpdateTaskPayloadUpdatesScheduledTaskPayload(t *testing.T) { + r := setup(t) + defer r.Close() + m1_old := h.NewTaskMessage("task1", []byte("m1_old")) + m1_new := h.NewTaskMessage("task1", nil) + m1_new.ID = m1_old.ID + m2_old := h.NewTaskMessage("task2", nil) + m2_new := h.NewTaskMessage("task2", []byte("m2_new")) + m2_new.ID = m2_old.ID + m3_old := h.NewTaskMessageWithQueue("task3", []byte("m3_old"), "custom") + m3_new := h.NewTaskMessageWithQueue("task3", []byte("m3_new"), "custom") + m3_new.ID = m3_old.ID + + now := time.Now() + z1_old := base.Z{Message: m1_old, Score: now.Add(5 * time.Minute).Unix()} + z1_new := base.Z{Message: m1_new, Score: now.Add(5 * time.Minute).Unix()} + z2_old := base.Z{Message: m2_old, Score: now.Add(15 * time.Minute).Unix()} + z2_new := base.Z{Message: m2_new, Score: now.Add(15 * time.Minute).Unix()} + z3_old := base.Z{Message: m3_old, Score: now.Add(2 * time.Minute).Unix()} + z3_new := base.Z{Message: m3_new, Score: now.Add(2 * time.Minute).Unix()} + + inspector := NewInspector(getRedisConnOpt(t)) + + tests := []struct { + scheduled map[string][]base.Z + qname string + id string + newPayload []byte + wantScheduled map[string][]base.Z + }{ + { + scheduled: map[string][]base.Z{ + "default": {z1_old, z2_old}, + "custom": {z3_old}, + }, + qname: "default", + id: createScheduledTask(z2_old).ID, + newPayload: m2_new.Payload, + wantScheduled: map[string][]base.Z{ + "default": {z1_old, z2_new}, + "custom": {z3_old}, + }, + }, + { + scheduled: map[string][]base.Z{ + "default": {z1_old, z2_old}, + "custom": {z3_old}, + }, + qname: "default", + id: createScheduledTask(z1_old).ID, + newPayload: m1_new.Payload, + wantScheduled: map[string][]base.Z{ + "default": {z1_new, z2_old}, + "custom": {z3_old}, + }, + }, + { + scheduled: map[string][]base.Z{ + "default": {z1_old, z2_old}, + "custom": {z3_old}, + }, + qname: "custom", + id: createScheduledTask(z3_old).ID, + newPayload: m3_new.Payload, + wantScheduled: map[string][]base.Z{ + "default": {z1_old, z2_old}, + "custom": {z3_new}, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) + h.SeedAllScheduledQueues(t, r, tc.scheduled) + + if err := inspector.UpdateTaskPayload(tc.qname, tc.id, tc.newPayload); err != nil { + t.Errorf("UpdateTask(%q, %q) returned error: %v", tc.qname, tc.id, err) + } + for qname, want := range tc.wantScheduled { + gotScheduled := h.GetScheduledEntries(t, r, qname) + if diff := cmp.Diff(want, gotScheduled, h.SortZSetEntryOpt); diff != "" { + t.Errorf("unexpected scheduled tasks in queue %q: (-want, +got)\n%s", qname, diff) + } + + } + } +} + +func TestInspectorUpdateTaskPayloadError(t *testing.T) { + r := setup(t) + defer r.Close() + m1 := h.NewTaskMessage("task1", nil) + m2 := h.NewTaskMessage("task2", nil) + m3 := h.NewTaskMessageWithQueue("task3", nil, "custom") + + now := time.Now() + z1 := base.Z{Message: m1, Score: now.Add(5 * time.Minute).Unix()} + z2 := base.Z{Message: m2, Score: now.Add(15 * time.Minute).Unix()} + z3 := base.Z{Message: m3, Score: now.Add(2 * time.Minute).Unix()} + + inspector := NewInspector(getRedisConnOpt(t)) + + tests := []struct { + tasks map[string][]base.Z + qname string + id string + newPayload []byte + wantErr error + }{ + { + tasks: map[string][]base.Z{ + "default": {z1, z2}, + "custom": {z3}, + }, + qname: "nonexistent", + id: createScheduledTask(z2).ID, + newPayload: nil, + wantErr: ErrQueueNotFound, + }, + { + tasks: map[string][]base.Z{ + "default": {z1, z2}, + "custom": {z3}, + }, + qname: "default", + id: uuid.NewString(), + newPayload: nil, + wantErr: ErrTaskNotFound, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) + h.SeedAllScheduledQueues(t, r, tc.tasks) + + if err := inspector.UpdateTaskPayload(tc.qname, tc.id, tc.newPayload); !errors.Is(err, tc.wantErr) { + t.Errorf("UpdateTask(%q, %q) = %v, want %v", tc.qname, tc.id, err, tc.wantErr) + continue + } + } +} + func TestInspectorDeleteTaskDeletesPendingTask(t *testing.T) { r := setup(t) defer r.Close() diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index a18c4e2..4bd7975 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -1412,6 +1412,93 @@ func (r *RDB) archiveAll(src, dst, qname string) (int64, error) { return n, nil } +// Input: +// KEYS[1] -> asynq:{}:t: +// -- +// ARGV[1] -> task message data +// +// Output: +// Numeric code indicating the status: +// Returns 1 if task is successfully updated. +// Returns 0 if task is not found. +// Returns -1 if task is not in scheduled state. +var updateTaskPayloadCmd = redis.NewScript(` +-- Check if given taks exists +if redis.call("EXISTS", KEYS[1]) == 0 then + return 0 +end +local state, pending_since, group, unique_key = unpack(redis.call("HMGET", KEYS[1], "state", "pending_since", "group", "unique_key")) +if state ~= "scheduled" then + return -1 +end +local redis_call_args = {"state", state} + +if pending_since then + table.insert(redis_call_args, "pending_since") + table.insert(redis_call_args, pending_since) +end +if group then + table.insert(redis_call_args, "group") + table.insert(redis_call_args, group) +end +if unique_key then + table.insert(redis_call_args, "unique_key") + table.insert(redis_call_args, unique_key) +end +redis.call("HSET", KEYS[1], "msg", ARGV[1], unpack(redis_call_args)) +return 1 +`) + +// UpdateTaskPayload finds a task that matches the id from the given queue and updates it's payload. +// It returns nil if it successfully updated the task payload. +// +// If a queue with the given name doesn't exist, it returns QueueNotFoundError. +// If a task with the given id doesn't exist in the queue, it returns TaskNotFoundError +// If a task is in active state it returns non-nil error with Code FailedPrecondition. +func (r *RDB) UpdateTaskPayload(qname, id string, payload []byte) error { + var op errors.Op = "rdb.UpdateTask" + if err := r.checkQueueExists(qname); err != nil { + return errors.E(op, errors.CanonicalCode(err), err) + } + + taskInfo, err := r.GetTaskInfo(qname, id) + if err != nil { + return errors.E(op, errors.Unknown, err) + } + + taskInfo.Message.Payload = payload + + encoded, err := base.EncodeMessage(taskInfo.Message) + if err != nil { + return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) + } + keys := []string{ + base.TaskKey(qname, id), + } + argv := []interface{}{ + encoded, + } + + res, err := updateTaskPayloadCmd.Run(context.Background(), r.client, keys, argv...).Result() + if err != nil { + return errors.E(op, errors.Unknown, err) + } + n, ok := res.(int64) + if !ok { + return errors.E(op, errors.Internal, fmt.Sprintf("cast error: updateTaskCmd script returned unexported value %v", res)) + } + switch n { + case 1: + return nil + case 0: + return errors.E(op, errors.NotFound, &errors.TaskNotFoundError{Queue: qname, ID: id}) + case -1: + return errors.E(op, errors.FailedPrecondition, "cannot update task that is not in scheduled state.") + default: + return errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from updateTaskCmd script: %d", n)) + } +} + // Input: // KEYS[1] -> asynq:{}:t: // KEYS[2] -> asynq:{}:groups