mirror of
https://github.com/hibiken/asynq.git
synced 2024-12-26 07:42:17 +08:00
Update RDB Enqueue and Schedule methods to check for task ID conflict
This commit is contained in:
parent
2261c7c9a0
commit
dbdd9c6d5f
@ -170,6 +170,9 @@ var (
|
||||
|
||||
// ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock.
|
||||
ErrDuplicateTask = errors.New("task already exists")
|
||||
|
||||
// ErrTaskIdConflict indicates that another task with the same task ID already exist
|
||||
ErrTaskIdConflict = errors.New("task id conflicts with another task")
|
||||
)
|
||||
|
||||
// TaskNotFoundError indicates that a task with the given ID does not exist
|
||||
|
@ -50,6 +50,19 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args
|
||||
return nil
|
||||
}
|
||||
|
||||
// Runs the given script with keys and args and retuns the script's return value as int64.
|
||||
func (r *RDB) runScriptWithErrorCode(op errors.Op, script *redis.Script, keys []string, args ...interface{}) (int64, error) {
|
||||
res, err := script.Run(context.Background(), r.client, keys, args...).Result()
|
||||
if err != nil {
|
||||
return 0, errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err))
|
||||
}
|
||||
n, ok := res.(int64)
|
||||
if !ok {
|
||||
return 0, errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res))
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// enqueueCmd enqueues a given task message.
|
||||
//
|
||||
// Input:
|
||||
@ -63,7 +76,11 @@ func (r *RDB) runScript(op errors.Op, script *redis.Script, keys []string, args
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully enqueued
|
||||
// Returns 0 if task ID already exists
|
||||
var enqueueCmd = redis.NewScript(`
|
||||
if redis.call("EXISTS", KEYS[1]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[1],
|
||||
"msg", ARGV[1],
|
||||
"state", "pending",
|
||||
@ -93,7 +110,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
|
||||
msg.Timeout,
|
||||
msg.Deadline,
|
||||
}
|
||||
return r.runScript(op, enqueueCmd, keys, argv...)
|
||||
n, err := r.runScriptWithErrorCode(op, enqueueCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// enqueueUniqueCmd enqueues the task message if the task is unique.
|
||||
@ -110,10 +134,14 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully enqueued
|
||||
// Returns 0 if task already exists
|
||||
// Returns 0 if task ID conflicts with another task
|
||||
// Returns -1 if task unique key already exists
|
||||
var enqueueUniqueCmd = redis.NewScript(`
|
||||
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
|
||||
if not ok then
|
||||
return -1
|
||||
end
|
||||
if redis.call("EXISTS", KEYS[2]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[2],
|
||||
@ -149,16 +177,15 @@ func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error {
|
||||
msg.Timeout,
|
||||
msg.Deadline,
|
||||
}
|
||||
res, err := enqueueUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result()
|
||||
n, err := r.runScriptWithErrorCode(op, enqueueUniqueCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err))
|
||||
return err
|
||||
}
|
||||
n, ok := res.(int64)
|
||||
if !ok {
|
||||
return errors.E(op, errors.Internal, fmt.Sprintf("unexpected return value from Lua script: %v", res))
|
||||
if n == -1 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -362,7 +389,14 @@ func (r *RDB) Requeue(msg *base.TaskMessage) error {
|
||||
// ARGV[3] -> task ID
|
||||
// ARGV[4] -> task timeout in seconds (0 if not timeout)
|
||||
// ARGV[5] -> task deadline in unix time (0 if no deadline)
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully enqueued
|
||||
// Returns 0 if task ID already exists
|
||||
var scheduleCmd = redis.NewScript(`
|
||||
if redis.call("EXISTS", KEYS[1]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[1],
|
||||
"msg", ARGV[1],
|
||||
"state", "scheduled",
|
||||
@ -393,7 +427,14 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
|
||||
msg.Timeout,
|
||||
msg.Deadline,
|
||||
}
|
||||
return r.runScript(op, scheduleCmd, keys, argv...)
|
||||
n, err := r.runScriptWithErrorCode(op, scheduleCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// KEYS[1] -> unique key
|
||||
@ -405,9 +446,17 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
|
||||
// ARGV[4] -> task message
|
||||
// ARGV[5] -> task timeout in seconds (0 if not timeout)
|
||||
// ARGV[6] -> task deadline in unix time (0 if no deadline)
|
||||
//
|
||||
// Output:
|
||||
// Returns 1 if successfully scheduled
|
||||
// Returns 0 if task ID already exists
|
||||
// Returns -1 if task unique key already exists
|
||||
var scheduleUniqueCmd = redis.NewScript(`
|
||||
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
|
||||
if not ok then
|
||||
return -1
|
||||
end
|
||||
if redis.call("EXISTS", KEYS[2]) == 1 then
|
||||
return 0
|
||||
end
|
||||
redis.call("HSET", KEYS[2],
|
||||
@ -444,16 +493,15 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim
|
||||
msg.Timeout,
|
||||
msg.Deadline,
|
||||
}
|
||||
res, err := scheduleUniqueCmd.Run(context.Background(), r.client, keys, argv...).Result()
|
||||
n, err := r.runScriptWithErrorCode(op, scheduleUniqueCmd, keys, argv...)
|
||||
if err != nil {
|
||||
return errors.E(op, errors.Unknown, fmt.Sprintf("redis eval error: %v", err))
|
||||
return err
|
||||
}
|
||||
n, ok := res.(int64)
|
||||
if !ok {
|
||||
return errors.E(op, errors.Internal, fmt.Sprintf("cast error: unexpected return value from Lua script: %v", res))
|
||||
if n == -1 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
|
||||
}
|
||||
if n == 0 {
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrDuplicateTask)
|
||||
return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -123,6 +123,42 @@ func TestEnqueue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.Enqueue(tc.firstMsg); err != nil {
|
||||
t.Errorf("First message: Enqueue failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.Enqueue(tc.secondMsg); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: Enqueue returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueUnique(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
@ -218,6 +254,45 @@ func TestEnqueueUnique(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnqueueUniqueTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_one",
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_two",
|
||||
}
|
||||
const ttl = 30 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.EnqueueUnique(tc.firstMsg, ttl); err != nil {
|
||||
t.Errorf("First message: EnqueueUnique failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.EnqueueUnique(tc.secondMsg, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: EnqueueUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDequeue(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
@ -946,6 +1021,45 @@ func TestSchedule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_one",
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_two",
|
||||
}
|
||||
processAt := time.Now().Add(30 * time.Second)
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.Schedule(tc.firstMsg, processAt); err != nil {
|
||||
t.Errorf("First message: Schedule failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.Schedule(tc.secondMsg, processAt); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: Schedule returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleUnique(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
@ -1040,6 +1154,46 @@ func TestScheduleUnique(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleUniqueTaskIdConflictError(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
||||
m1 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "foo",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_one",
|
||||
}
|
||||
m2 := base.TaskMessage{
|
||||
ID: "custom_id",
|
||||
Type: "bar",
|
||||
Payload: nil,
|
||||
UniqueKey: "unique_key_two",
|
||||
}
|
||||
const ttl = 30 * time.Second
|
||||
processAt := time.Now().Add(30 * time.Second)
|
||||
|
||||
tests := []struct {
|
||||
firstMsg *base.TaskMessage
|
||||
secondMsg *base.TaskMessage
|
||||
}{
|
||||
{firstMsg: &m1, secondMsg: &m2},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
h.FlushDB(t, r.client) // clean up db before each test case.
|
||||
|
||||
if err := r.ScheduleUnique(tc.firstMsg, processAt, ttl); err != nil {
|
||||
t.Errorf("First message: ScheduleUnique failed: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := r.ScheduleUnique(tc.secondMsg, processAt, ttl); !errors.Is(err, errors.ErrTaskIdConflict) {
|
||||
t.Errorf("Second message: ScheduleUnique returned %v, want %v", err, errors.ErrTaskIdConflict)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
r := setup(t)
|
||||
defer r.Close()
|
||||
|
Loading…
Reference in New Issue
Block a user