From b7c0c5d3aa9ebebc7047cb1f871499d903dd7085 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Mon, 30 Dec 2019 09:22:25 -0800 Subject: [PATCH] Handle mutated task in RDB's Done, Retry, Kill methods It is possible that user mutates the task's payload in Handler (Although doc says the task in handler is read-only). Prevent ending up in an inconsistent state by handling the case where user mutates the task. --- internal/rdb/rdb.go | 48 ++++++- internal/rdb/rdb_test.go | 263 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 308 insertions(+), 3 deletions(-) diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 24a264a..d9c3087 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -73,12 +73,26 @@ func (r *RDB) Done(msg *base.TaskMessage) error { return err } // Note: LREM count ZERO means "remove all elements equal to val" + // Note: Script will try removing the message by exact match first, + // if the task is muated and exact match is not found, it'll fallback + // to linear scan of the list and find a match with ID. // KEYS[1] -> asynq:in_progress // KEYS[2] -> asynq:processed: // ARGV[1] -> base.TaskMessage value // ARGV[2] -> stats expiration timestamp script := redis.NewScript(` - redis.call("LREM", KEYS[1], 0, ARGV[1]) + local x = redis.call("LREM", KEYS[1], 0, ARGV[1]) + if tonumber(x) == 0 then + local target = cjson.decode(ARGV[1]) + local data = redis.call("LRANGE", KEYS[1], 0, -1) + for _, s in ipairs(data) do + local msg = cjson.decode(s) + if target["ID"] == msg["ID"] then + redis.call("LREM", KEYS[1], 0, s) + break + end + end + end local n = redis.call("INCR", KEYS[2]) if tonumber(n) == 1 then redis.call("EXPIREAT", KEYS[2], ARGV[2]) @@ -139,6 +153,9 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) e if err != nil { return err } + // Note: Script will try removing the message by exact match first, + // if the task is muated and exact match is not found, it'll fallback + // to linear scan of the list and find a match with ID. // KEYS[1] -> asynq:in_progress // KEYS[2] -> asynq:retry // KEYS[3] -> asynq:processed: @@ -148,7 +165,18 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) e // ARGV[3] -> retry_at UNIX timestamp // ARGV[4] -> stats expiration timestamp script := redis.NewScript(` - redis.call("LREM", KEYS[1], 0, ARGV[1]) + local x = redis.call("LREM", KEYS[1], 0, ARGV[1]) + if tonumber(x) == 0 then + local target = cjson.decode(ARGV[1]) + local data = redis.call("LRANGE", KEYS[1], 0, -1) + for _, s in ipairs(data) do + local msg = cjson.decode(s) + if target["ID"] == msg["ID"] then + redis.call("LREM", KEYS[1], 0, s) + break + end + end + end redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2]) local n = redis.call("INCR", KEYS[3]) if tonumber(n) == 1 then @@ -193,6 +221,9 @@ func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error { processedKey := base.ProcessedKey(now) failureKey := base.FailureKey(now) expireAt := now.Add(statsTTL) + // Note: Script will try removing the message by exact match first, + // if the task is muated and exact match is not found, it'll fallback + // to linear scan of the list and find a match with ID. // KEYS[1] -> asynq:in_progress // KEYS[2] -> asynq:dead // KEYS[3] -> asynq:processed: @@ -204,7 +235,18 @@ func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error { // ARGV[5] -> max number of tasks in dead queue (e.g., 100) // ARGV[6] -> stats expiration timestamp script := redis.NewScript(` - redis.call("LREM", KEYS[1], 0, ARGV[1]) + local x = redis.call("LREM", KEYS[1], 0, ARGV[1]) + if tonumber(x) == 0 then + local target = cjson.decode(ARGV[1]) + local data = redis.call("LRANGE", KEYS[1], 0, -1) + for _, s in ipairs(data) do + local msg = cjson.decode(s) + if target["ID"] == msg["ID"] then + redis.call("LREM", KEYS[1], 0, s) + break + end + end + end redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2]) redis.call("ZREMRANGEBYSCORE", KEYS[2], "-inf", ARGV[4]) redis.call("ZREMRANGEBYRANK", KEYS[2], 0, -ARGV[5]) diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index fd13660..c4d8c31 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -144,6 +144,64 @@ func TestDone(t *testing.T) { } } +// Note: User should not mutate task payload in Handler +// However, we should handle even if the user mutates the task +// in Handler. This test case is to make sure that we remove task +// from in-progress queue when we call Done for the task. +func TestDoneWithMutatedTask(t *testing.T) { + r := setup(t) + t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello"}) + t2 := h.NewTaskMessage("export_csv", map[string]interface{}{"subjct": "hola"}) + + tests := []struct { + inProgress []*base.TaskMessage // initial state of the in-progress list + target *base.TaskMessage // task to remove + wantInProgress []*base.TaskMessage // final state of the in-progress list + }{ + { + inProgress: []*base.TaskMessage{t1, t2}, + target: t1, + wantInProgress: []*base.TaskMessage{t2}, + }, + { + inProgress: []*base.TaskMessage{t1}, + target: t1, + wantInProgress: []*base.TaskMessage{}, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case + h.SeedInProgressQueue(t, r.client, tc.inProgress) + + // Mutate payload map! + tc.target.Payload["newkey"] = 123 + + err := r.Done(tc.target) + if err != nil { + t.Errorf("(*RDB).Done(task) = %v, want nil", err) + continue + } + + gotInProgress := h.GetInProgressMessages(t, r.client) + if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q: (-want, +got):\n%s", base.InProgressQueue, diff) + continue + } + + processedKey := base.ProcessedKey(time.Now()) + gotProcessed := r.client.Get(processedKey).Val() + if gotProcessed != "1" { + t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed) + } + + gotTTL := r.client.TTL(processedKey).Val() + if gotTTL > statsTTL { + t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL) + } + } +} + func TestRequeue(t *testing.T) { r := setup(t) t1 := h.NewTaskMessage("send_email", nil) @@ -321,6 +379,105 @@ func TestRetry(t *testing.T) { } } } + +func TestRetryWithMutatedTask(t *testing.T) { + r := setup(t) + t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "Hola!"}) + t2 := h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/image.jpg"}) + t3 := h.NewTaskMessage("reindex", map[string]interface{}{}) + t1.Retried = 10 + errMsg := "SMTP server is not responding" + t1AfterRetry := &base.TaskMessage{ + ID: t1.ID, + Type: t1.Type, + Payload: t1.Payload, + Queue: t1.Queue, + Retry: t1.Retry, + Retried: t1.Retried + 1, + ErrorMsg: errMsg, + } + now := time.Now() + + tests := []struct { + inProgress []*base.TaskMessage + retry []h.ZSetEntry + msg *base.TaskMessage + processAt time.Time + errMsg string + wantInProgress []*base.TaskMessage + wantRetry []h.ZSetEntry + }{ + { + inProgress: []*base.TaskMessage{t1, t2}, + retry: []h.ZSetEntry{ + { + Msg: t3, + Score: now.Add(time.Minute).Unix(), + }, + }, + msg: t1, + processAt: now.Add(5 * time.Minute), + errMsg: errMsg, + wantInProgress: []*base.TaskMessage{t2}, + wantRetry: []h.ZSetEntry{ + { + Msg: t1AfterRetry, + Score: now.Add(5 * time.Minute).Unix(), + }, + { + Msg: t3, + Score: now.Add(time.Minute).Unix(), + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedInProgressQueue(t, r.client, tc.inProgress) + h.SeedRetryQueue(t, r.client, tc.retry) + + // Mutate paylod map! + tc.msg.Payload["newkey"] = "newvalue" + + err := r.Retry(tc.msg, tc.processAt, tc.errMsg) + if err != nil { + t.Errorf("(*RDB).Retry = %v, want nil", err) + continue + } + + gotInProgress := h.GetInProgressMessages(t, r.client) + if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", base.InProgressQueue, diff) + } + + gotRetry := h.GetRetryEntries(t, r.client) + if diff := cmp.Diff(tc.wantRetry, gotRetry, h.SortZSetEntryOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", base.RetryQueue, diff) + } + + processedKey := base.ProcessedKey(time.Now()) + gotProcessed := r.client.Get(processedKey).Val() + if gotProcessed != "1" { + t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed) + } + gotTTL := r.client.TTL(processedKey).Val() + if gotTTL > statsTTL { + t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL) + } + + failureKey := base.FailureKey(time.Now()) + gotFailure := r.client.Get(failureKey).Val() + if gotFailure != "1" { + t.Errorf("GET %q = %q, want 1", failureKey, gotFailure) + } + gotTTL = r.client.TTL(processedKey).Val() + if gotTTL > statsTTL { + t.Errorf("TTL %q = %v, want less than or equal to %v", failureKey, gotTTL, statsTTL) + } + } +} + func TestKill(t *testing.T) { r := setup(t) t1 := h.NewTaskMessage("send_email", nil) @@ -424,6 +581,112 @@ func TestKill(t *testing.T) { } } +func TestKillWithMutatedTask(t *testing.T) { + r := setup(t) + t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello"}) + t2 := h.NewTaskMessage("reindex", map[string]interface{}{}) + t3 := h.NewTaskMessage("generate_csv", map[string]interface{}{"path": "some/path/to/img"}) + errMsg := "SMTP server not responding" + t1AfterKill := &base.TaskMessage{ + ID: t1.ID, + Type: t1.Type, + Payload: t1.Payload, + Queue: t1.Queue, + Retry: t1.Retry, + Retried: t1.Retried, + ErrorMsg: errMsg, + } + now := time.Now() + + // TODO(hibiken): add test cases for trimming + tests := []struct { + inProgress []*base.TaskMessage + dead []h.ZSetEntry + target *base.TaskMessage // task to kill + wantInProgress []*base.TaskMessage + wantDead []h.ZSetEntry + }{ + { + inProgress: []*base.TaskMessage{t1, t2}, + dead: []h.ZSetEntry{ + { + Msg: t3, + Score: now.Add(-time.Hour).Unix(), + }, + }, + target: t1, + wantInProgress: []*base.TaskMessage{t2}, + wantDead: []h.ZSetEntry{ + { + Msg: t1AfterKill, + Score: now.Unix(), + }, + { + Msg: t3, + Score: now.Add(-time.Hour).Unix(), + }, + }, + }, + { + inProgress: []*base.TaskMessage{t1, t2, t3}, + dead: []h.ZSetEntry{}, + target: t1, + wantInProgress: []*base.TaskMessage{t2, t3}, + wantDead: []h.ZSetEntry{ + { + Msg: t1AfterKill, + Score: now.Unix(), + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case + h.SeedInProgressQueue(t, r.client, tc.inProgress) + h.SeedDeadQueue(t, r.client, tc.dead) + + // Mutate payload map! + tc.target.Payload["newkey"] = "newvalue" + + err := r.Kill(tc.target, errMsg) + if err != nil { + t.Errorf("(*RDB).Kill(%v, %v) = %v, want nil", tc.target, errMsg, err) + continue + } + + gotInProgress := h.GetInProgressMessages(t, r.client) + if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q: (-want, +got)\n%s", base.InProgressQueue, diff) + } + + gotDead := h.GetDeadEntries(t, r.client) + if diff := cmp.Diff(tc.wantDead, gotDead, h.SortZSetEntryOpt); diff != "" { + t.Errorf("mismatch found in %q after calling (*RDB).Kill: (-want, +got):\n%s", base.DeadQueue, diff) + } + + processedKey := base.ProcessedKey(time.Now()) + gotProcessed := r.client.Get(processedKey).Val() + if gotProcessed != "1" { + t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed) + } + gotTTL := r.client.TTL(processedKey).Val() + if gotTTL > statsTTL { + t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL) + } + + failureKey := base.FailureKey(time.Now()) + gotFailure := r.client.Get(failureKey).Val() + if gotFailure != "1" { + t.Errorf("GET %q = %q, want 1", failureKey, gotFailure) + } + gotTTL = r.client.TTL(processedKey).Val() + if gotTTL > statsTTL { + t.Errorf("TTL %q = %v, want less than or equal to %v", failureKey, gotTTL, statsTTL) + } + } +} + func TestRestoreUnfinished(t *testing.T) { r := setup(t) t1 := h.NewTaskMessage("send_email", nil)