2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-12-27 00:02:19 +08:00

Update RDB Enqueue and Schedule methods to check for task ID conflict

This commit is contained in:
Ken Hibino 2021-09-10 16:47:00 -07:00
parent 2261c7c9a0
commit dbdd9c6d5f
3 changed files with 220 additions and 15 deletions

View File

@ -170,6 +170,9 @@ var (
// ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock. // ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock.
ErrDuplicateTask = errors.New("task already exists") ErrDuplicateTask = errors.New("task already exists")
// ErrTaskIdConflict indicates that another task with the same task ID already exist
ErrTaskIdConflict = errors.New("task id conflicts with another task")
) )
// TaskNotFoundError indicates that a task with the given ID does not exist // TaskNotFoundError indicates that a task with the given ID does not exist

View File

@ -50,6 +50,19 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args
return nil return nil
} }
// Runs the given script with keys and args and retuns the script's return value as int64.
func (r *RDB) runScriptWithErrorCode(op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) {
res, err := script.Run(context.Background(), r.client, keys, args...).Result()
if err != nil {
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err))
}
n, ok := res.(int64)
if !ok {
return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res))
}
return n, nil
}
// enqueueCmd enqueues a given task message. // enqueueCmd enqueues a given task message.
// //
// Input: // Input:
@ -63,7 +76,11 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args
// //
// Output: // Output:
// Returns 1 if successfully enqueued // Returns 1 if successfully enqueued
// Returns 0 if task ID already exists
var enqueueCmd = redis.NewScript(` var enqueueCmd = redis.NewScript(`
if redis.call("EXISTS", KEYS[1]) == 1 then
return 0
end
redis.call("HSET", KEYS[1], redis.call("HSET", KEYS[1],
"msg", ARGV[1], "msg", ARGV[1],
"state", "pending", "state", "pending",
@ -93,7 +110,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
return r.runScript(op, enqueueCmd, keys, argv...) n, err := r.runScriptWithErrorCode(op, enqueueCmd, keys, argv...)
if err != nil {
return err
}
if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
}
return nil
} }
// enqueueUniqueCmd enqueues the task message if the task is unique. // enqueueUniqueCmd enqueues the task message if the task is unique.
@ -110,10 +134,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
// //
// Output: // Output:
// Returns 1 if successfully enqueued // Returns 1 if successfully enqueued
// Returns 0 if task already exists // Returns 0 if task ID conflicts with another task
// Returns -1 if task unique key already exists
var enqueueUniqueCmd = redis.NewScript(` var enqueueUniqueCmd = redis.NewScript(`
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
if not ok then if not ok then
return -1
end
if redis.call("EXISTS", KEYS[2]) == 1 then
return 0 return 0
end end
redis.call("HSET", KEYS[2], redis.call("HSET", KEYS[2],
@ -149,16 +177,15 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
res, err := enqueueUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result() n, err := r.runScriptWithErrorCode(op, enqueueUniqueCmd, keys, argv...)
if err != nil { if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) return err
} }
n, ok := res.(int64) if n == -1 {
if !ok { return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
return errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res))
} }
if n == 0 { if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
} }
return nil return nil
} }
@ -362,7 +389,14 @@ func (r *RDB) Requeue(msg *base.TaskMessage) error {
// ARGV[3] -> task ID // ARGV[3] -> task ID
// ARGV[4] -> task timeout in seconds (0 if not timeout) // ARGV[4] -> task timeout in seconds (0 if not timeout)
// ARGV[5] -> task deadline in unix time (0 if no deadline) // ARGV[5] -> task deadline in unix time (0 if no deadline)
//
// Output:
// Returns 1 if successfully enqueued
// Returns 0 if task ID already exists
var scheduleCmd = redis.NewScript(` var scheduleCmd = redis.NewScript(`
if redis.call("EXISTS", KEYS[1]) == 1 then
return 0
end
redis.call("HSET", KEYS[1], redis.call("HSET", KEYS[1],
"msg", ARGV[1], "msg", ARGV[1],
"state", "scheduled", "state", "scheduled",
@ -393,7 +427,14 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
return r.runScript(op, scheduleCmd, keys, argv...) n, err := r.runScriptWithErrorCode(op, scheduleCmd, keys, argv...)
if err != nil {
return err
}
if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
}
return nil
} }
// KEYS[1] -> unique key // KEYS[1] -> unique key
@ -405,9 +446,17 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
// ARGV[4] -> task message // ARGV[4] -> task message
// ARGV[5] -> task timeout in seconds (0 if not timeout) // ARGV[5] -> task timeout in seconds (0 if not timeout)
// ARGV[6] -> task deadline in unix time (0 if no deadline) // ARGV[6] -> task deadline in unix time (0 if no deadline)
//
// Output:
// Returns 1 if successfully scheduled
// Returns 0 if task ID already exists
// Returns -1 if task unique key already exists
var scheduleUniqueCmd = redis.NewScript(` var scheduleUniqueCmd = redis.NewScript(`
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
if not ok then if not ok then
return -1
end
if redis.call("EXISTS", KEYS[2]) == 1 then
return 0 return 0
end end
redis.call("HSET", KEYS[2], redis.call("HSET", KEYS[2],
@ -444,16 +493,15 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim
msg.Timeout, msg.Timeout,
msg.Deadline, msg.Deadline,
} }
res, err := scheduleUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result() n, err := r.runScriptWithErrorCode(op, scheduleUniqueCmd, keys, argv...)
if err != nil { if err != nil {
return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err)) return err
} }
n, ok := res.(int64) if n == -1 {
if !ok { return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
return errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res))
} }
if n == 0 { if n == 0 {
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask) return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
} }
return nil return nil
} }

