diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 28cc574..5666be6 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -50,8 +50,8 @@ func (r *RDB) Ping() error { return r.client.Ping().Err() } -// KEYS[1] -> asynq:{qname}:t: -// KEYS[2] -> asynq:{qname}:pending +// KEYS[1] -> asynq:{}:t: +// KEYS[2] -> asynq:{}:pending // ARGV[1] -> task message data // ARGV[2] -> task ID var enqueueCmd = redis.NewScript(` @@ -75,7 +75,8 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error { } // KEYS[1] -> unique key -// KEYS[2] -> asynq:{}:pending +// KEYS[2] -> asynq:{}:t: +// KEYS[3] -> asynq:{}:pending // ARGV[1] -> task ID // ARGV[2] -> uniqueness lock TTL // ARGV[3] -> task message data @@ -84,7 +85,8 @@ local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) if not ok then return 0 end -redis.call("LPUSH", KEYS[2], ARGV[3]) +redis.call("SET", KEYS[2], ARGV[3]) +redis.call("LPUSH", KEYS[3], ARGV[1]) return 1 `) @@ -98,9 +100,9 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { if err := r.client.SAdd(base.AllQueues, msg.Queue).Err(); err != nil { return err } - res, err := enqueueUniqueCmd.Run(r.client, - []string{msg.UniqueKey, base.PendingKey(msg.Queue)}, - msg.ID.String(), int(ttl.Seconds()), encoded).Result() + keys := []string{msg.UniqueKey, base.TaskKey(msg.Queue, msg.ID.String()), base.PendingKey(msg.Queue)} + args := []interface{}{msg.ID.String(), int(ttl.Seconds()), encoded} + res, err := enqueueUniqueCmd.Run(r.client, keys, args...).Result() if err != nil { return err } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 1e8a396..40b2885 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -101,7 +101,7 @@ func TestEnqueueUnique(t *testing.T) { m1 := base.TaskMessage{ ID: uuid.New(), Type: "email", - Payload: map[string]interface{}{"user_id": 123}, + Payload: map[string]interface{}{"user_id": float64(123)}, Queue: base.DefaultQueueName, UniqueKey: base.UniqueKey(base.DefaultQueueName, "email", map[string]interface{}{"user_id": 123}), } @@ -116,13 +116,26 @@ func TestEnqueueUnique(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case. + // Enqueue the first message, should succeed. err := r.EnqueueUnique(tc.msg, tc.ttl) if err != nil { t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil", tc.msg, tc.ttl, err) continue } + gotPending := h.GetPendingMessages(t, r.client, tc.msg.Queue) + if len(gotPending) != 1 { + t.Errorf("%q has length %d, want 1", base.PendingKey(tc.msg.Queue), len(gotPending)) + continue + } + if diff := cmp.Diff(tc.msg, gotPending[0]); diff != "" { + t.Errorf("persisted data differed from the original input (-want, +got)\n%s", diff) + } + if !r.client.SIsMember(base.AllQueues, tc.msg.Queue).Val() { + t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues) + } + // Enqueue the second message, should fail. got := r.EnqueueUnique(tc.msg, tc.ttl) if got != ErrDuplicateTask { t.Errorf("Second message: (*RDB).EnqueueUnique(%v, %v) = %v, want %v", @@ -134,9 +147,6 @@ func TestEnqueueUnique(t *testing.T) { t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl) continue } - if !r.client.SIsMember(base.AllQueues, tc.msg.Queue).Val() { - t.Errorf("%q is not a member of SET %q", tc.msg.Queue, base.AllQueues) - } } }