From 99c00bffeb8dbcbc5e92ada091d4125d6fa9ee44 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Wed, 9 Mar 2022 17:05:16 -0800 Subject: [PATCH] Implement RDB.AggregationCheck --- aggregator.go | 13 +-- internal/asynqtest/asynqtest.go | 19 ++++ internal/asynqtest/builder.go | 4 + internal/base/base.go | 2 +- internal/rdb/rdb.go | 133 +++++++++++++++++++++-- internal/rdb/rdb_test.go | 181 ++++++++++++++++++++++++++++++++ 6 files changed, 339 insertions(+), 13 deletions(-) diff --git a/aggregator.go b/aggregator.go index d2d3089..2ad30ed 100644 --- a/aggregator.go +++ b/aggregator.go @@ -103,17 +103,17 @@ func (a *aggregator) start(wg *sync.WaitGroup) { a.logger.Debug("Aggregator done") ticker.Stop() return - case <-ticker.C: - a.exec() + case t := <-ticker.C: + a.exec(t) } } }() } -func (a *aggregator) exec() { +func (a *aggregator) exec(t time.Time) { select { case a.sema <- struct{}{}: // acquire token - go a.aggregate() + go a.aggregate(t) default: // If the semaphore blocks, then we are currently running max number of // aggregation checks. Skip this round and log warning. @@ -121,7 +121,7 @@ func (a *aggregator) exec() { } } -func (a *aggregator) aggregate() { +func (a *aggregator) aggregate(t time.Time) { defer func() { <-a.sema /* release token */ }() for _, qname := range a.queues { groups, err := a.broker.ListGroups(qname) @@ -130,7 +130,8 @@ func (a *aggregator) aggregate() { continue } for _, gname := range groups { - aggregationSetID, err := a.broker.AggregationCheck(qname, gname, a.gracePeriod, a.maxDelay, a.maxSize) + aggregationSetID, err := a.broker.AggregationCheck( + qname, gname, t.Add(-a.gracePeriod), t.Add(-a.maxDelay), a.maxSize) if err != nil { a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname) continue diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 259c82b..f36a37c 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -245,6 +245,13 @@ func SeedCompletedQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z seedRedisZSet(tb, r, base.CompletedKey(qname), entries, base.TaskStateCompleted) } +// SeedGroup initializes the group with the given entries. +func SeedGroup(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, gname string) { + tb.Helper() + r.SAdd(context.Background(), base.AllQueues, qname) + seedRedisZSet(tb, r, base.GroupKey(qname, gname), entries, base.TaskStateAggregating) +} + // SeedAllPendingQueues initializes all of the specified queues with the given messages. // // pending maps a queue name to a list of messages. @@ -303,6 +310,18 @@ func SeedAllCompletedQueues(tb testing.TB, r redis.UniversalClient, completed ma } } +// SeedAllGroups initializes all groups in all queues. +// The map maps queue names to group names which maps to a list of task messages and the time it was +// added to the group. +func SeedAllGroups(tb testing.TB, r redis.UniversalClient, groups map[string]map[string][]base.Z) { + tb.Helper() + for qname, g := range groups { + for gname, entries := range g { + SeedGroup(tb, r, entries, qname, gname) + } + } +} + func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, msgs []*base.TaskMessage, state base.TaskState) { tb.Helper() diff --git a/internal/asynqtest/builder.go b/internal/asynqtest/builder.go index 168725c..2c9498f 100644 --- a/internal/asynqtest/builder.go +++ b/internal/asynqtest/builder.go @@ -26,6 +26,10 @@ type TaskMessageBuilder struct { msg *base.TaskMessage } +func NewTaskMessageBuilder() *TaskMessageBuilder { + return &TaskMessageBuilder{} +} + func (b *TaskMessageBuilder) lazyInit() { if b.msg == nil { b.msg = makeDefaultTaskMessage() diff --git a/internal/base/base.go b/internal/base/base.go index 6b9606d..fca070d 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -731,7 +731,7 @@ type Broker interface { AddToGroup(ctx context.Context, msg *TaskMessage, gname string) error AddToGroupUnique(ctx context.Context, msg *TaskMessage, groupKey string, ttl time.Duration) error ListGroups(qname string) ([]string, error) - AggregationCheck(qname, gname string) (aggregationSetID string, err error) + AggregationCheck(qname, gname string, gracePeriodStartTime, maxDelayTime time.Time, maxSize int) (aggregationSetID string, err error) ReadAggregationSet(qname, gname, aggregationSetID string) ([]*TaskMessage, time.Time, error) DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 4f71f14..0ba9690 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -12,6 +12,7 @@ import ( "time" "github.com/go-redis/redis/v8" + "github.com/google/uuid" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/timeutil" @@ -988,19 +989,139 @@ func (r *RDB) ListGroups(qname string) ([]string, error) { return nil, nil } +// TODO: Add comment describing what the script does. +// KEYS[1] -> asynq:{}:g: +// KEYS[2] -> asynq:{}:g:: +// KEYS[3] -> asynq:{}:aggregation_sets +// ------- +// ARGV[1] -> max group size +// ARGV[2] -> max group delay in unix time +// ARGV[3] -> start time of the grace period +// ARGV[4] -> aggregation set ID +// ARGV[5] -> aggregation set expire time +// +// Output: +// Returns 0 if no aggregation set was created +// Returns 1 if an aggregation set was created +var aggregationCheckCmd = redis.NewScript(` +local size = redis.call("ZCARD", KEYS[1]) +local maxSize = tonumber(ARGV[1]) +if size >= maxSize then + local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) + for _, msg in ipairs(msgs) do + redis.call("SADD", KEYS[2], msg) + end + redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) + redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) + return 1 +end +local oldestEntry = redis.call("ZRANGE", KEYS[1], 0, 0, "WITHSCORES") +local oldestEntryScore = tonumber(oldestEntry[2]) +local maxDelayTime = tonumber(ARGV[2]) +if oldestEntryScore <= maxDelayTime then + local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) + for _, msg in ipairs(msgs) do + redis.call("SADD", KEYS[2], msg) + end + redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) + redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) + return 1 +end +local latestEntry = redis.call("ZREVRANGE", KEYS[1], 0, 0, "WITHSCORES") +local latestEntryScore = tonumber(latestEntry[2]) +local gracePeriodStartTime = tonumber(ARGV[3]) +if latestEntryScore <= gracePeriodStartTime then + local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) + for _, msg in ipairs(msgs) do + redis.call("SADD", KEYS[2], msg) + end + redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) + redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) + return 1 +end +return 0 +`) + +// Task aggregation should finish within this timeout. +// Otherwise an aggregation set should be reclaimed by the recoverer. +const aggregationTimeout = 2 * time.Minute + // AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the // group are ready to be aggregated. If so, it moves the tasks to be aggregated to a aggregation set and returns // set ID. If not, it returns an empty string for the set ID. -func (r *RDB) AggregationCheck(qname, gname string) (string, error) { - // TODO: Implement this with TDD - return "", nil +// +// Note: It assumes that this function is called at frequency less than or equal to the gracePeriod. In other words, +// the function only checks the most recently added task aganist the given gracePeriod. +func (r *RDB) AggregationCheck(qname, gname string, gracePeriodStartTime, maxDelayTime time.Time, maxSize int) (string, error) { + var op errors.Op = "RDB.AggregationCheck" + aggregationSetID := uuid.NewString() + expireTime := r.clock.Now().Add(aggregationTimeout) + keys := []string{ + base.GroupKey(qname, gname), + base.AggregationSetKey(qname, gname, aggregationSetID), + base.AllAggregationSets(qname), + } + argv := []interface{}{ + maxSize, + maxDelayTime.Unix(), + gracePeriodStartTime.Unix(), + aggregationSetID, + expireTime.Unix(), + } + n, err := r.runScriptWithErrorCode(context.Background(), op, aggregationCheckCmd, keys, argv...) + if err != nil { + return "", err + } + switch n { + case 0: + return "", nil + case 1: + return aggregationSetID, nil + default: + return "", errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from lua script: %d", n)) + } } -// ReadAggregationSet retrieves memebers of an aggregation set and returns list of tasks and +// KEYS[1] -> asynq:{}:g:: +// ------ +// ARGV[1] -> task key prefix +var readAggregationSetCmd = redis.NewScript(` +local msgs = {} +local ids = redis.call("SMEMBERS", KEYS[1]) +for _, id in ipairs(ids) do + local key = ARGV[1] .. id + table.insert(msgs, redis.call("HGET", key, "msg")) +end +return msgs +`) + +// ReadAggregationSet retrieves members of an aggregation set and returns a list of tasks in the set and // the deadline for aggregating those tasks. func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessage, time.Time, error) { - // TODO: Implement this with TDD - return nil, time.Time{}, nil + var op errors.Op = "RDB.ReadAggregationSet" + ctx := context.Background() + res, err := readAggregationSetCmd.Run(ctx, r.client, + []string{base.AggregationSetKey(qname, gname, setID)}, base.TaskKeyPrefix(qname)).Result() + if err != nil { + return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "smembers", Err: err}) + } + data, err := cast.ToStringSliceE(res) + if err != nil { + return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cast error: Lua script returned unexpected value: %v", res)) + } + var msgs []*base.TaskMessage + for _, s := range data { + msg, err := base.DecodeMessage([]byte(s)) + if err != nil { + return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cannot decode message: %v", err)) + } + msgs = append(msgs, msg) + } + deadlineUnix, err := r.client.ZScore(ctx, base.AllAggregationSets(qname), setID).Result() + if err != nil { + return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zscore", Err: err}) + } + return msgs, time.Unix(int64(deadlineUnix), 0), nil } // DeleteAggregationSet deletes the aggregation set identified by the parameters. diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 9444961..7d973ec 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3103,3 +3103,184 @@ func TestWriteResult(t *testing.T) { } } } + +func TestAggregationCheck(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + + msg1 := h.NewTaskMessageBuilder().SetType("task1").SetGroup("mygroup").Build() + msg2 := h.NewTaskMessageBuilder().SetType("task2").SetGroup("mygroup").Build() + msg3 := h.NewTaskMessageBuilder().SetType("task3").SetGroup("mygroup").Build() + msg4 := h.NewTaskMessageBuilder().SetType("task4").SetGroup("mygroup").Build() + msg5 := h.NewTaskMessageBuilder().SetType("task5").SetGroup("mygroup").Build() + + tests := []struct { + desc string + groups map[string]map[string][]base.Z + qname string + gname string + gracePeriod time.Duration + maxDelay time.Duration + maxSize int + shouldCreateSet bool // whether the check should create a new aggregation set + wantAggregationSet []*base.TaskMessage + wantGroups map[string]map[string][]base.Z + }{ + { + desc: "with a group size reaching the max size", + groups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + {Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, + {Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, + {Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, + {Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, + {Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, + }, + }, + }, + qname: "default", + gname: "mygroup", + gracePeriod: 1 * time.Minute, + maxDelay: 10 * time.Minute, + maxSize: 5, + shouldCreateSet: true, + wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3, msg4, msg5}, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": {}, + }, + }, + }, + { + desc: "with group size greater than max size", + groups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + {Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, + {Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, + {Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, + {Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, + {Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, + }, + }, + }, + qname: "default", + gname: "mygroup", + gracePeriod: 2 * time.Minute, + maxDelay: 10 * time.Minute, + maxSize: 3, + shouldCreateSet: true, + wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3}, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + {Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, + {Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, + }, + }, + }, + }, + { + desc: "with the most recent task older than grace period", + groups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + {Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, + {Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, + {Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, + }, + }, + }, + qname: "default", + gname: "mygroup", + gracePeriod: 1 * time.Minute, + maxDelay: 10 * time.Minute, + maxSize: 5, + shouldCreateSet: true, + wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3}, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": {}, + }, + }, + }, + { + desc: "with the oldest task older than max delay", + groups: map[string]map[string][]base.Z{ + "default": { + "mygroup": { + {Message: msg1, Score: now.Add(-15 * time.Minute).Unix()}, + {Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, + {Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, + {Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, + {Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, + }, + }, + }, + qname: "default", + gname: "mygroup", + gracePeriod: 2 * time.Minute, + maxDelay: 10 * time.Minute, + maxSize: 30, + shouldCreateSet: true, + wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3, msg4, msg5}, + wantGroups: map[string]map[string][]base.Z{ + "default": { + "mygroup": {}, + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + h.SeedAllGroups(t, r.client, tc.groups) + + gracePeriodStartTime := now.Add(-tc.gracePeriod) + maxDelayTime := now.Add(-tc.maxDelay) + aggregationSetID, err := r.AggregationCheck(tc.qname, tc.gname, gracePeriodStartTime, maxDelayTime, tc.maxSize) + if err != nil { + t.Errorf("%s: AggregationCheck returned error: %v", tc.desc, err) + continue + } + + if !tc.shouldCreateSet && aggregationSetID != "" { + t.Errorf("%s: AggregationCheck returned non empty set ID. want empty ID", tc.desc) + continue + } + if tc.shouldCreateSet && aggregationSetID == "" { + t.Errorf("%s: AggregationCheck returned empty set ID. want non empty ID", tc.desc) + continue + } + + if !tc.shouldCreateSet { + continue // below checks are intended for aggregation set + } + + msgs, deadline, err := r.ReadAggregationSet(tc.qname, tc.gname, aggregationSetID) + if err != nil { + t.Fatalf("%s: Failed to read aggregation set %q: %v", tc.desc, aggregationSetID, err) + } + if diff := cmp.Diff(tc.wantAggregationSet, msgs, h.SortMsgOpt); diff != "" { + t.Errorf("%s: Mismatch found in aggregation set: (-want,+got)\n%s", tc.desc, diff) + } + + if wantDeadline := now.Add(aggregationTimeout); deadline.Unix() != wantDeadline.Unix() { + t.Errorf("%s: ReadAggregationSet returned deadline=%v, want=%v", tc.desc, deadline, wantDeadline) + } + + for qname, groups := range tc.wantGroups { + for gname, want := range groups { + gotGroup := h.GetGroupEntries(t, r.client, qname, gname) + if diff := cmp.Diff(want, gotGroup, h.SortZSetEntryOpt); diff != "" { + t.Errorf("%s: Mismatch found in group zset: %q: (-want,+got)\n%s", + tc.desc, base.GroupKey(qname, gname), diff) + } + } + } + } +}