mirror of
				https://github.com/hibiken/asynq.git
				synced 2025-10-26 11:16:12 +08:00 
			
		
		
		
	Allow client to enqueue a task with unique option
Changes: - Added Unique option for clients - Require go v.13 or above (to use new errors wrapping functions) - Fixed adding queue key to all-queues set (asynq:queues) when scheduling.
This commit is contained in:
		| @@ -22,6 +22,9 @@ var ( | ||||
|  | ||||
| 	// ErrTaskNotFound indicates that a task that matches the given identifier was not found. | ||||
| 	ErrTaskNotFound = errors.New("could not find a task") | ||||
|  | ||||
| 	// ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock. | ||||
| 	ErrDuplicateTask = errors.New("task already exists") | ||||
| ) | ||||
|  | ||||
| const statsTTL = 90 * 24 * time.Hour // 90 days | ||||
| @@ -59,6 +62,46 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error { | ||||
| 	return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, bytes).Err() | ||||
| } | ||||
|  | ||||
| // KEYS[1] -> unique key in the form <type>:<payload>:<qname> | ||||
| // KEYS[2] -> asynq:queues:<qname> | ||||
| // KEYS[2] -> asynq:queues | ||||
| // ARGV[1] -> task ID | ||||
| // ARGV[2] -> uniqueness lock TTL | ||||
| // ARGV[3] -> task message data | ||||
| var enqueueUniqueCmd = redis.NewScript(` | ||||
| local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) | ||||
| if not ok then | ||||
|   return 0 | ||||
| end | ||||
| redis.call("LPUSH", KEYS[2], ARGV[3]) | ||||
| redis.call("SADD", KEYS[3], KEYS[2]) | ||||
| return 1 | ||||
| `) | ||||
|  | ||||
| // EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired. | ||||
| // It returns ErrDuplicateTask if the lock cannot be acquired. | ||||
| func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { | ||||
| 	bytes, err := json.Marshal(msg) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	key := base.QueueKey(msg.Queue) | ||||
| 	res, err := enqueueUniqueCmd.Run(r.client, | ||||
| 		[]string{msg.UniqueKey, key, base.AllQueues}, | ||||
| 		msg.ID.String(), int(ttl.Seconds()), bytes).Result() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	n, ok := res.(int64) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("could not cast %v to int64", res) | ||||
| 	} | ||||
| 	if n == 0 { | ||||
| 		return ErrDuplicateTask | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Dequeue queries given queues in order and pops a task message if there is one and returns it. | ||||
| // If all queues are empty, ErrNoProcessableTask error is returned. | ||||
| func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) { | ||||
| @@ -118,8 +161,10 @@ func (r *RDB) dequeue(queues ...string) (data string, err error) { | ||||
|  | ||||
| // KEYS[1] -> asynq:in_progress | ||||
| // KEYS[2] -> asynq:processed:<yyyy-mm-dd> | ||||
| // KEYS[3] -> unique key in the format <type>:<payload>:<qname> | ||||
| // ARGV[1] -> base.TaskMessage value | ||||
| // ARGV[2] -> stats expiration timestamp | ||||
| // ARGV[3] -> task ID | ||||
| // Note: LREM count ZERO means "remove all elements equal to val" | ||||
| var doneCmd = redis.NewScript(` | ||||
| redis.call("LREM", KEYS[1], 0, ARGV[1])  | ||||
| @@ -127,10 +172,14 @@ local n = redis.call("INCR", KEYS[2]) | ||||
| if tonumber(n) == 1 then | ||||
| 	redis.call("EXPIREAT", KEYS[2], ARGV[2]) | ||||
| end | ||||
| if string.len(KEYS[3]) > 0 and redis.call("GET", KEYS[3]) == ARGV[3] then | ||||
|   redis.call("DEL", KEYS[3]) | ||||
| end | ||||
| return redis.status_reply("OK") | ||||
| `) | ||||
|  | ||||
| // Done removes the task from in-progress queue to mark the task as done. | ||||
| // It removes a uniqueness lock acquired by the task, if any. | ||||
| func (r *RDB) Done(msg *base.TaskMessage) error { | ||||
| 	bytes, err := json.Marshal(msg) | ||||
| 	if err != nil { | ||||
| @@ -140,8 +189,8 @@ func (r *RDB) Done(msg *base.TaskMessage) error { | ||||
| 	processedKey := base.ProcessedKey(now) | ||||
| 	expireAt := now.Add(statsTTL) | ||||
| 	return doneCmd.Run(r.client, | ||||
| 		[]string{base.InProgressQueue, processedKey}, | ||||
| 		bytes, expireAt.Unix()).Err() | ||||
| 		[]string{base.InProgressQueue, processedKey, msg.UniqueKey}, | ||||
| 		bytes, expireAt.Unix(), msg.ID.String()).Err() | ||||
| } | ||||
|  | ||||
| // KEYS[1] -> asynq:in_progress | ||||
| @@ -164,15 +213,71 @@ func (r *RDB) Requeue(msg *base.TaskMessage) error { | ||||
| 		string(bytes)).Err() | ||||
| } | ||||
|  | ||||
| // KEYS[1] -> asynq:scheduled | ||||
| // KEYS[2] -> asynq:queues | ||||
| // ARGV[1] -> score (process_at timestamp) | ||||
| // ARGV[2] -> task message | ||||
| // ARGV[3] -> queue key | ||||
| var scheduleCmd = redis.NewScript(` | ||||
| redis.call("ZADD", KEYS[1], ARGV[1], ARGV[2]) | ||||
| redis.call("SADD", KEYS[2], ARGV[3]) | ||||
| return 1 | ||||
| `) | ||||
|  | ||||
| // Schedule adds the task to the backlog queue to be processed in the future. | ||||
| func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { | ||||
| 	bytes, err := json.Marshal(msg) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	qkey := base.QueueKey(msg.Queue) | ||||
| 	score := float64(processAt.Unix()) | ||||
| 	return r.client.ZAdd(base.ScheduledQueue, | ||||
| 		&redis.Z{Member: string(bytes), Score: score}).Err() | ||||
| 	return scheduleCmd.Run(r.client, | ||||
| 		[]string{base.ScheduledQueue, base.AllQueues}, | ||||
| 		score, bytes, qkey).Err() | ||||
| } | ||||
|  | ||||
| // KEYS[1] -> unique key in the format <type>:<payload>:<qname> | ||||
| // KEYS[2] -> asynq:scheduled | ||||
| // KEYS[3] -> asynq:queues | ||||
| // ARGV[1] -> task ID | ||||
| // ARGV[2] -> uniqueness lock TTL | ||||
| // ARGV[3] -> score (process_at timestamp) | ||||
| // ARGV[4] -> task message | ||||
| // ARGV[5] -> queue key | ||||
| var scheduleUniqueCmd = redis.NewScript(` | ||||
| local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2]) | ||||
| if not ok then | ||||
|   return 0 | ||||
| end | ||||
| redis.call("ZADD", KEYS[2], ARGV[3], ARGV[4]) | ||||
| redis.call("SADD", KEYS[3], ARGV[5]) | ||||
| return 1 | ||||
| `) | ||||
|  | ||||
| // Schedule adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired. | ||||
| // It returns ErrDuplicateTask if the lock cannot be acquired. | ||||
| func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { | ||||
| 	bytes, err := json.Marshal(msg) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	qkey := base.QueueKey(msg.Queue) | ||||
| 	score := float64(processAt.Unix()) | ||||
| 	res, err := scheduleUniqueCmd.Run(r.client, | ||||
| 		[]string{msg.UniqueKey, base.ScheduledQueue, base.AllQueues}, | ||||
| 		msg.ID.String(), int(ttl.Seconds()), score, bytes, qkey).Result() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	n, ok := res.(int64) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("could not cast %v to int64", res) | ||||
| 	} | ||||
| 	if n == 0 { | ||||
| 		return ErrDuplicateTask | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // KEYS[1] -> asynq:in_progress | ||||
|   | ||||
| @@ -16,6 +16,7 @@ import ( | ||||
| 	"github.com/google/go-cmp/cmp/cmpopts" | ||||
| 	h "github.com/hibiken/asynq/internal/asynqtest" | ||||
| 	"github.com/hibiken/asynq/internal/base" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
|  | ||||
| // TODO(hibiken): Get Redis address and db number from ENV variables. | ||||
| @@ -69,6 +70,48 @@ func TestEnqueue(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEnqueueUnique(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	m1 := base.TaskMessage{ | ||||
| 		ID:        xid.New(), | ||||
| 		Type:      "email", | ||||
| 		Payload:   map[string]interface{}{"user_id": 123}, | ||||
| 		Queue:     base.DefaultQueueName, | ||||
| 		UniqueKey: "email:user_id=123:default", | ||||
| 	} | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		msg *base.TaskMessage | ||||
| 		ttl time.Duration // uniqueness ttl | ||||
| 	}{ | ||||
| 		{&m1, time.Minute}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		h.FlushDB(t, r.client) // clean up db before each test case. | ||||
|  | ||||
| 		err := r.EnqueueUnique(tc.msg, tc.ttl) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil", | ||||
| 				tc.msg, tc.ttl, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		got := r.EnqueueUnique(tc.msg, tc.ttl) | ||||
| 		if got != ErrDuplicateTask { | ||||
| 			t.Errorf("Second message: (*RDB).EnqueueUnique(%v, %v) = %v, want %v", | ||||
| 				tc.msg, tc.ttl, got, ErrDuplicateTask) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		gotTTL := r.client.TTL(tc.msg.UniqueKey).Val() | ||||
| 		if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { | ||||
| 			t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl) | ||||
| 			continue | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDequeue(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello!"}) | ||||
| @@ -188,6 +231,13 @@ func TestDone(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	t1 := h.NewTaskMessage("send_email", nil) | ||||
| 	t2 := h.NewTaskMessage("export_csv", nil) | ||||
| 	t3 := &base.TaskMessage{ | ||||
| 		ID:        xid.New(), | ||||
| 		Type:      "reindex", | ||||
| 		Payload:   nil, | ||||
| 		UniqueKey: "reindex:nil:default", | ||||
| 		Queue:     "default", | ||||
| 	} | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		inProgress     []*base.TaskMessage // initial state of the in-progress list | ||||
| @@ -204,11 +254,25 @@ func TestDone(t *testing.T) { | ||||
| 			target:         t1, | ||||
| 			wantInProgress: []*base.TaskMessage{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			inProgress:     []*base.TaskMessage{t1, t2, t3}, | ||||
| 			target:         t3, | ||||
| 			wantInProgress: []*base.TaskMessage{t1, t2}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		h.FlushDB(t, r.client) // clean up db before each test case | ||||
| 		h.SeedInProgressQueue(t, r.client, tc.inProgress) | ||||
| 		for _, msg := range tc.inProgress { | ||||
| 			// Set uniqueness lock if unique key is present. | ||||
| 			if len(msg.UniqueKey) > 0 { | ||||
| 				err := r.client.SetNX(msg.UniqueKey, msg.ID.String(), time.Minute).Err() | ||||
| 				if err != nil { | ||||
| 					t.Fatal(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		err := r.Done(tc.target) | ||||
| 		if err != nil { | ||||
| @@ -232,6 +296,10 @@ func TestDone(t *testing.T) { | ||||
| 		if gotTTL > statsTTL { | ||||
| 			t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL) | ||||
| 		} | ||||
|  | ||||
| 		if len(tc.target.UniqueKey) > 0 && r.client.Exists(tc.target.UniqueKey).Val() != 0 { | ||||
| 			t.Errorf("Uniqueness lock %q still exists", tc.target.UniqueKey) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -344,6 +412,58 @@ func TestSchedule(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestScheduleUnique(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	m1 := base.TaskMessage{ | ||||
| 		ID:        xid.New(), | ||||
| 		Type:      "email", | ||||
| 		Payload:   map[string]interface{}{"user_id": 123}, | ||||
| 		Queue:     base.DefaultQueueName, | ||||
| 		UniqueKey: "email:user_id=123:default", | ||||
| 	} | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		msg       *base.TaskMessage | ||||
| 		processAt time.Time | ||||
| 		ttl       time.Duration // uniqueness lock ttl | ||||
| 	}{ | ||||
| 		{&m1, time.Now().Add(15 * time.Minute), time.Minute}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		h.FlushDB(t, r.client) // clean up db before each test case | ||||
|  | ||||
| 		desc := fmt.Sprintf("(*RDB).ScheduleUnique(%v, %v, %v)", tc.msg, tc.processAt, tc.ttl) | ||||
| 		err := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Frist task: %s = %v, want nil", desc, err) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		gotScheduled := h.GetScheduledEntries(t, r.client) | ||||
| 		if len(gotScheduled) != 1 { | ||||
| 			t.Errorf("%s inserted %d items to %q, want 1 items inserted", desc, len(gotScheduled), base.ScheduledQueue) | ||||
| 			continue | ||||
| 		} | ||||
| 		if int64(gotScheduled[0].Score) != tc.processAt.Unix() { | ||||
| 			t.Errorf("%s inserted an item with score %d, want %d", desc, int64(gotScheduled[0].Score), tc.processAt.Unix()) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		got := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl) | ||||
| 		if got != ErrDuplicateTask { | ||||
| 			t.Errorf("Second task: %s = %v, want %v", | ||||
| 				desc, got, ErrDuplicateTask) | ||||
| 		} | ||||
|  | ||||
| 		gotTTL := r.client.TTL(tc.msg.UniqueKey).Val() | ||||
| 		if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { | ||||
| 			t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl) | ||||
| 			continue | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRetry(t *testing.T) { | ||||
| 	r := setup(t) | ||||
| 	t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "Hola!"}) | ||||
| @@ -784,8 +904,7 @@ func TestWriteProcessState(t *testing.T) { | ||||
| 	} | ||||
| 	// Check ProcessInfo TTL was set correctly | ||||
| 	gotTTL := r.client.TTL(pkey).Val() | ||||
| 	timeCmpOpt := cmpopts.EquateApproxTime(time.Second) | ||||
| 	if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { | ||||
| 	if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { | ||||
| 		t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) | ||||
| 	} | ||||
| 	// Check ProcessInfo key was added to the set correctly | ||||
| @@ -858,8 +977,7 @@ func TestWriteProcessStateWithWorkers(t *testing.T) { | ||||
| 	} | ||||
| 	// Check ProcessInfo TTL was set correctly | ||||
| 	gotTTL := r.client.TTL(pkey).Val() | ||||
| 	timeCmpOpt := cmpopts.EquateApproxTime(time.Second) | ||||
| 	if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { | ||||
| 	if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { | ||||
| 		t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) | ||||
| 	} | ||||
| 	// Check ProcessInfo key was added to the set correctly | ||||
|   | ||||
		Reference in New Issue
	
	Block a user