2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-12-26 15:52:18 +08:00

Add RDB.AddToGroup and RDB.AddToGroupUnique methods

This commit is contained in:
Ken Hibino 2022-03-05 06:08:43 -08:00
parent e3d2939a4c
commit 8b582899ad
2 changed files with 350 additions and 7 deletions

View File

@ -497,6 +497,126 @@ func (r *RDB) Requeue(ctx context.Context, msg *base.TaskMessage) error {
return r.runScript(ctx, op, requeueCmd, keys, msg.ID) return r.runScript(ctx, op, requeueCmd, keys, msg.ID)
} }
// KEYS[1] -> asynq:{<qname>}:t:<task_id>
// KEYS[2] -> asynq:{<qname>}:g:<group_key>
// KEYS[3] -> asynq:{<qname>}:groups
// -------
// ARGV[1] -> task message data
// ARGV[2] -> task ID
// ARGV[3] -> current time in Unix time
// ARGV[4] -> group key
//
// Output:
// Returns 1 if successfully added
// Returns 0 if task ID already exists
var addToGroupCmd = redis.NewScript(`
if redis.call("EXISTS", KEYS[1]) == 1 then
return 0
end
redis.call("HSET", KEYS[1],
"msg", ARGV[1],
"state", "aggregating")
redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2])
redis.call("SADD", KEYS[3], ARGV[4])
return 1
`)
func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey string) error {
var op errors.Op = "rdb.AddToGroup"
encoded, err := base.EncodeMessage(msg)
if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
}
if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
}
keys := []string{
base.TaskKey(msg.Queue, msg.ID),
base.GroupKey(msg.Queue, groupKey),
base.AllGroups(msg.Queue),
}
argv := []interface{}{
encoded,
msg.ID,
r.clock.Now().Unix(),
groupKey,
}
n, err := r.runScriptWithErrorCode(ctx, op, addToGroupCmd, keys, argv...)
if err != nil {
return err
}
if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
}
return nil
}
// KEYS[1] -> asynq:{<qname>}:t:<task_id>
// KEYS[2] -> asynq:{<qname>}:g:<group_key>
// KEYS[3] -> asynq:{<qname>}:groups
// KEYS[4] -> unique key
// -------
// ARGV[1] -> task message data
// ARGV[2] -> task ID
// ARGV[3] -> current time in Unix time
// ARGV[4] -> group key
// ARGV[5] -> uniqueness lock TTL
//
// Output:
// Returns 1 if successfully added
// Returns 0 if task ID already exists
// Returns -1 if task unique key already exists
var addToGroupUniqueCmd = redis.NewScript(`
local ok = redis.call("SET", KEYS[4], ARGV[2], "NX", "EX", ARGV[5])
if not ok then
return -1
end
if redis.call("EXISTS", KEYS[1]) == 1 then
return 0
end
redis.call("HSET", KEYS[1],
"msg", ARGV[1],
"state", "aggregating")
redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2])
redis.call("SADD", KEYS[3], ARGV[4])
return 1
`)
func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, groupKey string, ttl time.Duration) error {
var op errors.Op = "rdb.AddToGroupUnique"
encoded, err := base.EncodeMessage(msg)
if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
}
if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
}
keys := []string{
base.TaskKey(msg.Queue, msg.ID),
base.GroupKey(msg.Queue, groupKey),
base.AllGroups(msg.Queue),
base.UniqueKey(msg.Queue, msg.Type, msg.Payload),
}
argv := []interface{}{
encoded,
msg.ID,
r.clock.Now().Unix(),
groupKey,
int(ttl.Seconds()),
}
n, err := r.runScriptWithErrorCode(ctx, op, addToGroupUniqueCmd, keys, argv...)
if err != nil {
return err
}
if n == -1 {
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
}
if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
}
return nil
}
// KEYS[1] -> asynq:{<qname>}:t:<task_id> // KEYS[1] -> asynq:{<qname>}:t:<task_id>
// KEYS[2] -> asynq:{<qname>}:scheduled // KEYS[2] -> asynq:{<qname>}:scheduled
// ------- // -------

View File