View File

@ -123,6 +123,42 @@ func TestEnqueue(t *testing.T) {
} }
} }
func TestEnqueueTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
}
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.Enqueue(tc.firstMsg); err != nil {
t.Errorf("First message: Enqueue failed: %v", err)
continue
}
if err := r.Enqueue(tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestEnqueueUnique(t *testing.T) { func TestEnqueueUnique(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()
@ -218,6 +254,45 @@ func TestEnqueueUnique(t *testing.T) {
} }
} }
func TestEnqueueUniqueTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
UniqueKey: "unique_key_one",
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
UniqueKey: "unique_key_two",
}
const ttl = 30 * time.Second
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.EnqueueUnique(tc.firstMsg, ttl); err != nil {
t.Errorf("First message: EnqueueUnique failed: %v", err)
continue
}
if err := r.EnqueueUnique(tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestDequeue(t *testing.T) { func TestDequeue(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()
@ -946,6 +1021,45 @@ func TestSchedule(t *testing.T) {
} }
} }
func TestScheduleTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
UniqueKey: "unique_key_one",
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
UniqueKey: "unique_key_two",
}
processAt := time.Now().Add(30 * time.Second)
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.Schedule(tc.firstMsg, processAt); err != nil {
t.Errorf("First message: Schedule failed: %v", err)
continue
}
if err := r.Schedule(tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestScheduleUnique(t *testing.T) { func TestScheduleUnique(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()
@ -1040,6 +1154,46 @@ func TestScheduleUnique(t *testing.T) {
} }
} }
func TestScheduleUniqueTaskIdConflictError(t *testing.T) {
r := setup(t)
defer r.Close()
m1 := base.TaskMessage{
ID: "custom_id",
Type: "foo",
Payload: nil,
UniqueKey: "unique_key_one",
}
m2 := base.TaskMessage{
ID: "custom_id",
Type: "bar",
Payload: nil,
UniqueKey: "unique_key_two",
}
const ttl = 30 * time.Second
processAt := time.Now().Add(30 * time.Second)
tests := []struct {
firstMsg *base.TaskMessage
secondMsg *base.TaskMessage
}{
{firstMsg: &m1, secondMsg: &m2},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
if err := r.ScheduleUnique(tc.firstMsg, processAt, ttl); err != nil {
t.Errorf("First message: ScheduleUnique failed: %v", err)
continue
}
if err := r.ScheduleUnique(tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
continue
}
}
}
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()