From e7c1c3ad6f747b3e86aa18d573f446270e6228aa Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Fri, 10 Dec 2021 09:07:41 -0800 Subject: [PATCH] Use clock in RDB --- internal/rdb/rdb.go | 46 +++++++++++++++++++++++------------ internal/rdb/rdb_test.go | 19 ++++++++++----- internal/timeutil/timeutil.go | 8 ++++++ 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 191f930..70072a9 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -13,6 +13,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" + "github.com/hibiken/asynq/internal/timeutil" "github.com/spf13/cast" ) @@ -21,11 +22,15 @@ const statsTTL = 90 * 24 * time.Hour // 90 days // RDB is a client interface to query and mutate task queues. type RDB struct { client redis.UniversalClient + clock timeutil.Clock } // NewRDB returns a new instance of RDB. func NewRDB(client redis.UniversalClient) *RDB { - return &RDB{client} + return &RDB{ + client: client, + clock: timeutil.NewRealClock(), + } } // Close closes the connection with redis server. @@ -38,6 +43,13 @@ func (r *RDB) Client() redis.UniversalClient { return r.client } +// SetClock sets the clock used by RDB to the given clock. +// +// Use this function to set the clock to SimulatedClock in tests. +func (r *RDB) SetClock(c timeutil.Clock) { + r.clock = c +} + // Ping checks the connection with redis server. func (r *RDB) Ping() error { return r.client.Ping(context.Background()).Err() @@ -73,7 +85,7 @@ func (r *RDB) runScriptWithErrorCode(ctx context.Context, op errors.Op, script * // ARGV[2] -> task ID // ARGV[3] -> task timeout in seconds (0 if not timeout) // ARGV[4] -> task deadline in unix time (0 if no deadline) -// ARGV[5] -> current time +// ARGV[5] -> current uinx time in millisecond // // Output: // Returns 1 if successfully enqueued @@ -111,7 +123,7 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { msg.ID, msg.Timeout, msg.Deadline, - time.Now().Unix(), + timeutil.UnixMilli(r.clock.Now()), } n, err := r.runScriptWithErrorCode(ctx, op, enqueueCmd, keys, argv...) if err != nil { @@ -134,7 +146,7 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { // ARGV[3] -> task message data // ARGV[4] -> task timeout in seconds (0 if not timeout) // ARGV[5] -> task deadline in unix time (0 if no deadline) -// ARGV[6] -> current time +// ARGV[6] -> current unix time in milliseconds // // Output: // Returns 1 if successfully enqueued @@ -181,7 +193,7 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time encoded, msg.Timeout, msg.Deadline, - time.Now().Unix(), + timeutil.UnixMilli(r.clock.Now()), } n, err := r.runScriptWithErrorCode(ctx, op, enqueueUniqueCmd, keys, argv...) if err != nil { @@ -254,7 +266,7 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, deadline time.Ti base.DeadlinesKey(qname), } argv := []interface{}{ - time.Now().Unix(), + r.clock.Now().Unix(), base.TaskKeyPrefix(qname), } res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result() @@ -341,7 +353,7 @@ return redis.status_reply("OK") func (r *RDB) Done(msg *base.TaskMessage) error { var op errors.Op = "rdb.Done" ctx := context.Background() - now := time.Now() + now := r.clock.Now() expireAt := now.Add(statsTTL) keys := []string{ base.ActiveKey(msg.Queue), @@ -424,7 +436,7 @@ return redis.status_reply("OK") func (r *RDB) MarkAsComplete(msg *base.TaskMessage) error { var op errors.Op = "rdb.MarkAsComplete" ctx := context.Background() - now := time.Now() + now := r.clock.Now() statsExpireAt := now.Add(statsTTL) msg.CompletedAt = now.Unix() encoded, err := base.EncodeMessage(msg) @@ -644,7 +656,7 @@ return redis.status_reply("OK")`) func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string, isFailure bool) error { var op errors.Op = "rdb.Retry" ctx := context.Background() - now := time.Now() + now := r.clock.Now() modified := *msg if isFailure { modified.Retried++ @@ -717,7 +729,7 @@ return redis.status_reply("OK")`) func (r *RDB) Archive(msg *base.TaskMessage, errMsg string) error { var op errors.Op = "rdb.Archive" ctx := context.Background() - now := time.Now() + now := r.clock.Now() modified := *msg modified.ErrorMsg = errMsg modified.LastFailedAt = now.Unix() @@ -760,8 +772,9 @@ func (r *RDB) ForwardIfReady(qnames ...string) error { // KEYS[1] -> source queue (e.g. asynq:{:scheduled or asynq:{}:retry}) // KEYS[2] -> asynq:{}:pending -// ARGV[1] -> current unix time +// ARGV[1] -> current unix time in seconds // ARGV[2] -> task key prefix +// ARGV[3] -> current unix time in milliseconds // Note: Script moves tasks up to 100 at a time to keep the runtime of script short. var forwardCmd = redis.NewScript(` local ids = redis.call("ZRANGEBYSCORE", KEYS[1], "-inf", ARGV[1], "LIMIT", 0, 100) @@ -770,15 +783,16 @@ for _, id in ipairs(ids) do redis.call("ZREM", KEYS[1], id) redis.call("HSET", ARGV[2] .. id, "state", "pending", - "pending_since", ARGV[1]) + "pending_since", ARGV[3]) end return table.getn(ids)`) // forward moves tasks with a score less than the current unix time // from the src zset to the dst list. It returns the number of tasks moved. func (r *RDB) forward(src, dst, taskKeyPrefix string) (int, error) { + now := r.clock.Now() res, err := forwardCmd.Run(context.Background(), r.client, - []string{src, dst}, time.Now().Unix(), taskKeyPrefix).Result() + []string{src, dst}, now.Unix(), taskKeyPrefix, timeutil.UnixMilli(now)).Result() if err != nil { return 0, errors.E(errors.Internal, fmt.Sprintf("redis eval error: %v", err)) } @@ -843,7 +857,7 @@ func (r *RDB) deleteExpiredCompletedTasks(qname string, batchSize int) (int64, e var op errors.Op = "rdb.DeleteExpiredCompletedTasks" keys := []string{base.CompletedKey(qname)} argv := []interface{}{ - time.Now().Unix(), + r.clock.Now().Unix(), base.TaskKeyPrefix(qname), batchSize, } @@ -921,7 +935,7 @@ func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo if err != nil { return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode server info: %v", err)) } - exp := time.Now().Add(ttl).UTC() + exp := r.clock.Now().Add(ttl).UTC() args := []interface{}{ttl.Seconds(), bytes} // args to the lua script for _, w := range workers { bytes, err := base.EncodeWorkerInfo(w) @@ -986,7 +1000,7 @@ func (r *RDB) WriteSchedulerEntries(schedulerID string, entries []*base.Schedule } args = append(args, bytes) } - exp := time.Now().Add(ttl).UTC() + exp := r.clock.Now().Add(ttl).UTC() key := base.SchedulerEntriesKey(schedulerID) err := r.client.ZAdd(ctx, base.AllSchedulers, &redis.Z{Score: float64(exp.Unix()), Member: key}).Err() if err != nil { diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index b60b24b..5f696a2 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -21,6 +21,7 @@ import ( h "github.com/hibiken/asynq/internal/asynqtest" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" + "github.com/hibiken/asynq/internal/timeutil" ) // variables used for package testing. @@ -67,6 +68,9 @@ func TestEnqueue(t *testing.T) { t2 := h.NewTaskMessageWithQueue("generate_csv", h.JSON(map[string]interface{}{}), "csv") t3 := h.NewTaskMessageWithQueue("sync", nil, "low") + enqueueTime := time.Now() + r.SetClock(timeutil.NewSimulatedClock(enqueueTime)) + tests := []struct { msg *base.TaskMessage }{ @@ -78,7 +82,6 @@ func TestEnqueue(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. - enqueueTime := time.Now() err := r.Enqueue(context.Background(), tc.msg) if err != nil { t.Errorf("(*RDB).Enqueue(msg) = %v, want nil", err) @@ -117,7 +120,7 @@ func TestEnqueue(t *testing.T) { t.Errorf("deadline field under task-key is set to %v, want %v", deadline, want) } pendingSince := r.client.HGet(context.Background(), taskKey, "pending_since").Val() // "pending_since" field - if want := strconv.Itoa(int(enqueueTime.Unix())); pendingSince != want { + if want := strconv.Itoa(int(timeutil.UnixMilli(enqueueTime))); pendingSince != want { t.Errorf("pending_since field under task-key is set to %v, want %v", pendingSince, want) } @@ -175,6 +178,9 @@ func TestEnqueueUnique(t *testing.T) { UniqueKey: base.UniqueKey(base.DefaultQueueName, "email", h.JSON(map[string]interface{}{"user_id": 123})), } + enqueueTime := time.Now() + r.SetClock(timeutil.NewSimulatedClock(enqueueTime)) + tests := []struct { msg *base.TaskMessage ttl time.Duration // uniqueness ttl @@ -186,7 +192,6 @@ func TestEnqueueUnique(t *testing.T) { h.FlushDB(t, r.client) // clean up db before each test case. // Enqueue the first message, should succeed. - enqueueTime := time.Now() err := r.EnqueueUnique(context.Background(), tc.msg, tc.ttl) if err != nil { t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil", @@ -237,7 +242,7 @@ func TestEnqueueUnique(t *testing.T) { t.Errorf("deadline field under task-key is set to %v, want %v", deadline, want) } pendingSince := r.client.HGet(context.Background(), taskKey, "pending_since").Val() // "pending_since" field - if want := strconv.Itoa(int(enqueueTime.Unix())); pendingSince != want { + if want := strconv.Itoa(int(timeutil.UnixMilli(enqueueTime))); pendingSince != want { t.Errorf("pending_since field under task-key is set to %v, want %v", pendingSince, want) } uniqueKey := r.client.HGet(context.Background(), taskKey, "unique_key").Val() // "unique_key" field @@ -2065,7 +2070,9 @@ func TestForwardIfReady(t *testing.T) { h.SeedAllScheduledQueues(t, r.client, tc.scheduled) h.SeedAllRetryQueues(t, r.client, tc.retry) - now := time.Now() // time when the method is called + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + err := r.ForwardIfReady(tc.qnames...) if err != nil { t.Errorf("(*RDB).CheckScheduled(%v) = %v, want nil", tc.qnames, err) @@ -2080,7 +2087,7 @@ func TestForwardIfReady(t *testing.T) { // Make sure "pending_since" field is set for _, msg := range gotPending { pendingSince := r.client.HGet(context.Background(), base.TaskKey(msg.Queue, msg.ID), "pending_since").Val() - if want := strconv.Itoa(int(now.Unix())); pendingSince != want { + if want := strconv.Itoa(int(timeutil.UnixMilli(now))); pendingSince != want { t.Error("pending_since field is not set for newly pending message") } } diff --git a/internal/timeutil/timeutil.go b/internal/timeutil/timeutil.go index 65691ad..05738d4 100644 --- a/internal/timeutil/timeutil.go +++ b/internal/timeutil/timeutil.go @@ -36,3 +36,11 @@ func (c *SimulatedClock) Now() time.Time { return c.t } func (c *SimulatedClock) SetTime(t time.Time) { c.t = t } func (c *SimulatedClock) AdvanceTime(d time.Duration) { c.t.Add(d) } + +// UnixMilli returns t as a Unix time, the number of milliseconds elapsed since +// January 1, 1970 UTC. +// +// TODO: Use time.UnixMilli() when we drop support for go1.16 or below +func UnixMilli(t time.Time) int64 { + return t.UnixNano() / 1e6 +}