@ -1167,6 +1167,232 @@ func TestRequeue(t *testing.T) {
} }
} }
func TestAddToGroup(t *testing.T) {
r := setup(t)
defer r.Close()
now := time.Now()
r.SetClock(timeutil.NewSimulatedClock(now))
msg := h.NewTaskMessage("mytask", []byte("foo"))
ctx := context.Background()
tests := []struct {
msg *base.TaskMessage
groupKey string
}{
{
msg: msg,
groupKey: "mygroup",
},
}
for _, tc := range tests {
h.FlushDB(t, r.client)
err := r.AddToGroup(ctx, tc.msg, tc.groupKey)
if err != nil {
t.Errorf("r.AddToGroup(ctx, msg, %q) returned error: %v", tc.groupKey, err)
continue
}
// Check Group zset has task ID
gkey := base.GroupKey(tc.msg.Queue, tc.groupKey)
zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val()
if n := len(zs); n != 1 {
t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n)
continue
}
if got := zs[0].Member.(string); got != tc.msg.ID {
t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID)
continue
}
if got := int64(zs[0].Score); got != now.Unix() {
t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix())
continue
}
// Check the values under the task key.
taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID)
encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field
decoded := h.MustUnmarshal(t, encoded)
if diff := cmp.Diff(tc.msg, decoded); diff != "" {
t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff)
}
state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field
if want := "aggregating"; state != want {
t.Errorf("state field under task-key is set to %q, want %q", state, want)
}
// Check queue is in the AllQueues set.
if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() {
t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues)
}
}
}
func TestAddToGroupeTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
ctx := context.Background()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
UniqueKey: "unique_key_one",
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
UniqueKey: "unique_key_two",
}
const groupKey = "mygroup"
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.AddToGroup(ctx, tc.firstMsg, groupKey); err != nil {
t.Errorf("First message: AddToGroup failed: %v", err)
continue
}
if err := r.AddToGroup(ctx, tc.secondMsg, groupKey); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: AddToGroup returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestAddToGroupUnique(t *testing.T) {
r := setup(t)
defer r.Close()
now := time.Now()
r.SetClock(timeutil.NewSimulatedClock(now))
msg := h.NewTaskMessage("mytask", []byte("foo"))
msg.UniqueKey = base.UniqueKey(msg.Queue, msg.Type, msg.Payload)
ctx := context.Background()
tests := []struct {
msg *base.TaskMessage
groupKey string
ttl time.Duration
}{
{
msg: msg,
groupKey: "mygroup",
ttl: 30 * time.Second,
},
}
for _, tc := range tests {
h.FlushDB(t, r.client)
err := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl)
if err != nil {
t.Errorf("First message: r.AddToGroupUnique(ctx, msg, %q) returned error: %v", tc.groupKey, err)
continue
}
// Check Group zset has task ID
gkey := base.GroupKey(tc.msg.Queue, tc.groupKey)
zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val()
if n := len(zs); n != 1 {
t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n)
continue
}
if got := zs[0].Member.(string); got != tc.msg.ID {
t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID)
continue
}
if got := int64(zs[0].Score); got != now.Unix() {
t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix())
continue
}
// Check the values under the task key.
taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID)
encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field
decoded := h.MustUnmarshal(t, encoded)
if diff := cmp.Diff(tc.msg, decoded); diff != "" {
t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff)
}
state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field
if want := "aggregating"; state != want {
t.Errorf("state field under task-key is set to %q, want %q", state, want)
}
// Check queue is in the AllQueues set.
if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() {
t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues)
}
got := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl)
if !errors.Is(got, errors.ErrDuplicateTask) {
t.Errorf("Second message: r.AddGroupUnique(ctx, msg, %q) = %v, want %v",
tc.groupKey, got, errors.ErrDuplicateTask)
continue
}
gotTTL := r.client.TTL(ctx, tc.msg.UniqueKey).Val()
if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl)
continue
}
}
}
func TestAddToGroupUniqueTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
ctx := context.Background()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
UniqueKey: "unique_key_one",
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
UniqueKey: "unique_key_two",
}
const groupKey = "mygroup"
const ttl = 30 * time.Second
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.AddToGroupUnique(ctx, tc.firstMsg, groupKey, ttl); err != nil {
t.Errorf("First message: AddToGroupUnique failed: %v", err)
continue
}
if err := r.AddToGroupUnique(ctx, tc.secondMsg, groupKey, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: AddToGroupUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestSchedule(t *testing.T) { func TestSchedule(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()
@ -1183,8 +1409,7 @@ func TestSchedule(t *testing.T) {
err := r.Schedule(context.Background(), tc.msg, tc.processAt) err := r.Schedule(context.Background(), tc.msg, tc.processAt)
if err != nil { if err != nil {
t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", tc.msg, tc.processAt, err)
tc.msg, tc.processAt, err)
continue continue
} }
@ -1192,13 +1417,11 @@ func TestSchedule(t *testing.T) {
scheduledKey := base.ScheduledKey(tc.msg.Queue) scheduledKey := base.ScheduledKey(tc.msg.Queue)
zs := r.client.ZRangeWithScores(context.Background(), scheduledKey, 0, -1).Val() zs := r.client.ZRangeWithScores(context.Background(), scheduledKey, 0, -1).Val()
if n := len(zs); n != 1 { if n := len(zs); n != 1 {
t.Errorf("Redis ZSET %q contains %d elements, want 1", t.Errorf("Redis ZSET %q contains %d elements, want 1", scheduledKey, n)
scheduledKey, n)
continue continue
} }
if got := zs[0].Member.(string); got != tc.msg.ID { if got := zs[0].Member.(string); got != tc.msg.ID {
t.Errorf("Redis ZSET %q member: got %v, want %v", t.Errorf("Redis ZSET %q member: got %v, want %v", scheduledKey, got, tc.msg.ID)
scheduledKey, got, tc.msg.ID)
continue continue
} }
if got := int64(zs[0].Score); got != tc.processAt.Unix() { if got := int64(zs[0].Score); got != tc.processAt.Unix() {
@ -1292,7 +1515,7 @@ func TestScheduleUnique(t *testing.T) {
desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)" desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)"
err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl) err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl)
if err != nil { if err != nil {
t.Errorf("Frist task: %s = %v, want nil", desc, err) t.Errorf("First task: %s = %v, want nil", desc, err)
continue continue
} }