From 3551d3334c57b119f36fe4983807cf5cfc67ca26 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Fri, 11 Mar 2022 10:59:05 -0800 Subject: [PATCH] Use zset for aggregation set to preserve score --- internal/asynqtest/asynqtest.go | 31 ++----------------------------- internal/rdb/rdb.go | 22 +++++++++++----------- internal/rdb/rdb_test.go | 19 ++++++++++++------- 3 files changed, 25 insertions(+), 47 deletions(-) diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 1bd9c3c..6511fee 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -254,10 +254,10 @@ func SeedGroup(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, seedRedisZSet(tb, r, base.GroupKey(qname, gname), entries, base.TaskStateAggregating) } -func SeedAggregationSet(tb testing.TB, r redis.UniversalClient, msgs []*base.TaskMessage, qname, gname, setID string) { +func SeedAggregationSet(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, gname, setID string) { tb.Helper() r.SAdd(context.Background(), base.AllQueues, qname) - seedRedisSet(tb, r, base.AggregationSetKey(qname, gname, setID), msgs, base.TaskStateAggregating) + seedRedisZSet(tb, r, base.AggregationSetKey(qname, gname, setID), entries, base.TaskStateAggregating) } // SeedAllPendingQueues initializes all of the specified queues with the given messages. @@ -386,33 +386,6 @@ func seedRedisZSet(tb testing.TB, c redis.UniversalClient, key string, } } -func seedRedisSet(tb testing.TB, c redis.UniversalClient, key string, - msgs []*base.TaskMessage, state base.TaskState) { - tb.Helper() - for _, msg := range msgs { - encoded := MustMarshal(tb, msg) - if err := c.SAdd(context.Background(), key, msg.ID).Err(); err != nil { - tb.Fatal(err) - } - taskKey := base.TaskKey(msg.Queue, msg.ID) - data := map[string]interface{}{ - "msg": encoded, - "state": state.String(), - "unique_key": msg.UniqueKey, - "group": msg.GroupKey, - } - if err := c.HSet(context.Background(), taskKey, data).Err(); err != nil { - tb.Fatal(err) - } - if len(msg.UniqueKey) > 0 { - err := c.SetNX(context.Background(), msg.UniqueKey, msg.ID, 1*time.Minute).Err() - if err != nil { - tb.Fatalf("Failed to set unique lock in redis: %v", err) - } - } - } -} - // GetPendingMessages returns all pending messages in the given queue. // It also asserts the state field of the task. func GetPendingMessages(tb testing.TB, r redis.UniversalClient, qname string) []*base.TaskMessage { diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 9d8a15d..c9b5283 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -1015,9 +1015,9 @@ if size == 0 then end local maxSize = tonumber(ARGV[1]) if maxSize ~= 0 and size >= maxSize then - local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) - for _, msg in ipairs(msgs) do - redis.call("SADD", KEYS[2], msg) + local res = redis.call("ZRANGE", KEYS[1], 0, maxSize-1, "WITHSCORES") + for i=1, table.getn(res)-1, 2 do + redis.call("ZADD", KEYS[2], tonumber(res[i+1]), res[i]) end redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) @@ -1030,9 +1030,9 @@ if maxDelay ~= 0 then local oldestEntryScore = tonumber(oldestEntry[2]) local maxDelayTime = currentTime - maxDelay if oldestEntryScore <= maxDelayTime then - local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) - for _, msg in ipairs(msgs) do - redis.call("SADD", KEYS[2], msg) + local res = redis.call("ZRANGE", KEYS[1], 0, maxSize-1, "WITHSCORES") + for i=1, table.getn(res)-1, 2 do + redis.call("ZADD", KEYS[2], tonumber(res[i+1]), res[i]) end redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) @@ -1043,9 +1043,9 @@ local latestEntry = redis.call("ZREVRANGE", KEYS[1], 0, 0, "WITHSCORES") local latestEntryScore = tonumber(latestEntry[2]) local gracePeriodStartTime = currentTime - tonumber(ARGV[3]) if latestEntryScore <= gracePeriodStartTime then - local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) - for _, msg in ipairs(msgs) do - redis.call("SADD", KEYS[2], msg) + local res = redis.call("ZRANGE", KEYS[1], 0, maxSize-1, "WITHSCORES") + for i=1, table.getn(res)-1, 2 do + redis.call("ZADD", KEYS[2], tonumber(res[i+1]), res[i]) end redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) @@ -1101,7 +1101,7 @@ func (r *RDB) AggregationCheck(qname, gname string, t time.Time, gracePeriod, ma // ARGV[1] -> task key prefix var readAggregationSetCmd = redis.NewScript(` local msgs = {} -local ids = redis.call("SMEMBERS", KEYS[1]) +local ids = redis.call("ZRANGE", KEYS[1], 0, -1) for _, id in ipairs(ids) do local key = ARGV[1] .. id table.insert(msgs, redis.call("HGET", key, "msg")) @@ -1142,7 +1142,7 @@ func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessag // ------- // ARGV[1] -> task key prefix var deleteAggregationSetCmd = redis.NewScript(` -local ids = redis.call("SMEMBERS", KEYS[1]) +local ids = redis.call("ZRANGE", KEYS[1], 0, -1) for _, id in ipairs(ids) do redis.call("DEL", ARGV[1] .. id) end diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index aa58a99..d5ceaed 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3359,6 +3359,7 @@ func TestDeleteAggregationSet(t *testing.T) { r := setup(t) defer r.Close() + now := time.Now() ctx := context.Background() setID := uuid.NewString() msg1 := h.NewTaskMessageBuilder().SetType("foo").SetQueue("default").SetGroup("mygroup").Build() @@ -3366,16 +3367,20 @@ func TestDeleteAggregationSet(t *testing.T) { msg3 := h.NewTaskMessageBuilder().SetType("baz").SetQueue("default").SetGroup("mygroup").Build() tests := []struct { - aggregationSet []*base.TaskMessage + aggregationSet []base.Z qname string gname string setID string }{ { - aggregationSet: []*base.TaskMessage{msg1, msg2, msg3}, - qname: "default", - gname: "mygroup", - setID: setID, + aggregationSet: []base.Z{ + {msg1, now.Add(-3 * time.Minute).Unix()}, + {msg2, now.Add(-2 * time.Minute).Unix()}, + {msg3, now.Add(-1 * time.Minute).Unix()}, + }, + qname: "default", + gname: "mygroup", + setID: setID, }, } @@ -3393,8 +3398,8 @@ func TestDeleteAggregationSet(t *testing.T) { } // Check all tasks in the set are deleted. - for _, m := range tc.aggregationSet { - taskKey := base.TaskKey(m.Queue, m.ID) + for _, z := range tc.aggregationSet { + taskKey := base.TaskKey(z.Message.Queue, z.Message.ID) if r.client.Exists(ctx, taskKey).Val() != 0 { t.Errorf("task key %q still exists", taskKey) }