From 86c772f6c13fe5da87405683db23153e9b2dab52 Mon Sep 17 00:00:00 2001 From: kanzihuang Date: Wed, 1 May 2024 23:34:03 +0800 Subject: [PATCH] fix: stop active tasks before server shutdown --- internal/timeutil/sleep_test.go | 6 ++-- processor.go | 52 ++++++++++++++++++++++----------- processor_test.go | 3 +- server_test.go | 4 +-- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/internal/timeutil/sleep_test.go b/internal/timeutil/sleep_test.go index e6e6aae..97d13d3 100644 --- a/internal/timeutil/sleep_test.go +++ b/internal/timeutil/sleep_test.go @@ -2,7 +2,7 @@ package timeutil import ( "context" - "github.com/stretchr/testify/require" + "github.com/hibiken/asynq/internal/errors" "sync" "testing" "time" @@ -37,7 +37,9 @@ func TestSleep(t *testing.T) { go func() { defer wg.Done() err := Sleep(ctx, tc.sleep) - require.ErrorIs(t, tc.wantErr, err) + if !errors.Is(err, tc.wantErr) { + t.Errorf("timeutil.Sleep: got %v, want %v", err, tc.wantErr) + } }() time.Sleep(20 * time.Millisecond) cancel() diff --git a/processor.go b/processor.go index 4c6471a..20e9534 100644 --- a/processor.go +++ b/processor.go @@ -62,6 +62,9 @@ type processor struct { // quit channel is closed when the shutdown of the "processor" goroutine starts. quit chan struct{} + // terminate channel is closed when the shutdown of the "processor" goroutine starts. + terminate chan struct{} + // abort channel communicates to the in-flight worker goroutines to stop. abort chan struct{} @@ -113,6 +116,7 @@ func newProcessor(params processorParams) *processor { sema: make(chan struct{}, params.concurrency), done: make(chan struct{}), quit: make(chan struct{}), + terminate: make(chan struct{}), abort: make(chan struct{}), errHandler: params.errHandler, handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), @@ -139,6 +143,7 @@ func (p *processor) stop() { func (p *processor) shutdown() { p.stop() + close(p.terminate) time.AfterFunc(p.shutdownTimeout, func() { close(p.abort) }) p.logger.Info("Waiting for all workers to finish...") @@ -232,25 +237,38 @@ func (p *processor) exec() { resCh <- p.perform(ctx, task) }() - select { - case <-p.abort: - // time is up, push the message back to queue and quit this worker goroutine. - p.logger.Warnf("Quitting worker. task id=%s", msg.ID) - p.requeue(lease, msg) - return - case <-lease.Done(): - cancel() - p.handleFailedMessage(ctx, lease, msg, ErrLeaseExpired) - return - case <-ctx.Done(): - p.handleFailedMessage(ctx, lease, msg, ctx.Err()) - return - case resErr := <-resCh: - if resErr != nil { - p.handleFailedMessage(ctx, lease, msg, resErr) + var leaseDone, terminated bool + for { + select { + case <-p.terminate: + cancel() + case <-lease.Done(): + leaseDone = true + cancel() + case <-p.abort: + // time is up, push the message back to queue and quit this worker goroutine. + p.logger.Warnf("Quitting worker. task id=%s", msg.ID) + p.requeue(lease, msg) + return + case resErr := <-resCh: + switch { + case resErr == nil: + p.handleSucceededMessage(lease, msg) + case errors.Is(resErr, context.Canceled): + switch { + case leaseDone: + p.handleFailedMessage(ctx, lease, msg, ErrLeaseExpired) + case terminated: + p.logger.Warnf("Quitting worker. task id=%s", msg.ID) + p.requeue(lease, msg) + default: + p.handleFailedMessage(ctx, lease, msg, resErr) + } + default: + p.handleFailedMessage(ctx, lease, msg, resErr) + } return } - p.handleSucceededMessage(lease, msg) } }() } diff --git a/processor_test.go b/processor_test.go index 9be4729..280396a 100644 --- a/processor_test.go +++ b/processor_test.go @@ -505,8 +505,7 @@ func TestProcessorWithExpiredLease(t *testing.T) { handler: HandlerFunc(func(ctx context.Context, task *Task) error { // make sure the task processing time exceeds lease duration // to test expired lease. - time.Sleep(rdb.LeaseDuration + 10*time.Second) - return nil + return timeutil.Sleep(ctx, rdb.LeaseDuration+10*time.Second) }), wantErrCount: 1, // ErrorHandler should still be called with ErrLeaseExpired }, diff --git a/server_test.go b/server_test.go index 5e35584..f3c2cbf 100644 --- a/server_test.go +++ b/server_test.go @@ -7,8 +7,6 @@ package asynq import ( "context" "fmt" - "github.com/hibiken/asynq/internal/timeutil" - "github.com/redis/go-redis/v9" "syscall" "testing" "time" @@ -16,6 +14,8 @@ import ( "github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/testbroker" "github.com/hibiken/asynq/internal/testutil" + "github.com/hibiken/asynq/internal/timeutil" + "github.com/redis/go-redis/v9" "go.uber.org/goleak" )