mirror of
				https://github.com/hibiken/asynq.git
				synced 2025-10-26 11:16:12 +08:00 
			
		
		
		
	Implement RDB.AggregationCheck
This commit is contained in:
		| @@ -103,17 +103,17 @@ func (a *aggregator) start(wg *sync.WaitGroup) { | ||||
| 				a.logger.Debug("Aggregator done") | ||||
| 				ticker.Stop() | ||||
| 				return | ||||
| 			case <-ticker.C: | ||||
| 				a.exec() | ||||
| 			case t := <-ticker.C: | ||||
| 				a.exec(t) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  | ||||
| func (a *aggregator) exec() { | ||||
| func (a *aggregator) exec(t time.Time) { | ||||
| 	select { | ||||
| 	case a.sema <- struct{}{}: // acquire token | ||||
| 		go a.aggregate() | ||||
| 		go a.aggregate(t) | ||||
| 	default: | ||||
| 		// If the semaphore blocks, then we are currently running max number of | ||||
| 		// aggregation checks. Skip this round and log warning. | ||||
| @@ -121,7 +121,7 @@ func (a *aggregator) exec() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *aggregator) aggregate() { | ||||
| func (a *aggregator) aggregate(t time.Time) { | ||||
| 	defer func() { <-a.sema /* release token */ }() | ||||
| 	for _, qname := range a.queues { | ||||
| 		groups, err := a.broker.ListGroups(qname) | ||||
| @@ -130,7 +130,8 @@ func (a *aggregator) aggregate() { | ||||
| 			continue | ||||
| 		} | ||||
| 		for _, gname := range groups { | ||||
| 			aggregationSetID, err := a.broker.AggregationCheck(qname, gname, a.gracePeriod, a.maxDelay, a.maxSize) | ||||
| 			aggregationSetID, err := a.broker.AggregationCheck( | ||||
| 				qname, gname, t.Add(-a.gracePeriod), t.Add(-a.maxDelay), a.maxSize) | ||||
| 			if err != nil { | ||||
| 				a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname) | ||||
| 				continue | ||||
|   | ||||
| @@ -245,6 +245,13 @@ func SeedCompletedQueue(tb testing.TB, r redis.UniversalClient, entries []base.Z | ||||
| 	seedRedisZSet(tb, r, base.CompletedKey(qname), entries, base.TaskStateCompleted) | ||||
| } | ||||
|  | ||||
| // SeedGroup initializes the group with the given entries. | ||||
| func SeedGroup(tb testing.TB, r redis.UniversalClient, entries []base.Z, qname, gname string) { | ||||
| 	tb.Helper() | ||||
| 	r.SAdd(context.Background(), base.AllQueues, qname) | ||||
| 	seedRedisZSet(tb, r, base.GroupKey(qname, gname), entries, base.TaskStateAggregating) | ||||
| } | ||||
|  | ||||
| // SeedAllPendingQueues initializes all of the specified queues with the given messages. | ||||
| // | ||||
| // pending maps a queue name to a list of messages. | ||||
| @@ -303,6 +310,18 @@ func SeedAllCompletedQueues(tb testing.TB, r redis.UniversalClient, completed ma | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // SeedAllGroups initializes all groups in all queues. | ||||
| // The map maps queue names to group names which maps to a list of task messages and the time it was | ||||
| // added to the group. | ||||
| func SeedAllGroups(tb testing.TB, r redis.UniversalClient, groups map[string]map[string][]base.Z) { | ||||
| 	tb.Helper() | ||||
| 	for qname, g := range groups { | ||||
| 		for gname, entries := range g { | ||||
| 			SeedGroup(tb, r, entries, qname, gname) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func seedRedisList(tb testing.TB, c redis.UniversalClient, key string, | ||||
| 	msgs []*base.TaskMessage, state base.TaskState) { | ||||
| 	tb.Helper() | ||||
|   | ||||
| @@ -26,6 +26,10 @@ type TaskMessageBuilder struct { | ||||
| 	msg *base.TaskMessage | ||||
| } | ||||
|  | ||||
| func NewTaskMessageBuilder() *TaskMessageBuilder { | ||||
| 	return &TaskMessageBuilder{} | ||||
| } | ||||
|  | ||||
| func (b *TaskMessageBuilder) lazyInit() { | ||||
| 	if b.msg == nil { | ||||
| 		b.msg = makeDefaultTaskMessage() | ||||
|   | ||||
| @@ -731,7 +731,7 @@ type Broker interface { | ||||
| 	AddToGroup(ctx context.Context, msg *TaskMessage, gname string) error | ||||
| 	AddToGroupUnique(ctx context.Context, msg *TaskMessage, groupKey string, ttl time.Duration) error | ||||
| 	ListGroups(qname string) ([]string, error) | ||||
| 	AggregationCheck(qname, gname string) (aggregationSetID string, err error) | ||||
| 	AggregationCheck(qname, gname string, gracePeriodStartTime, maxDelayTime time.Time, maxSize int) (aggregationSetID string, err error) | ||||
| 	ReadAggregationSet(qname, gname, aggregationSetID string) ([]*TaskMessage, time.Time, error) | ||||
| 	DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error | ||||
|  | ||||
|   | ||||
| @@ -12,6 +12,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-redis/redis/v8" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/hibiken/asynq/internal/base" | ||||
| 	"github.com/hibiken/asynq/internal/errors" | ||||
| 	"github.com/hibiken/asynq/internal/timeutil" | ||||
| @@ -988,19 +989,139 @@ func (r *RDB) ListGroups(qname string) ([]string, error) { | ||||
| 	return nil, nil | ||||
| } | ||||
|  | ||||
| // TODO: Add comment describing what the script does. | ||||
| // KEYS[1] -> asynq:{<qname>}:g:<gname> | ||||
| // KEYS[2] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id> | ||||
| // KEYS[3] -> asynq:{<qname>}:aggregation_sets | ||||
| // ------- | ||||
| // ARGV[1] -> max group size | ||||
| // ARGV[2] -> max group delay in unix time | ||||
| // ARGV[3] -> start time of the grace period | ||||
| // ARGV[4] -> aggregation set ID | ||||
| // ARGV[5] -> aggregation set expire time | ||||
| // | ||||
| // Output: | ||||
| // Returns 0 if no aggregation set was created | ||||
| // Returns 1 if an aggregation set was created | ||||
| var aggregationCheckCmd = redis.NewScript(` | ||||
| local size = redis.call("ZCARD", KEYS[1]) | ||||
| local maxSize = tonumber(ARGV[1]) | ||||
| if size >= maxSize then | ||||
| 	local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) | ||||
| 	for _, msg in ipairs(msgs) do | ||||
| 		redis.call("SADD", KEYS[2], msg) | ||||
| 	end | ||||
| 	redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) | ||||
| 	redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) | ||||
| 	return 1 | ||||
| end | ||||
| local oldestEntry = redis.call("ZRANGE", KEYS[1], 0, 0, "WITHSCORES") | ||||
| local oldestEntryScore = tonumber(oldestEntry[2]) | ||||
| local maxDelayTime = tonumber(ARGV[2]) | ||||
| if oldestEntryScore <= maxDelayTime then | ||||
| 	local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) | ||||
| 	for _, msg in ipairs(msgs) do | ||||
| 		redis.call("SADD", KEYS[2], msg) | ||||
| 	end | ||||
| 	redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) | ||||
| 	redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) | ||||
| 	return 1 | ||||
| end | ||||
| local latestEntry = redis.call("ZREVRANGE", KEYS[1], 0, 0, "WITHSCORES") | ||||
| local latestEntryScore = tonumber(latestEntry[2]) | ||||
| local gracePeriodStartTime = tonumber(ARGV[3]) | ||||
| if latestEntryScore <= gracePeriodStartTime then | ||||
| 	local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1) | ||||
| 	for _, msg in ipairs(msgs) do | ||||
| 		redis.call("SADD", KEYS[2], msg) | ||||
| 	end | ||||
| 	redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1) | ||||
| 	redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4]) | ||||
| 	return 1 | ||||
| end | ||||
| return 0 | ||||
| `) | ||||
|  | ||||
| // Task aggregation should finish within this timeout. | ||||
| // Otherwise an aggregation set should be reclaimed by the recoverer. | ||||
| const aggregationTimeout = 2 * time.Minute | ||||
|  | ||||
| // AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the | ||||
| // group are ready to be aggregated. If so, it moves the tasks to be aggregated to a aggregation set and returns | ||||
| // set ID. If not, it returns an empty string for the set ID. | ||||
| func (r *RDB) AggregationCheck(qname, gname string) (string, error) { | ||||
| 	// TODO: Implement this with TDD | ||||
| 	return "", nil | ||||
| // | ||||
| // Note: It assumes that this function is called at frequency less than or equal to the gracePeriod. In other words, | ||||
| // the function only checks the most recently added task aganist the given gracePeriod. | ||||
| func (r *RDB) AggregationCheck(qname, gname string, gracePeriodStartTime, maxDelayTime time.Time, maxSize int) (string, error) { | ||||
| 	var op errors.Op = "RDB.AggregationCheck" | ||||
| 	aggregationSetID := uuid.NewString() | ||||
| 	expireTime := r.clock.Now().Add(aggregationTimeout) | ||||
| 	keys := []string{ | ||||
| 		base.GroupKey(qname, gname), | ||||
| 		base.AggregationSetKey(qname, gname, aggregationSetID), | ||||
| 		base.AllAggregationSets(qname), | ||||
| 	} | ||||
| 	argv := []interface{}{ | ||||
| 		maxSize, | ||||
| 		maxDelayTime.Unix(), | ||||
| 		gracePeriodStartTime.Unix(), | ||||
| 		aggregationSetID, | ||||
| 		expireTime.Unix(), | ||||
| 	} | ||||
| 	n, err := r.runScriptWithErrorCode(context.Background(), op, aggregationCheckCmd, keys, argv...) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	switch n { | ||||
| 	case 0: | ||||
| 		return "", nil | ||||
| 	case 1: | ||||
| 		return aggregationSetID, nil | ||||
| 	default: | ||||
| 		return "", errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from lua script: %d", n)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadAggregationSet retrieves memebers of an aggregation set and returns list of tasks and | ||||
| // KEYS[1] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id> | ||||
| // ------ | ||||
| // ARGV[1] -> task key prefix | ||||
| var readAggregationSetCmd = redis.NewScript(` | ||||
| local msgs = {} | ||||
| local ids = redis.call("SMEMBERS", KEYS[1]) | ||||
| for _, id in ipairs(ids) do | ||||
| 	local key = ARGV[1] .. id | ||||
| 	table.insert(msgs, redis.call("HGET", key, "msg")) | ||||
| end | ||||
| return msgs | ||||
| `) | ||||
|  | ||||
| // ReadAggregationSet retrieves members of an aggregation set and returns a list of tasks in the set and | ||||
| // the deadline for aggregating those tasks. | ||||
| func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessage, time.Time, error) { | ||||
| 	// TODO: Implement this with TDD | ||||
| 	return nil, time.Time{}, nil | ||||
| 	var op errors.Op = "RDB.ReadAggregationSet" | ||||
| 	ctx := context.Background() | ||||
| 	res, err := readAggregationSetCmd.Run(ctx, r.client, | ||||
| 		[]string{base.AggregationSetKey(qname, gname, setID)}, base.TaskKeyPrefix(qname)).Result() | ||||
| 	if err != nil { | ||||
| 		return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "smembers", Err: err}) | ||||
| 	} | ||||
| 	data, err := cast.ToStringSliceE(res) | ||||
| 	if err != nil { | ||||
| 		return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cast error: Lua script returned unexpected value: %v", res)) | ||||
| 	} | ||||
| 	var msgs []*base.TaskMessage | ||||
| 	for _, s := range data { | ||||
| 		msg, err := base.DecodeMessage([]byte(s)) | ||||
| 		if err != nil { | ||||
| 			return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cannot decode message: %v", err)) | ||||
| 		} | ||||
| 		msgs = append(msgs, msg) | ||||
| 	} | ||||
| 	deadlineUnix, err := r.client.ZScore(ctx, base.AllAggregationSets(qname), setID).Result() | ||||
| 	if err != nil { | ||||
| 		return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zscore", Err: err}) | ||||
| 	} | ||||
| 	return msgs, time.Unix(int64(deadlineUnix), 0), nil | ||||
| } | ||||
|  | ||||
| // DeleteAggregationSet deletes the aggregation set identified by the parameters. | ||||
|   | ||||
| @@ -3103,3 +3103,184 @@ func TestWriteResult(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAggregationCheck(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	defer r.Close() | ||||
|  | ||||
| 	now := time.Now() | ||||
| 	r.SetClock(timeutil.NewSimulatedClock(now)) | ||||
|  | ||||
| 	msg1 := h.NewTaskMessageBuilder().SetType("task1").SetGroup("mygroup").Build() | ||||
| 	msg2 := h.NewTaskMessageBuilder().SetType("task2").SetGroup("mygroup").Build() | ||||
| 	msg3 := h.NewTaskMessageBuilder().SetType("task3").SetGroup("mygroup").Build() | ||||
| 	msg4 := h.NewTaskMessageBuilder().SetType("task4").SetGroup("mygroup").Build() | ||||
| 	msg5 := h.NewTaskMessageBuilder().SetType("task5").SetGroup("mygroup").Build() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		desc               string | ||||
| 		groups             map[string]map[string][]base.Z | ||||
| 		qname              string | ||||
| 		gname              string | ||||
| 		gracePeriod        time.Duration | ||||
| 		maxDelay           time.Duration | ||||
| 		maxSize            int | ||||
| 		shouldCreateSet    bool // whether the check should create a new aggregation set | ||||
| 		wantAggregationSet []*base.TaskMessage | ||||
| 		wantGroups         map[string]map[string][]base.Z | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc: "with a group size reaching the max size", | ||||
| 			groups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": { | ||||
| 						{Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, | ||||
| 						{Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, | ||||
| 						{Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, | ||||
| 						{Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, | ||||
| 						{Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			qname:              "default", | ||||
| 			gname:              "mygroup", | ||||
| 			gracePeriod:        1 * time.Minute, | ||||
| 			maxDelay:           10 * time.Minute, | ||||
| 			maxSize:            5, | ||||
| 			shouldCreateSet:    true, | ||||
| 			wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3, msg4, msg5}, | ||||
| 			wantGroups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "with group size greater than max size", | ||||
| 			groups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": { | ||||
| 						{Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, | ||||
| 						{Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, | ||||
| 						{Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, | ||||
| 						{Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, | ||||
| 						{Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			qname:              "default", | ||||
| 			gname:              "mygroup", | ||||
| 			gracePeriod:        2 * time.Minute, | ||||
| 			maxDelay:           10 * time.Minute, | ||||
| 			maxSize:            3, | ||||
| 			shouldCreateSet:    true, | ||||
| 			wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3}, | ||||
| 			wantGroups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": { | ||||
| 						{Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, | ||||
| 						{Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "with the most recent task older than grace period", | ||||
| 			groups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": { | ||||
| 						{Message: msg1, Score: now.Add(-5 * time.Minute).Unix()}, | ||||
| 						{Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, | ||||
| 						{Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			qname:              "default", | ||||
| 			gname:              "mygroup", | ||||
| 			gracePeriod:        1 * time.Minute, | ||||
| 			maxDelay:           10 * time.Minute, | ||||
| 			maxSize:            5, | ||||
| 			shouldCreateSet:    true, | ||||
| 			wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3}, | ||||
| 			wantGroups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "with the oldest task older than max delay", | ||||
| 			groups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": { | ||||
| 						{Message: msg1, Score: now.Add(-15 * time.Minute).Unix()}, | ||||
| 						{Message: msg2, Score: now.Add(-3 * time.Minute).Unix()}, | ||||
| 						{Message: msg3, Score: now.Add(-2 * time.Minute).Unix()}, | ||||
| 						{Message: msg4, Score: now.Add(-1 * time.Minute).Unix()}, | ||||
| 						{Message: msg5, Score: now.Add(-10 * time.Second).Unix()}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			qname:              "default", | ||||
| 			gname:              "mygroup", | ||||
| 			gracePeriod:        2 * time.Minute, | ||||
| 			maxDelay:           10 * time.Minute, | ||||
| 			maxSize:            30, | ||||
| 			shouldCreateSet:    true, | ||||
| 			wantAggregationSet: []*base.TaskMessage{msg1, msg2, msg3, msg4, msg5}, | ||||
| 			wantGroups: map[string]map[string][]base.Z{ | ||||
| 				"default": { | ||||
| 					"mygroup": {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		h.FlushDB(t, r.client) | ||||
| 		h.SeedAllGroups(t, r.client, tc.groups) | ||||
|  | ||||
| 		gracePeriodStartTime := now.Add(-tc.gracePeriod) | ||||
| 		maxDelayTime := now.Add(-tc.maxDelay) | ||||
| 		aggregationSetID, err := r.AggregationCheck(tc.qname, tc.gname, gracePeriodStartTime, maxDelayTime, tc.maxSize) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("%s: AggregationCheck returned error: %v", tc.desc, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if !tc.shouldCreateSet && aggregationSetID != "" { | ||||
| 			t.Errorf("%s: AggregationCheck returned non empty set ID. want empty ID", tc.desc) | ||||
| 			continue | ||||
| 		} | ||||
| 		if tc.shouldCreateSet && aggregationSetID == "" { | ||||
| 			t.Errorf("%s: AggregationCheck returned empty set ID. want non empty ID", tc.desc) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if !tc.shouldCreateSet { | ||||
| 			continue // below checks are intended for aggregation set | ||||
| 		} | ||||
|  | ||||
| 		msgs, deadline, err := r.ReadAggregationSet(tc.qname, tc.gname, aggregationSetID) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("%s: Failed to read aggregation set %q: %v", tc.desc, aggregationSetID, err) | ||||
| 		} | ||||
| 		if diff := cmp.Diff(tc.wantAggregationSet, msgs, h.SortMsgOpt); diff != "" { | ||||
| 			t.Errorf("%s: Mismatch found in aggregation set: (-want,+got)\n%s", tc.desc, diff) | ||||
| 		} | ||||
|  | ||||
| 		if wantDeadline := now.Add(aggregationTimeout); deadline.Unix() != wantDeadline.Unix() { | ||||
| 			t.Errorf("%s: ReadAggregationSet returned deadline=%v, want=%v", tc.desc, deadline, wantDeadline) | ||||
| 		} | ||||
|  | ||||
| 		for qname, groups := range tc.wantGroups { | ||||
| 			for gname, want := range groups { | ||||
| 				gotGroup := h.GetGroupEntries(t, r.client, qname, gname) | ||||
| 				if diff := cmp.Diff(want, gotGroup, h.SortZSetEntryOpt); diff != "" { | ||||
| 					t.Errorf("%s: Mismatch found in group zset: %q: (-want,+got)\n%s", | ||||
| 						tc.desc, base.GroupKey(qname, gname), diff) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user