diff --git a/internal/base/base.go b/internal/base/base.go index 5b0023b..e3b57ce 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -663,6 +663,7 @@ type Broker interface { EnqueueUnique(msg *TaskMessage, ttl time.Duration) error Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Done(msg *TaskMessage) error + MarkAsComplete(msg *TaskMessage) error Requeue(msg *TaskMessage) error Schedule(msg *TaskMessage, processAt time.Time) error ScheduleUnique(msg *TaskMessage, processAt time.Time, ttl time.Duration) error diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 735c08c..bd1f650 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -81,6 +81,15 @@ func (tb *TestBroker) Done(msg *base.TaskMessage) error { return tb.real.Done(msg) } +func (tb *TestBroker) MarkAsComplete(msg *base.TaskMessage) error { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return errRedisDown + } + return tb.real.MarkAsComplete(msg) +} + func (tb *TestBroker) Requeue(msg *base.TaskMessage) error { tb.mu.Lock() defer tb.mu.Unlock() diff --git a/processor.go b/processor.go index e4c50b1..5bf35be 100644 --- a/processor.go +++ b/processor.go @@ -201,7 +201,7 @@ func (p *processor) exec() { select { case <-ctx.Done(): // already canceled (e.g. deadline exceeded). - p.retryOrArchive(ctx, msg, ctx.Err()) + p.handleFailedMessage(ctx, msg, ctx.Err()) return default: } @@ -218,18 +218,14 @@ func (p *processor) exec() { p.requeue(msg) return case <-ctx.Done(): - p.retryOrArchive(ctx, msg, ctx.Err()) + p.handleFailedMessage(ctx, msg, ctx.Err()) return case resErr := <-resCh: - // Note: One of three things should happen. - // 1) Done -> Removes the message from Active - // 2) Retry -> Removes the message from Active & Adds the message to Retry - // 3) Archive -> Removes the message from Active & Adds the message to archive if resErr != nil { - p.retryOrArchive(ctx, msg, resErr) + p.handleFailedMessage(ctx, msg, resErr) return } - p.markAsDone(ctx, msg) + p.handleSucceededMessage(ctx, msg) } }() } @@ -244,6 +240,34 @@ func (p *processor) requeue(msg *base.TaskMessage) { } } +func (p *processor) handleSucceededMessage(ctx context.Context, msg *base.TaskMessage) { + if msg.ResultTTL > 0 { + p.markAsComplete(ctx, msg) + } else { + p.markAsDone(ctx, msg) + } +} + +func (p *processor) markAsComplete(ctx context.Context, msg *base.TaskMessage) { + err := p.broker.MarkAsComplete(msg) + if err != nil { + errMsg := fmt.Sprintf("Could not move task id=%s type=%q from %q to %q: %+v", + msg.ID, msg.Type, base.ActiveKey(msg.Queue), base.CompletedKey(msg.Queue), err) + deadline, ok := ctx.Deadline() + if !ok { + panic("asynq: internal error: missing deadline in context") + } + p.logger.Warnf("%s; Will retry syncing", errMsg) + p.syncRequestCh <- &syncRequest{ + fn: func() error { + return p.broker.MarkAsComplete(msg) + }, + errMsg: errMsg, + deadline: deadline, + } + } +} + func (p *processor) markAsDone(ctx context.Context, msg *base.TaskMessage) { err := p.broker.Done(msg) if err != nil { @@ -267,7 +291,7 @@ func (p *processor) markAsDone(ctx context.Context, msg *base.TaskMessage) { // the task should not be retried and should be archived instead. var SkipRetry = errors.New("skip retry for the task") -func (p *processor) retryOrArchive(ctx context.Context, msg *base.TaskMessage, err error) { +func (p *processor) handleFailedMessage(ctx context.Context, msg *base.TaskMessage, err error) { if p.errHandler != nil { p.errHandler.HandleError(ctx, NewTask(msg.Type, msg.Payload), err) } diff --git a/processor_test.go b/processor_test.go index 3708ecf..b545003 100644 --- a/processor_test.go +++ b/processor_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" h "github.com/hibiken/asynq/internal/asynqtest" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" @@ -42,6 +43,34 @@ func fakeSyncer(syncCh <-chan *syncRequest, done <-chan struct{}) { } } +// Returns a processor instance configured for testing purpose. +func newProcessorForTest(t *testing.T, r *rdb.RDB, h Handler) *processor { + starting := make(chan *workerInfo) + finished := make(chan *base.TaskMessage) + syncCh := make(chan *syncRequest) + done := make(chan struct{}) + t.Cleanup(func() { close(done) }) + go fakeHeartbeater(starting, finished, done) + go fakeSyncer(syncCh, done) + p := newProcessor(processorParams{ + logger: testLogger, + broker: r, + retryDelayFunc: DefaultRetryDelayFunc, + isFailureFunc: defaultIsFailureFunc, + syncCh: syncCh, + cancelations: base.NewCancelations(), + concurrency: 10, + queues: defaultQueueConfig, + strictPriority: false, + errHandler: nil, + shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, + }) + p.handler = h + return p +} + func TestProcessorSuccessWithSingleQueue(t *testing.T) { r := setup(t) defer r.Close() @@ -87,29 +116,7 @@ func TestProcessorSuccessWithSingleQueue(t *testing.T) { processed = append(processed, task) return nil } - starting := make(chan *workerInfo) - finished := make(chan *base.TaskMessage) - syncCh := make(chan *syncRequest) - done := make(chan struct{}) - defer func() { close(done) }() - go fakeHeartbeater(starting, finished, done) - go fakeSyncer(syncCh, done) - p := newProcessor(processorParams{ - logger: testLogger, - broker: rdbClient, - retryDelayFunc: DefaultRetryDelayFunc, - isFailureFunc: defaultIsFailureFunc, - syncCh: syncCh, - cancelations: base.NewCancelations(), - concurrency: 10, - queues: defaultQueueConfig, - strictPriority: false, - errHandler: nil, - shutdownTimeout: defaultShutdownTimeout, - starting: starting, - finished: finished, - }) - p.handler = HandlerFunc(handler) + p := newProcessorForTest(t, rdbClient, HandlerFunc(handler)) p.start(&sync.WaitGroup{}) for _, msg := range tc.incoming { @@ -180,33 +187,12 @@ func TestProcessorSuccessWithMultipleQueues(t *testing.T) { processed = append(processed, task) return nil } - starting := make(chan *workerInfo) - finished := make(chan *base.TaskMessage) - syncCh := make(chan *syncRequest) - done := make(chan struct{}) - defer func() { close(done) }() - go fakeHeartbeater(starting, finished, done) - go fakeSyncer(syncCh, done) - p := newProcessor(processorParams{ - logger: testLogger, - broker: rdbClient, - retryDelayFunc: DefaultRetryDelayFunc, - isFailureFunc: defaultIsFailureFunc, - syncCh: syncCh, - cancelations: base.NewCancelations(), - concurrency: 10, - queues: map[string]int{ - "default": 2, - "high": 3, - "low": 1, - }, - strictPriority: false, - errHandler: nil, - shutdownTimeout: defaultShutdownTimeout, - starting: starting, - finished: finished, - }) - p.handler = HandlerFunc(handler) + p := newProcessorForTest(t, rdbClient, HandlerFunc(handler)) + p.queueConfig = map[string]int{ + "default": 2, + "high": 3, + "low": 1, + } p.start(&sync.WaitGroup{}) // Wait for two second to allow all pending tasks to be processed. @@ -267,29 +253,7 @@ func TestProcessTasksWithLargeNumberInPayload(t *testing.T) { processed = append(processed, task) return nil } - starting := make(chan *workerInfo) - finished := make(chan *base.TaskMessage) - syncCh := make(chan *syncRequest) - done := make(chan struct{}) - defer func() { close(done) }() - go fakeHeartbeater(starting, finished, done) - go fakeSyncer(syncCh, done) - p := newProcessor(processorParams{ - logger: testLogger, - broker: rdbClient, - retryDelayFunc: DefaultRetryDelayFunc, - isFailureFunc: defaultIsFailureFunc, - syncCh: syncCh, - cancelations: base.NewCancelations(), - concurrency: 10, - queues: defaultQueueConfig, - strictPriority: false, - errHandler: nil, - shutdownTimeout: defaultShutdownTimeout, - starting: starting, - finished: finished, - }) - p.handler = HandlerFunc(handler) + p := newProcessorForTest(t, rdbClient, HandlerFunc(handler)) p.start(&sync.WaitGroup{}) time.Sleep(2 * time.Second) // wait for two second to allow all pending tasks to be processed. @@ -389,27 +353,9 @@ func TestProcessorRetry(t *testing.T) { defer mu.Unlock() n++ } - starting := make(chan *workerInfo) - finished := make(chan *base.TaskMessage) - done := make(chan struct{}) - defer func() { close(done) }() - go fakeHeartbeater(starting, finished, done) - p := newProcessor(processorParams{ - logger: testLogger, - broker: rdbClient, - retryDelayFunc: delayFunc, - isFailureFunc: defaultIsFailureFunc, - syncCh: nil, - cancelations: base.NewCancelations(), - concurrency: 10, - queues: defaultQueueConfig, - strictPriority: false, - errHandler: ErrorHandlerFunc(errHandler), - shutdownTimeout: defaultShutdownTimeout, - starting: starting, - finished: finished, - }) - p.handler = tc.handler + p := newProcessorForTest(t, rdbClient, tc.handler) + p.errHandler = ErrorHandlerFunc(errHandler) + p.retryDelayFunc = delayFunc p.start(&sync.WaitGroup{}) runTime := time.Now() // time when processor is running @@ -453,6 +399,81 @@ func TestProcessorRetry(t *testing.T) { } } +func TestProcessorMarkAsComplete(t *testing.T) { + r := setup(t) + defer r.Close() + rdbClient := rdb.NewRDB(r) + + msg1 := h.NewTaskMessage("one", nil) + msg2 := h.NewTaskMessage("two", nil) + msg3 := h.NewTaskMessageWithQueue("three", nil, "custom") + msg1.ResultTTL = 3600 + msg3.ResultTTL = 7200 + + handler := func(ctx context.Context, task *Task) error { return nil } + + tests := []struct { + pending map[string][]*base.TaskMessage + completed map[string][]base.Z + queueCfg map[string]int + wantPending map[string][]*base.TaskMessage + wantCompleted func(completedAt time.Time) map[string][]base.Z + }{ + { + pending: map[string][]*base.TaskMessage{ + "default": {msg1, msg2}, + "custom": {msg3}, + }, + completed: map[string][]base.Z{ + "default": {}, + "custom": {}, + }, + queueCfg: map[string]int{ + "default": 1, + "custom": 1, + }, + wantPending: map[string][]*base.TaskMessage{ + "default": {}, + "custom": {}, + }, + wantCompleted: func(completedAt time.Time) map[string][]base.Z { + return map[string][]base.Z{ + "default": {{Message: h.TaskMessageWithCompletedAt(*msg1, completedAt), Score: completedAt.Unix() + msg1.ResultTTL}}, + "custom": {{Message: h.TaskMessageWithCompletedAt(*msg3, completedAt), Score: completedAt.Unix() + msg3.ResultTTL}}, + } + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) + h.SeedAllPendingQueues(t, r, tc.pending) + h.SeedAllCompletedQueues(t, r, tc.completed) + + p := newProcessorForTest(t, rdbClient, HandlerFunc(handler)) + p.queueConfig = tc.queueCfg + + p.start(&sync.WaitGroup{}) + runTime := time.Now() // time when processor is running + time.Sleep(2 * time.Second) + p.shutdown() + + for qname, want := range tc.wantPending { + gotPending := h.GetPendingMessages(t, r, qname) + if diff := cmp.Diff(want, gotPending, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("diff found in %q pending set; want=%v, got=%v\n%s", qname, want, gotPending, diff) + } + } + + for qname, want := range tc.wantCompleted(runTime) { + gotCompleted := h.GetCompletedEntries(t, r, qname) + if diff := cmp.Diff(want, gotCompleted, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("diff found in %q completed set; want=%v, got=%v\n%s", qname, want, gotCompleted, diff) + } + } + } +} + func TestProcessorQueues(t *testing.T) { sortOpt := cmp.Transformer("SortStrings", func(in []string) []string { out := append([]string(nil), in...) // Copy input to avoid mutating it @@ -481,26 +502,10 @@ func TestProcessorQueues(t *testing.T) { } for _, tc := range tests { - starting := make(chan *workerInfo) - finished := make(chan *base.TaskMessage) - done := make(chan struct{}) - defer func() { close(done) }() - go fakeHeartbeater(starting, finished, done) - p := newProcessor(processorParams{ - logger: testLogger, - broker: nil, - retryDelayFunc: DefaultRetryDelayFunc, - isFailureFunc: defaultIsFailureFunc, - syncCh: nil, - cancelations: base.NewCancelations(), - concurrency: 10, - queues: tc.queueCfg, - strictPriority: false, - errHandler: nil, - shutdownTimeout: defaultShutdownTimeout, - starting: starting, - finished: finished, - }) + // Note: rdb and handler not needed for this test. + p := newProcessorForTest(t, nil, nil) + p.queueConfig = tc.queueCfg + got := p.queues() if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" { t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s", @@ -644,12 +649,9 @@ func TestProcessorPerform(t *testing.T) { wantErr: true, }, } - // Note: We don't need to fully initialize the processor since we are only testing + // Note: We don't need to fully initialized the processor since we are only testing // perform method. - p := newProcessor(processorParams{ - logger: testLogger, - queues: defaultQueueConfig, - }) + p := newProcessorForTest(t, nil, nil) for _, tc := range tests { p.handler = tc.handler