diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 5f415ac..939893d 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -170,6 +170,9 @@ var ( // ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock. ErrDuplicateTask = errors.New("task already exists") + + // ErrTaskIdConflict indicates that another task with the same task ID already exist + ErrTaskIdConflict = errors.New("task id conflicts with another task") ) // TaskNotFoundError indicates that a task with the given ID does not exist diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index de12ca7..f546253 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -50,6 +50,19 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args return nil } +// Runs the given script with keys and args and retuns the script's return value as int64. +func (r *RDB) runScriptWithErrorCode(op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) { + res, err := script.Run(context.Background(), r.client, keys, args...).Result() + if err != nil { + return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) + } + n, ok := res.(int64) + if !ok { + return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res)) + } + return n, nil +} + // enqueueCmd enqueues a given task message. // // Input: @@ -63,7 +76,11 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args // // Output: // Returns 1 if successfully enqueued +// Returns 0 if task ID already exists var enqueueCmd = redis.NewScript(` +if redis.call("EXISTS", KEYS[1]) == 1 then + return 0 +end redis.call("HSET", KEYS[1], "msg", ARGV[1], "state", "pending", @@ -93,7 +110,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error { msg.Timeout, msg.Deadline, } - return r.runScript(op, enqueueCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(op, enqueueCmd, keys, argv...) + if err != nil { + return err + } + if n == 0 { + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) + } + return nil } // enqueueUniqueCmd enqueues the task message if the task is unique. @@ -110,10 +134,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error { // // Output: // Returns 1 if successfully enqueued -// Returns 0 if task already exists +// Returns 0 if task ID conflicts with another task +// Returns -1 if task unique key already exists var enqueueUniqueCmd = redis.NewScript(` local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) if not ok then + return -1 +end +if redis.call("EXISTS", KEYS[2]) == 1 then return 0 end redis.call("HSET", KEYS[2], @@ -149,16 +177,15 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { msg.Timeout, msg.Deadline, } - res, err := enqueueUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result() + n, err := r.runScriptWithErrorCode(op, enqueueUniqueCmd, keys, argv...) if err != nil { - return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) + return err } - n, ok := res.(int64) - if !ok { - return errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res)) + if n == -1 { + return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) } if n == 0 { - return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } return nil } @@ -362,7 +389,14 @@ func (r *RDB) Requeue(msg *base.TaskMessage) error { // ARGV[3] -> task ID // ARGV[4] -> task timeout in seconds (0 if not timeout) // ARGV[5] -> task deadline in unix time (0 if no deadline) +// +// Output: +// Returns 1 if successfully enqueued +// Returns 0 if task ID already exists var scheduleCmd = redis.NewScript(` +if redis.call("EXISTS", KEYS[1]) == 1 then + return 0 +end redis.call("HSET", KEYS[1], "msg", ARGV[1], "state", "scheduled", @@ -393,7 +427,14 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { msg.Timeout, msg.Deadline, } - return r.runScript(op, scheduleCmd, keys, argv...) + n, err := r.runScriptWithErrorCode(op, scheduleCmd, keys, argv...) + if err != nil { + return err + } + if n == 0 { + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) + } + return nil } // KEYS[1] -> unique key @@ -405,9 +446,17 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { // ARGV[4] -> task message // ARGV[5] -> task timeout in seconds (0 if not timeout) // ARGV[6] -> task deadline in unix time (0 if no deadline) +// +// Output: +// Returns 1 if successfully scheduled +// Returns 0 if task ID already exists +// Returns -1 if task unique key already exists var scheduleUniqueCmd = redis.NewScript(` local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) if not ok then + return -1 +end +if redis.call("EXISTS", KEYS[2]) == 1 then return 0 end redis.call("HSET", KEYS[2], @@ -444,16 +493,15 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim msg.Timeout, msg.Deadline, } - res, err := scheduleUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result() + n, err := r.runScriptWithErrorCode(op, scheduleUniqueCmd, keys, argv...) if err != nil { - return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) + return err } - n, ok := res.(int64) - if !ok { - return errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res)) + if n == -1 { + return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) } if n == 0 { - return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) + return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } return nil } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 8dbc179..87a29c4 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -123,6 +123,42 @@ func TestEnqueue(t *testing.T) { } } +func TestEnqueueTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + } + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.Enqueue(tc.firstMsg); err != nil { + t.Errorf("First message: Enqueue failed: %v", err) + continue + } + if err := r.Enqueue(tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } +} + func TestEnqueueUnique(t *testing.T) { r := setup(t) defer r.Close() @@ -218,6 +254,45 @@ func TestEnqueueUnique(t *testing.T) { } } +func TestEnqueueUniqueTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + UniqueKey: "unique_key_one", + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + UniqueKey: "unique_key_two", + } + const ttl = 30 * time.Second + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.EnqueueUnique(tc.firstMsg, ttl); err != nil { + t.Errorf("First message: EnqueueUnique failed: %v", err) + continue + } + if err := r.EnqueueUnique(tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } +} + func TestDequeue(t *testing.T) { r := setup(t) defer r.Close() @@ -946,6 +1021,45 @@ func TestSchedule(t *testing.T) { } } +func TestScheduleTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + UniqueKey: "unique_key_one", + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + UniqueKey: "unique_key_two", + } + processAt := time.Now().Add(30 * time.Second) + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.Schedule(tc.firstMsg, processAt); err != nil { + t.Errorf("First message: Schedule failed: %v", err) + continue + } + if err := r.Schedule(tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } +} + func TestScheduleUnique(t *testing.T) { r := setup(t) defer r.Close() @@ -1040,6 +1154,46 @@ func TestScheduleUnique(t *testing.T) { } } +func TestScheduleUniqueTaskIdConflictError(t *testing.T) { + r := setup(t) + defer r.Close() + + m1 := base.TaskMessage{ + ID: "custom_id", + Type: "foo", + Payload: nil, + UniqueKey: "unique_key_one", + } + m2 := base.TaskMessage{ + ID: "custom_id", + Type: "bar", + Payload: nil, + UniqueKey: "unique_key_two", + } + const ttl = 30 * time.Second + processAt := time.Now().Add(30 * time.Second) + + tests := []struct { + firstMsg *base.TaskMessage + secondMsg *base.TaskMessage + }{ + {firstMsg: &m1, secondMsg: &m2}, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) // clean up db before each test case. + + if err := r.ScheduleUnique(tc.firstMsg, processAt, ttl); err != nil { + t.Errorf("First message: ScheduleUnique failed: %v", err) + continue + } + if err := r.ScheduleUnique(tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) { + t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict) + continue + } + } +} + func TestRetry(t *testing.T) { r := setup(t) defer r.Close()