diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index b9292c9..faaa439 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -497,6 +497,126 @@ func (r *RDB) Requeue(ctx context.Context, msg *base.TaskMessage) error { return r.runScript(ctx, op, requeueCmd, keys, msg.ID) } +// KEYS[1] -> asynq:{}:t: +// KEYS[2] -> asynq:{}:g: +// KEYS[3] -> asynq:{}:groups +// ------- +// ARGV[1] -> task message data +// ARGV[2] -> task ID +// ARGV[3] -> current time in Unix time +// ARGV[4] -> group key +// +// Output: +// Returns 1 if successfully added +// Returns 0 if task ID already exists +var addToGroupCmd = redis.NewScript(` +if redis.call("EXISTS", KEYS[1]) == 1 then + return 0 +end +redis.call("HSET", KEYS[1], + "msg", ARGV[1], + "state", "aggregating") +redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2]) +redis.call("SADD", KEYS[3], ARGV[4]) +return 1 +`) + +func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey string) error { + var op errors.Op = "rdb.AddToGroup" + encoded, err := base.EncodeMessage(msg) + if err != nil { + return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) + } + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { + return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) + } + keys := []string{ + base.TaskKey(msg.Queue, msg.ID), + base.GroupKey(msg.Queue, groupKey), + base.AllGroups(msg.Queue), + } + argv := []interface{}{ + encoded, + msg.ID, + r.clock.Now().Unix(), + groupKey, + } + n, err := r.runScriptWithErrorCode(ctx, op, addToGroupCmd, keys, argv...) + if err != nil { + return err + } + if n == 0 { + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) + } + return nil +} + +// KEYS[1] -> asynq:{}:t: +// KEYS[2] -> asynq:{}:g: +// KEYS[3] -> asynq:{}:groups +// KEYS[4] -> unique key +// ------- +// ARGV[1] -> task message data +// ARGV[2] -> task ID +// ARGV[3] -> current time in Unix time +// ARGV[4] -> group key +// ARGV[5] -> uniqueness lock TTL +// +// Output: +// Returns 1 if successfully added +// Returns 0 if task ID already exists +// Returns -1 if task unique key already exists +var addToGroupUniqueCmd = redis.NewScript(` +local ok = redis.call("SET", KEYS[4], ARGV[2], "NX", "EX", ARGV[5]) +if not ok then + return -1 +end +if redis.call("EXISTS", KEYS[1]) == 1 then + return 0 +end +redis.call("HSET", KEYS[1], + "msg", ARGV[1], + "state", "aggregating") +redis.call("ZADD", KEYS[2], ARGV[3], ARGV[2]) +redis.call("SADD", KEYS[3], ARGV[4]) +return 1 +`) + +func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, groupKey string, ttl time.Duration) error { + var op errors.Op = "rdb.AddToGroupUnique" + encoded, err := base.EncodeMessage(msg) + if err != nil { + return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err)) + } + if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil { + return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err}) + } + keys := []string{ + base.TaskKey(msg.Queue, msg.ID), + base.GroupKey(msg.Queue, groupKey), + base.AllGroups(msg.Queue), + base.UniqueKey(msg.Queue, msg.Type, msg.Payload), + } + argv := []interface{}{ + encoded, + msg.ID, + r.clock.Now().Unix(), + groupKey, + int(ttl.Seconds()), + } + n, err := r.runScriptWithErrorCode(ctx, op, addToGroupUniqueCmd, keys, argv...) + if err != nil { + return err + } + if n == -1 { + return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) + } + if n == 0 { + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) + } + return nil +} + // KEYS[1] -> asynq:{}:t: // KEYS[2] -> asynq:{}:scheduled // ------- diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index b7ad806..85838b8 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -1167,6 +1167,232 @@ func TestRequeue(t *testing.T) { } } +func TestAddToGroup(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + msg := h.NewTaskMessage("mytask", []byte("foo")) + ctx := context.Background() + + tests := []struct { + msg *base.TaskMessage + groupKey string + }{ + { + msg: msg, + groupKey: "mygroup", + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + err := r.AddToGroup(ctx, tc.msg, tc.groupKey) + if err != nil { + t.Errorf("r.AddToGroup(ctx, msg, %q) returned error: %v", tc.groupKey, err) + continue + } + + // Check Group zset has task ID + gkey := base.GroupKey(tc.msg.Queue, tc.groupKey) + zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val() + if n := len(zs); n != 1 { + t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n) + continue + } + if got := zs[0].Member.(string); got != tc.msg.ID { + t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID) + continue + } + if got := int64(zs[0].Score); got != now.Unix() { + t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix()) + continue + } + + // Check the values under the task key. + taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID) + encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field + decoded := h.MustUnmarshal(t, encoded) + if diff := cmp.Diff(tc.msg, decoded); diff != "" { + t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff) + } + state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field + if want := "aggregating"; state != want { + t.Errorf("state field under task-key is set to %q, want %q", state, want) + } + + // Check queue is in the AllQueues set. + if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() { + t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues) + } + } +} + +func TestAddToGroupeTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + ctx := context.Background() + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + UniqueKey: "unique_key_one", + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + UniqueKey: "unique_key_two", + } + const groupKey = "mygroup" + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.AddToGroup(ctx, tc.firstMsg, groupKey); err != nil { + t.Errorf("First message: AddToGroup failed: %v", err) + continue + } + if err := r.AddToGroup(ctx, tc.secondMsg, groupKey); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: AddToGroup returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } + +} + +func TestAddToGroupUnique(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + msg := h.NewTaskMessage("mytask", []byte("foo")) + msg.UniqueKey = base.UniqueKey(msg.Queue, msg.Type, msg.Payload) + ctx := context.Background() + + tests := []struct { + msg *base.TaskMessage + groupKey string + ttl time.Duration + }{ + { + msg: msg, + groupKey: "mygroup", + ttl: 30 * time.Second, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + + err := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl) + if err != nil { + t.Errorf("First message: r.AddToGroupUnique(ctx, msg, %q) returned error: %v", tc.groupKey, err) + continue + } + + // Check Group zset has task ID + gkey := base.GroupKey(tc.msg.Queue, tc.groupKey) + zs := r.client.ZRangeWithScores(ctx, gkey, 0, -1).Val() + if n := len(zs); n != 1 { + t.Errorf("Redis ZSET %q contains %d elements, want 1", gkey, n) + continue + } + if got := zs[0].Member.(string); got != tc.msg.ID { + t.Errorf("Redis ZSET %q member: got %v, want %v", gkey, got, tc.msg.ID) + continue + } + if got := int64(zs[0].Score); got != now.Unix() { + t.Errorf("Redis ZSET %q score: got %d, want %d", gkey, got, now.Unix()) + continue + } + + // Check the values under the task key. + taskKey := base.TaskKey(tc.msg.Queue, tc.msg.ID) + encoded := r.client.HGet(ctx, taskKey, "msg").Val() // "msg" field + decoded := h.MustUnmarshal(t, encoded) + if diff := cmp.Diff(tc.msg, decoded); diff != "" { + t.Errorf("persisted message was %v, want %v; (-want, +got)\n%s", decoded, tc.msg, diff) + } + state := r.client.HGet(ctx, taskKey, "state").Val() // "state" field + if want := "aggregating"; state != want { + t.Errorf("state field under task-key is set to %q, want %q", state, want) + } + + // Check queue is in the AllQueues set. + if !r.client.SIsMember(context.Background(), base.AllQueues, tc.msg.Queue).Val() { + t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues) + } + + got := r.AddToGroupUnique(ctx, tc.msg, tc.groupKey, tc.ttl) + if !errors.Is(got, errors.ErrDuplicateTask) { + t.Errorf("Second message: r.AddGroupUnique(ctx, msg, %q) = %v, want %v", + tc.groupKey, got, errors.ErrDuplicateTask) + continue + } + + gotTTL := r.client.TTL(ctx, tc.msg.UniqueKey).Val() + if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { + t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl) + continue + } + } + +} + +func TestAddToGroupUniqueTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + ctx := context.Background() + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + UniqueKey: "unique_key_one", + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + UniqueKey: "unique_key_two", + } + const groupKey = "mygroup" + const ttl = 30 * time.Second + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.AddToGroupUnique(ctx, tc.firstMsg, groupKey, ttl); err != nil { + t.Errorf("First message: AddToGroupUnique failed: %v", err) + continue + } + if err := r.AddToGroupUnique(ctx, tc.secondMsg, groupKey, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: AddToGroupUnique returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } + +} + func TestSchedule(t *testing.T) { r := setup(t) defer r.Close() @@ -1183,8 +1409,7 @@ func TestSchedule(t *testing.T) { err := r.Schedule(context.Background(), tc.msg, tc.processAt) if err != nil { - t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", - tc.msg, tc.processAt, err) + t.Errorf("(*RDB).Schedule(%v, %v) = %v, want nil", tc.msg, tc.processAt, err) continue } @@ -1192,13 +1417,11 @@ func TestSchedule(t *testing.T) { scheduledKey := base.ScheduledKey(tc.msg.Queue) zs := r.client.ZRangeWithScores(context.Background(), scheduledKey, 0, -1).Val() if n := len(zs); n != 1 { - t.Errorf("Redis ZSET %q contains %d elements, want 1", - scheduledKey, n) + t.Errorf("Redis ZSET %q contains %d elements, want 1", scheduledKey, n) continue } if got := zs[0].Member.(string); got != tc.msg.ID { - t.Errorf("Redis ZSET %q member: got %v, want %v", - scheduledKey, got, tc.msg.ID) + t.Errorf("Redis ZSET %q member: got %v, want %v", scheduledKey, got, tc.msg.ID) continue } if got := int64(zs[0].Score); got != tc.processAt.Unix() { @@ -1292,7 +1515,7 @@ func TestScheduleUnique(t *testing.T) { desc := "(*RDB).ScheduleUnique(msg, processAt, ttl)" err := r.ScheduleUnique(context.Background(), tc.msg, tc.processAt, tc.ttl) if err != nil { - t.Errorf("Frist task: %s = %v, want nil", desc, err) + t.Errorf("First task: %s = %v, want nil", desc, err) continue }