diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index c762c34..6bfbec2 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -2,8 +2,10 @@ package rdb import ( "encoding/json" + "fmt" "time" + "github.com/go-redis/redis/v7" "github.com/google/uuid" ) @@ -221,3 +223,69 @@ func (r *RDB) ListDead() ([]*DeadTask, error) { } return tasks, nil } + +// Rescue finds a task that matches the given id and score from dead queue +// and enqueues it processing. If a task that maches the id and score does +// not exist, it returns ErrTaskNotFound. +func (r *RDB) Rescue(id string, score float64) error { + n, err := r.removeAndEnqueue(deadQ, id, score) + if err != nil { + return err + } + if n == 0 { + return ErrTaskNotFound + } + return nil +} + +// RetryNow finds a task that matches the given id and score from retry queue +// and enqueues it for processing. If a task that maches the id and score does +// not exist, it returns ErrTaskNotFound. +func (r *RDB) RetryNow(id string, score float64) error { + n, err := r.removeAndEnqueue(retryQ, id, score) + if err != nil { + return err + } + if n == 0 { + return ErrTaskNotFound + } + return nil +} + +// ProcessNow finds a task that matches the given id and score from scheduled queue +// and enqueues it for processing. If a task that maches the id and score does not +// exist, it returns ErrTaskNotFound. +func (r *RDB) ProcessNow(id string, score float64) error { + n, err := r.removeAndEnqueue(scheduledQ, id, score) + if err != nil { + return err + } + if n == 0 { + return ErrTaskNotFound + } + return nil +} + +func (r *RDB) removeAndEnqueue(zset, id string, score float64) (int64, error) { + script := redis.NewScript(` + local msgs = redis.call("ZRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1]) + for _, msg in ipairs(msgs) do + local decoded = cjson.decode(msg) + if decoded["ID"] == ARGV[2] then + redis.call("ZREM", KEYS[1], msg) + redis.call("LPUSH", KEYS[2], msg) + return 1 + end + end + return 0 + `) + res, err := script.Run(r.client, []string{zset, defaultQ}, score, id).Result() + if err != nil { + return 0, err + } + n, ok := res.(int64) + if !ok { + return 0, fmt.Errorf("could not cast %v to int64", res) + } + return n, nil +} diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 8c5afad..a7a7e8a 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -10,6 +10,44 @@ import ( "github.com/google/uuid" ) +// TODO(hibiken): Replace this with cmpopts.EquateApproxTime once it becomes availalble. +// https://github.com/google/go-cmp/issues/166 +// +// EquateApproxTime returns a Comparer options that +// determine two time.Time values to be equal if they +// are within the given time interval of one another. +// Note that if both times have a monotonic clock reading, +// the monotonic time difference will be used. +// +// The zero time is treated specially: it is only considered +// equal to another zero time value. +// +// It will panic if margin is negative. +func EquateApproxTime(margin time.Duration) cmp.Option { + if margin < 0 { + panic("negative duration in EquateApproxTime") + } + return cmp.FilterValues(func(x, y time.Time) bool { + return !x.IsZero() && !y.IsZero() + }, cmp.Comparer(timeApproximator{margin}.compare)) +} + +type timeApproximator struct { + margin time.Duration +} + +func (a timeApproximator) compare(x, y time.Time) bool { + // Avoid subtracting times to avoid overflow when the + // difference is larger than the largest representible duration. + if x.After(y) { + // Ensure x is always before y + x, y = y, x + } + // We're within the margin if x+margin >= y. + // Note: time.Time doesn't have AfterOrEqual method hence the negation. + return !x.Add(a.margin).Before(y) +} + func TestCurrentStats(t *testing.T) { r := setup(t) m1 := randomTask("send_email", "default", map[string]interface{}{"subject": "hello"}) @@ -429,40 +467,230 @@ func TestListDead(t *testing.T) { var timeCmpOpt = EquateApproxTime(time.Second) -// TODO(hibiken): Replace this with cmpopts.EquateApproxTime once it becomes availalble. -// https://github.com/google/go-cmp/issues/166 -// -// EquateApproxTime returns a Comparer options that -// determine two time.Time values to be equal if they -// are within the given time interval of one another. -// Note that if both times have a monotonic clock reading, -// the monotonic time difference will be used. -// -// The zero time is treated specially: it is only considered -// equal to another zero time value. -// -// It will panic if margin is negative. -func EquateApproxTime(margin time.Duration) cmp.Option { - if margin < 0 { - panic("negative duration in EquateApproxTime") +func TestRescue(t *testing.T) { + r := setup(t) + + t1 := randomTask("send_email", "default", nil) + t2 := randomTask("gen_thumbnail", "default", nil) + s1 := float64(time.Now().Add(-5 * time.Minute).Unix()) + s2 := float64(time.Now().Add(-time.Hour).Unix()) + type deadEntry struct { + msg *TaskMessage + score float64 + } + tests := []struct { + dead []deadEntry + score float64 + id string + want error // expected return value from calling Rescue + wantDead []*TaskMessage + wantEnqueued []*TaskMessage + }{ + { + dead: []deadEntry{ + {t1, s1}, + {t2, s2}, + }, + score: s2, + id: t2.ID.String(), + want: nil, + wantDead: []*TaskMessage{t1}, + wantEnqueued: []*TaskMessage{t2}, + }, + { + dead: []deadEntry{ + {t1, s1}, + {t2, s2}, + }, + score: 123.0, + id: t2.ID.String(), + want: ErrTaskNotFound, + wantDead: []*TaskMessage{t1, t2}, + wantEnqueued: []*TaskMessage{}, + }, + } + + for _, tc := range tests { + // clean up db before each test case. + if err := r.client.FlushDB().Err(); err != nil { + t.Fatal(err) + } + // initialize dead queue + for _, d := range tc.dead { + err := r.client.ZAdd(deadQ, &redis.Z{Member: mustMarshal(t, d.msg), Score: d.score}).Err() + if err != nil { + t.Fatal(err) + } + } + + got := r.Rescue(tc.id, tc.score) + if got != tc.want { + t.Errorf("r.Rescue(%s, %0.f) = %v, want %v", tc.id, tc.score, got, tc.want) + continue + } + + gotEnqueuedRaw := r.client.LRange(defaultQ, 0, -1).Val() + gotEnqueued := mustUnmarshalSlice(t, gotEnqueuedRaw) + if diff := cmp.Diff(tc.wantEnqueued, gotEnqueued, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", defaultQ, diff) + } + + gotDeadRaw := r.client.ZRange(deadQ, 0, -1).Val() + gotDead := mustUnmarshalSlice(t, gotDeadRaw) + if diff := cmp.Diff(tc.wantDead, gotDead, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q, (-want, +got)\n%s", deadQ, diff) + } } - return cmp.FilterValues(func(x, y time.Time) bool { - return !x.IsZero() && !y.IsZero() - }, cmp.Comparer(timeApproximator{margin}.compare)) } -type timeApproximator struct { - margin time.Duration +func TestRetryNow(t *testing.T) { + r := setup(t) + + t1 := randomTask("send_email", "default", nil) + t2 := randomTask("gen_thumbnail", "default", nil) + s1 := float64(time.Now().Add(-5 * time.Minute).Unix()) + s2 := float64(time.Now().Add(-time.Hour).Unix()) + type retryEntry struct { + msg *TaskMessage + score float64 + } + tests := []struct { + dead []retryEntry + score float64 + id string + want error // expected return value from calling RetryNow + wantRetry []*TaskMessage + wantEnqueued []*TaskMessage + }{ + { + dead: []retryEntry{ + {t1, s1}, + {t2, s2}, + }, + score: s2, + id: t2.ID.String(), + want: nil, + wantRetry: []*TaskMessage{t1}, + wantEnqueued: []*TaskMessage{t2}, + }, + { + dead: []retryEntry{ + {t1, s1}, + {t2, s2}, + }, + score: 123.0, + id: t2.ID.String(), + want: ErrTaskNotFound, + wantRetry: []*TaskMessage{t1, t2}, + wantEnqueued: []*TaskMessage{}, + }, + } + + for _, tc := range tests { + // clean up db before each test case. + if err := r.client.FlushDB().Err(); err != nil { + t.Fatal(err) + } + // initialize retry queue + for _, d := range tc.dead { + err := r.client.ZAdd(retryQ, &redis.Z{Member: mustMarshal(t, d.msg), Score: d.score}).Err() + if err != nil { + t.Fatal(err) + } + } + + got := r.RetryNow(tc.id, tc.score) + if got != tc.want { + t.Errorf("r.RetryNow(%s, %0.f) = %v, want %v", tc.id, tc.score, got, tc.want) + continue + } + + gotEnqueuedRaw := r.client.LRange(defaultQ, 0, -1).Val() + gotEnqueued := mustUnmarshalSlice(t, gotEnqueuedRaw) + if diff := cmp.Diff(tc.wantEnqueued, gotEnqueued, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", defaultQ, diff) + } + + gotRetryRaw := r.client.ZRange(retryQ, 0, -1).Val() + gotRetry := mustUnmarshalSlice(t, gotRetryRaw) + if diff := cmp.Diff(tc.wantRetry, gotRetry, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q, (-want, +got)\n%s", retryQ, diff) + } + } } -func (a timeApproximator) compare(x, y time.Time) bool { - // Avoid subtracting times to avoid overflow when the - // difference is larger than the largest representible duration. - if x.After(y) { - // Ensure x is always before y - x, y = y, x +func TestProcessNow(t *testing.T) { + r := setup(t) + + t1 := randomTask("send_email", "default", nil) + t2 := randomTask("gen_thumbnail", "default", nil) + s1 := float64(time.Now().Add(-5 * time.Minute).Unix()) + s2 := float64(time.Now().Add(-time.Hour).Unix()) + type scheduledEntry struct { + msg *TaskMessage + score float64 + } + tests := []struct { + dead []scheduledEntry + score float64 + id string + want error // expected return value from calling ProcessNow + wantScheduled []*TaskMessage + wantEnqueued []*TaskMessage + }{ + { + dead: []scheduledEntry{ + {t1, s1}, + {t2, s2}, + }, + score: s2, + id: t2.ID.String(), + want: nil, + wantScheduled: []*TaskMessage{t1}, + wantEnqueued: []*TaskMessage{t2}, + }, + { + dead: []scheduledEntry{ + {t1, s1}, + {t2, s2}, + }, + score: 123.0, + id: t2.ID.String(), + want: ErrTaskNotFound, + wantScheduled: []*TaskMessage{t1, t2}, + wantEnqueued: []*TaskMessage{}, + }, + } + + for _, tc := range tests { + // clean up db before each test case. + if err := r.client.FlushDB().Err(); err != nil { + t.Fatal(err) + } + // initialize scheduled queue + for _, d := range tc.dead { + err := r.client.ZAdd(scheduledQ, &redis.Z{Member: mustMarshal(t, d.msg), Score: d.score}).Err() + if err != nil { + t.Fatal(err) + } + } + + got := r.ProcessNow(tc.id, tc.score) + if got != tc.want { + t.Errorf("r.RetryNow(%s, %0.f) = %v, want %v", tc.id, tc.score, got, tc.want) + continue + } + + gotEnqueuedRaw := r.client.LRange(defaultQ, 0, -1).Val() + gotEnqueued := mustUnmarshalSlice(t, gotEnqueuedRaw) + if diff := cmp.Diff(tc.wantEnqueued, gotEnqueued, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q; (-want, +got)\n%s", defaultQ, diff) + } + + gotScheduledRaw := r.client.ZRange(scheduledQ, 0, -1).Val() + gotScheduled := mustUnmarshalSlice(t, gotScheduledRaw) + if diff := cmp.Diff(tc.wantScheduled, gotScheduled, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q, (-want, +got)\n%s", scheduledQ, diff) + } } - // We're within the margin if x+margin >= y. - // Note: time.Time doesn't have AfterOrEqual method hence the negation. - return !x.Add(a.margin).Before(y) } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 7ef8515..6111b99 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -14,7 +14,6 @@ import ( // Redis keys const ( - allQueues = "asynq:queues" // SET queuePrefix = "asynq:queues:" // LIST - asynq:queues: defaultQ = queuePrefix + "default" // LIST scheduledQ = "asynq:scheduled" // ZSET @@ -23,8 +22,13 @@ const ( inProgressQ = "asynq:in_progress" // LIST ) -// ErrDequeueTimeout indicates that the blocking dequeue operation timed out. -var ErrDequeueTimeout = errors.New("blocking dequeue operation timed out") +var ( + // ErrDequeueTimeout indicates that the blocking dequeue operation timed out. + ErrDequeueTimeout = errors.New("blocking dequeue operation timed out") + + // ErrTaskNotFound indicates that a task that matches the given identifier was not found. + ErrTaskNotFound = errors.New("could not find a task") +) // RDB is a client interface to query and mutate task queues. type RDB struct { @@ -72,7 +76,6 @@ func (r *RDB) Enqueue(msg *TaskMessage) error { } qname := queuePrefix + msg.Queue pipe := r.client.Pipeline() - pipe.SAdd(allQueues, qname) pipe.LPush(qname, string(bytes)) _, err = pipe.Exec() if err != nil { @@ -182,19 +185,18 @@ func (r *RDB) CheckAndEnqueue() error { return nil } -// Forward moves all tasks with a score less than the current unix time +// forward moves all tasks with a score less than the current unix time // from the given zset to the default queue. func (r *RDB) forward(from string) error { script := redis.NewScript(` local msgs = redis.call("ZRANGEBYSCORE", KEYS[1], "-inf", ARGV[1]) for _, msg in ipairs(msgs) do redis.call("ZREM", KEYS[1], msg) - redis.call("SADD", KEYS[2], KEYS[3]) - redis.call("LPUSH", KEYS[3], msg) + redis.call("LPUSH", KEYS[2], msg) end return msgs `) now := float64(time.Now().Unix()) - _, err := script.Run(r.client, []string{from, allQueues, defaultQ}, now).Result() + _, err := script.Run(r.client, []string{from, defaultQ}, now).Result() return err } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 66a2f13..646e45a 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -112,9 +112,6 @@ func TestEnqueue(t *testing.T) { t.Errorf("LIST %q has length %d, want 1", defaultQ, len(res)) continue } - if !r.client.SIsMember(allQueues, defaultQ).Val() { - t.Errorf("SISMEMBER %q %q = false, want true", allQueues, defaultQ) - } if diff := cmp.Diff(*tc.msg, *mustUnmarshal(t, res[0])); diff != "" { t.Errorf("persisted data differed from the original input (-want, +got)\n%s", diff) }