diff --git a/processor.go b/processor.go index bf6ece4..4235d85 100644 --- a/processor.go +++ b/processor.go @@ -14,6 +14,7 @@ import ( "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" + "github.com/rs/xid" "golang.org/x/time/rate" ) @@ -51,6 +52,10 @@ type processor struct { // quit channel communicates to the in-flight worker goroutines to stop. quit chan struct{} + + // cancelFuncs is a map of task ID to cancel function for all in-progress tasks. + mu sync.Mutex + cancelFuncs map[string]context.CancelFunc } type retryDelayFunc func(n int, err error, task *Task) time.Duration @@ -74,6 +79,7 @@ func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRe done: make(chan struct{}), abort: make(chan struct{}), quit: make(chan struct{}), + cancelFuncs: make(map[string]context.CancelFunc), handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), } } @@ -99,6 +105,12 @@ func (p *processor) terminate() { const timeout = 8 * time.Second time.AfterFunc(timeout, func() { close(p.quit) }) logger.info("Waiting for all workers to finish...") + + // send cancellation signal to all in-progress task handlers + for _, cancel := range p.cancelFuncs { + cancel() + } + // block until all workers have released the token for i := 0; i < cap(p.sema); i++ { p.sema <- struct{}{} @@ -162,9 +174,11 @@ func (p *processor) exec() { resCh := make(chan error, 1) task := NewTask(msg.Type, msg.Payload) // TODO: Set timeout if provided - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + p.addCancelFunc(msg.ID, cancel) go func() { resCh <- perform(ctx, task, p.handler) + p.deleteCancelFunc(msg.ID) }() select { @@ -255,6 +269,18 @@ func (p *processor) kill(msg *base.TaskMessage, e error) { } } +func (p *processor) addCancelFunc(id xid.ID, fn context.CancelFunc) { + p.mu.Lock() + defer p.mu.Unlock() + p.cancelFuncs[id.String()] = fn +} + +func (p *processor) deleteCancelFunc(id xid.ID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.cancelFuncs, id.String()) +} + // queues returns a list of queues to query. // Order of the queue names is based on the priority of each queue. // Queue names is sorted by their priority level if strict-priority is true.