mirror of
https://github.com/hibiken/asynq.git
synced 2024-12-25 23:32:17 +08:00
Add RDB.AddToGroup and RDB.AddToGroupUnique methods
This commit is contained in:
parent
e3d2939a4c
commit
8b582899ad
@ -497,6 +497,126 @@ func (r *RDB) Requeue(ctx context.Context, msg *base.TaskMessage) error {
|
||||
return r.runScript(ctx, op, requeueCmd, keys, msg.ID)
|
||||
}
|
||||
|
||||
// KEYS[1] -> asynq:{<qname>}:t:<task_id>
|
||||
// KEYS[2] -> asynq:{<qname>}:g:<group_key>
|
||||
// KEYS[3] -> asynq:{<qname>}:groups
|
||||
// -------
|
||||
// ARGV[1] -> task message data
|
||||
// ARGV[2] -> task ID
|
||||
// ARGV[3] -> current time in Unix time
|
||||
// ARGV[4] -> group key
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully added
|
||||
// Returns 0 if task ID already exists
|
||||
var addToGroupCmd = redis.NewScript(`
|
||||
if redis.call("EXISTS", KEYS[1]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[1],
|
||||
"msg", ARGV[1],
|
||||
"state", "aggregating")
|
||||
redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2])
|
||||
redis.call("SADD", KEYS[3], ARGV[4])
|
||||
return 1
|
||||
`)
|
||||
|
||||
func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey string) error {
|
||||
var op errors.Op = "rdb.AddToGroup"
|
||||
encoded, err := base.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
|
||||
}
|
||||
if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
|
||||
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
|
||||
}
|
||||
keys := []string{
|
||||
base.TaskKey(msg.Queue, msg.ID),
|
||||
base.GroupKey(msg.Queue, groupKey),
|
||||
base.AllGroups(msg.Queue),
|
||||
}
|
||||
argv := []interface{}{
|
||||
encoded,
|
||||
msg.ID,
|
||||
r.clock.Now().Unix(),
|
||||
groupKey,
|
||||
}
|
||||
n, err := r.runScriptWithErrorCode(ctx, op, addToGroupCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// KEYS[1] -> asynq:{<qname>}:t:<task_id>
|
||||
// KEYS[2] -> asynq:{<qname>}:g:<group_key>
|
||||
// KEYS[3] -> asynq:{<qname>}:groups
|
||||
// KEYS[4] -> unique key
|
||||
// -------
|
||||
// ARGV[1] -> task message data
|
||||
// ARGV[2] -> task ID
|
||||
// ARGV[3] -> current time in Unix time
|
||||
// ARGV[4] -> group key
|
||||
// ARGV[5] -> uniqueness lock TTL
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully added
|
||||
// Returns 0 if task ID already exists
|
||||
// Returns -1 if task unique key already exists
|
||||
var addToGroupUniqueCmd = redis.NewScript(`
|
||||
local ok = redis.call("SET", KEYS[4], ARGV[2], "NX", "EX", ARGV[5])
|
||||
if not ok then
|
||||
return -1
|
||||
end
|
||||
if redis.call("EXISTS", KEYS[1]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[1],
|
||||
"msg", ARGV[1],
|
||||
"state", "aggregating")
|
||||
redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2])
|
||||
redis.call("SADD", KEYS[3], ARGV[4])
|
||||
return 1
|
||||
`)
|
||||
|
||||
func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, groupKey string, ttl time.Duration) error {
|
||||
var op errors.Op = "rdb.AddToGroupUnique"
|
||||
encoded, err := base.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
|
||||
}
|
||||
if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
|
||||
return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
|
||||
}
|
||||
keys := []string{
|
||||
base.TaskKey(msg.Queue, msg.ID),
|
||||
base.GroupKey(msg.Queue, groupKey),
|
||||
base.AllGroups(msg.Queue),
|
||||
base.UniqueKey(msg.Queue, msg.Type, msg.Payload),
|
||||
}
|
||||
argv := []interface{}{
|
||||
encoded,
|
||||
msg.ID,
|
||||
r.clock.Now().Unix(),
|
||||
groupKey,
|
||||
int(ttl.Seconds()),
|
||||
}
|
||||
n, err := r.runScriptWithErrorCode(ctx, op, addToGroupUniqueCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == -1 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// KEYS[1] -> asynq:{<qname>}:t:<task_id>
|
||||
// KEYS[2] -> asynq:{<qname>}:scheduled
|
||||
// -------
|
||||
|
@ -1167,6 +1167,232 @@ func TestRequeue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToGroup(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
now := time.Now()
|
||||
r.SetClock(timeutil.NewSimulatedClock(now))
|
||||
msg := h.NewTaskMessage("mytask", []byte("foo"))
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
msg *base.TaskMessage
|
||||
groupKey string
|
||||
}{
|
||||
{
|
||||
msg: msg,
|
||||
groupKey: "mygroup",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client)
|
||||
|
||||
err := r.AddToGroup(ctx, tc.msg, tc.groupKey)
|
||||
if err != nil {
|
||||
t.Errorf("r.AddToGroup(ctx, msg, %q) returned error: %v", tc.groupKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check Group zset has task ID
|
||||
gkey := base.GroupKey(tc.msg.Queue, tc.groupKey)
|
||||
zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val()
|
||||
if n := len(zs); n != 1 {
|
||||
t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n)
|
||||
continue
|
||||
}
|
||||
if got := zs[0].Member.(string); got != tc.msg.ID {
|
||||
t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID)
|
||||
continue
|
||||
}
|
||||
if got := int64(zs[0].Score); got != now.Unix() {
|
||||
t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix())
|
||||
continue
|
||||
}
|
||||
|
||||
// Check the values under the task key.
|
||||
taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID)
|
||||
encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field
|
||||
decoded := h.MustUnmarshal(t, encoded)
|
||||
if diff := cmp.Diff(tc.msg, decoded); diff != "" {
|
||||
t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff)
|
||||
}
|
||||
state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field
|
||||
if want := "aggregating"; state != want {
|
||||
t.Errorf("state field under task-key is set to %q, want %q", state, want)
|
||||
}
|
||||
|
||||
// Check queue is in the AllQueues set.
|
||||
if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() {
|
||||
t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToGroupeTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_one",
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_two",
|
||||
}
|
||||
const groupKey = "mygroup"
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.AddToGroup(ctx, tc.firstMsg, groupKey); err != nil {
|
||||
t.Errorf("First message: AddToGroup failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.AddToGroup(ctx, tc.secondMsg, groupKey); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: AddToGroup returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAddToGroupUnique(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
now := time.Now()
|
||||
r.SetClock(timeutil.NewSimulatedClock(now))
|
||||
msg := h.NewTaskMessage("mytask", []byte("foo"))
|
||||
msg.UniqueKey = base.UniqueKey(msg.Queue, msg.Type, msg.Payload)
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
msg *base.TaskMessage
|
||||
groupKey string
|
||||
ttl time.Duration
|
||||
}{
|
||||
{
|
||||
msg: msg,
|
||||
groupKey: "mygroup",
|
||||
ttl: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client)
|
||||
|
||||
err := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl)
|
||||
if err != nil {
|
||||
t.Errorf("First message: r.AddToGroupUnique(ctx, msg, %q) returned error: %v", tc.groupKey, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check Group zset has task ID
|
||||
gkey := base.GroupKey(tc.msg.Queue, tc.groupKey)
|
||||
zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val()
|
||||
if n := len(zs); n != 1 {
|
||||
t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n)
|
||||
continue
|
||||
}
|
||||
if got := zs[0].Member.(string); got != tc.msg.ID {
|
||||
t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID)
|
||||
continue
|
||||
}
|
||||
if got := int64(zs[0].Score); got != now.Unix() {
|
||||
t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix())
|
||||
continue
|
||||
}
|
||||
|
||||
// Check the values under the task key.
|
||||
taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID)
|
||||
encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field
|
||||
decoded := h.MustUnmarshal(t, encoded)
|
||||
if diff := cmp.Diff(tc.msg, decoded); diff != "" {
|
||||
t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff)
|
||||
}
|
||||
state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field
|
||||
if want := "aggregating"; state != want {
|
||||
t.Errorf("state field under task-key is set to %q, want %q", state, want)
|
||||
}
|
||||
|
||||
// Check queue is in the AllQueues set.
|
||||
if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() {
|
||||
t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues)
|
||||
}
|
||||
|
||||
got := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl)
|
||||
if !errors.Is(got, errors.ErrDuplicateTask) {
|
||||
t.Errorf("Second message: r.AddGroupUnique(ctx, msg, %q) = %v, want %v",
|
||||
tc.groupKey, got, errors.ErrDuplicateTask)
|
||||
continue
|
||||
}
|
||||
|
||||
gotTTL := r.client.TTL(ctx, tc.msg.UniqueKey).Val()
|
||||
if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
|
||||
t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAddToGroupUniqueTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_one",
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_two",
|
||||
}
|
||||
const groupKey = "mygroup"
|
||||
const ttl = 30 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.AddToGroupUnique(ctx, tc.firstMsg, groupKey, ttl); err != nil {
|
||||
t.Errorf("First message: AddToGroupUnique failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.AddToGroupUnique(ctx, tc.secondMsg, groupKey, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: AddToGroupUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSchedule(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
@ -1183,8 +1409,7 @@ func TestSchedule(t *testing.T) {
|
||||
|
||||
err := r.Schedule(context.Background(), tc.msg, tc.processAt)
|
||||
if err != nil {
|
||||
t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil",
|
||||
tc.msg, tc.processAt, err)
|
||||
t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", tc.msg, tc.processAt, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -1192,13 +1417,11 @@ func TestSchedule(t *testing.T) {
|
||||
scheduledKey := base.ScheduledKey(tc.msg.Queue)
|
||||
zs := r.client.ZRangeWithScores(context.Background(), scheduledKey, 0, -1).Val()
|
||||
if n := len(zs); n != 1 {
|
||||
t.Errorf("Redis ZSET %q contains %d elements, want 1",
|
||||
scheduledKey, n)
|
||||
t.Errorf("Redis ZSET %q contains %d elements, want 1", scheduledKey, n)
|
||||
continue
|
||||
}
|
||||
if got := zs[0].Member.(string); got != tc.msg.ID {
|
||||
t.Errorf("Redis ZSET %q member: got %v, want %v",
|
||||
scheduledKey, got, tc.msg.ID)
|
||||
t.Errorf("Redis ZSET %q member: got %v, want %v", scheduledKey, got, tc.msg.ID)
|
||||
continue
|
||||
}
|
||||
if got := int64(zs[0].Score); got != tc.processAt.Unix() {
|
||||
@ -1292,7 +1515,7 @@ func TestScheduleUnique(t *testing.T) {
|
||||
desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)"
|
||||
err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl)
|
||||
if err != nil {
|
||||
t.Errorf("Frist task: %s = %v, want nil", desc, err)
|
||||
t.Errorf("First task: %s = %v, want nil", desc, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user