Implement RDB.AggregationCheck

This commit is contained in:
Ken Hibino
2022-03-09 17:05:16 -08:00
parent 4542b52da8
commit 99c00bffeb
6 changed files with 339 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/errors"
"github.com/hibiken/asynq/internal/timeutil"
@@ -988,19 +989,139 @@ func (r *RDB) ListGroups(qname string) ([]string, error) {
return nil, nil
}
// TODO: Add comment describing what the script does.
// KEYS[1] -> asynq:{<qname>}:g:<gname>
// KEYS[2] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id>
// KEYS[3] -> asynq:{<qname>}:aggregation_sets
// -------
// ARGV[1] -> max group size
// ARGV[2] -> max group delay in unix time
// ARGV[3] -> start time of the grace period
// ARGV[4] -> aggregation set ID
// ARGV[5] -> aggregation set expire time
//
// Output:
// Returns 0 if no aggregation set was created
// Returns 1 if an aggregation set was created
var aggregationCheckCmd = redis.NewScript(`
local size = redis.call("ZCARD", KEYS[1])
local maxSize = tonumber(ARGV[1])
if size >= maxSize then
local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1)
for _, msg in ipairs(msgs) do
redis.call("SADD", KEYS[2], msg)
end
redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1)
redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4])
return 1
end
local oldestEntry = redis.call("ZRANGE", KEYS[1], 0, 0, "WITHSCORES")
local oldestEntryScore = tonumber(oldestEntry[2])
local maxDelayTime = tonumber(ARGV[2])
if oldestEntryScore <= maxDelayTime then
local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1)
for _, msg in ipairs(msgs) do
redis.call("SADD", KEYS[2], msg)
end
redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1)
redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4])
return 1
end
local latestEntry = redis.call("ZREVRANGE", KEYS[1], 0, 0, "WITHSCORES")
local latestEntryScore = tonumber(latestEntry[2])
local gracePeriodStartTime = tonumber(ARGV[3])
if latestEntryScore <= gracePeriodStartTime then
local msgs = redis.call("ZRANGE", KEYS[1], 0, maxSize-1)
for _, msg in ipairs(msgs) do
redis.call("SADD", KEYS[2], msg)
end
redis.call("ZREMRANGEBYRANK", KEYS[1], 0, maxSize-1)
redis.call("ZADD", KEYS[3], ARGV[5], ARGV[4])
return 1
end
return 0
`)
// Task aggregation should finish within this timeout.
// Otherwise an aggregation set should be reclaimed by the recoverer.
const aggregationTimeout = 2 * time.Minute
// AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the
// group are ready to be aggregated. If so, it moves the tasks to be aggregated to a aggregation set and returns
// set ID. If not, it returns an empty string for the set ID.
func (r *RDB) AggregationCheck(qname, gname string) (string, error) {
// TODO: Implement this with TDD
return "", nil
//
// Note: It assumes that this function is called at frequency less than or equal to the gracePeriod. In other words,
// the function only checks the most recently added task aganist the given gracePeriod.
func (r *RDB) AggregationCheck(qname, gname string, gracePeriodStartTime, maxDelayTime time.Time, maxSize int) (string, error) {
var op errors.Op = "RDB.AggregationCheck"
aggregationSetID := uuid.NewString()
expireTime := r.clock.Now().Add(aggregationTimeout)
keys := []string{
base.GroupKey(qname, gname),
base.AggregationSetKey(qname, gname, aggregationSetID),
base.AllAggregationSets(qname),
}
argv := []interface{}{
maxSize,
maxDelayTime.Unix(),
gracePeriodStartTime.Unix(),
aggregationSetID,
expireTime.Unix(),
}
n, err := r.runScriptWithErrorCode(context.Background(), op, aggregationCheckCmd, keys, argv...)
if err != nil {
return "", err
}
switch n {
case 0:
return "", nil
case 1:
return aggregationSetID, nil
default:
return "", errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from lua script: %d", n))
}
}
// ReadAggregationSet retrieves memebers of an aggregation set and returns list of tasks and
// KEYS[1] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id>
// ------
// ARGV[1] -> task key prefix
var readAggregationSetCmd = redis.NewScript(`
local msgs = {}
local ids = redis.call("SMEMBERS", KEYS[1])
for _, id in ipairs(ids) do
local key = ARGV[1] .. id
table.insert(msgs, redis.call("HGET", key, "msg"))
end
return msgs
`)
// ReadAggregationSet retrieves members of an aggregation set and returns a list of tasks in the set and
// the deadline for aggregating those tasks.
func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessage, time.Time, error) {
// TODO: Implement this with TDD
return nil, time.Time{}, nil
var op errors.Op = "RDB.ReadAggregationSet"
ctx := context.Background()
res, err := readAggregationSetCmd.Run(ctx, r.client,
[]string{base.AggregationSetKey(qname, gname, setID)}, base.TaskKeyPrefix(qname)).Result()
if err != nil {
return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "smembers", Err: err})
}
data, err := cast.ToStringSliceE(res)
if err != nil {
return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cast error: Lua script returned unexpected value: %v", res))
}
var msgs []*base.TaskMessage
for _, s := range data {
msg, err := base.DecodeMessage([]byte(s))
if err != nil {
return nil, time.Time{}, errors.E(op, errors.Internal, fmt.Sprintf("cannot decode message: %v", err))
}
msgs = append(msgs, msg)
}
deadlineUnix, err := r.client.ZScore(ctx, base.AllAggregationSets(qname), setID).Result()
if err != nil {
return nil, time.Time{}, errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "zscore", Err: err})
}
return msgs, time.Unix(int64(deadlineUnix), 0), nil
}
// DeleteAggregationSet deletes the aggregation set identified by the parameters.