diff --git a/background.go b/background.go index 48696b7..e878cbc 100644 --- a/background.go +++ b/background.go @@ -40,6 +40,7 @@ type Background struct { processor *processor syncer *syncer heartbeater *heartbeater + subscriber *subscriber } // Config specifies the background-task processing behavior. @@ -120,10 +121,12 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background { pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority) rdb := rdb.NewRDB(createRedisClient(r)) syncRequestCh := make(chan *syncRequest) + cancelations := base.NewCancelations() syncer := newSyncer(syncRequestCh, 5*time.Second) heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second) scheduler := newScheduler(rdb, 5*time.Second, queues) - processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh) + processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh, cancelations) + subscriber := newSubscriber(rdb, cancelations) return &Background{ pinfo: pinfo, rdb: rdb, @@ -131,6 +134,7 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background { processor: processor, syncer: syncer, heartbeater: heartbeater, + subscriber: subscriber, } } @@ -198,6 +202,7 @@ func (bg *Background) start(handler Handler) { bg.processor.handler = handler bg.heartbeater.start() + bg.subscriber.start() bg.syncer.start() bg.scheduler.start() bg.processor.start() @@ -216,6 +221,7 @@ func (bg *Background) stop() { // Note: processor and all worker goroutines need to be exited // before shutting down syncer to avoid goroutine leak. bg.syncer.terminate() + bg.subscriber.terminate() bg.heartbeater.terminate() bg.rdb.ClearProcessInfo(bg.pinfo) diff --git a/internal/base/base.go b/internal/base/base.go index 71932f7..a2959d7 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -6,6 +6,7 @@ package base import ( + "context" "fmt" "strings" "sync" @@ -30,6 +31,7 @@ const ( RetryQueue = "asynq:retry" // ZSET DeadQueue = "asynq:dead" // ZSET InProgressQueue = "asynq:in_progress" // LIST + CancelChannel = "asynq:cancel" // PubSub channel ) // QueueKey returns a redis key string for the given queue name. @@ -129,3 +131,50 @@ func (p *ProcessInfo) IncrActiveWorkerCount(delta int) { defer p.mu.Unlock() p.ActiveWorkerCount += delta } + +// Cancelations hold cancel functions for all in-progress tasks. +// +// Its methods are safe to be used in multiple concurrent goroutines +type Cancelations struct { + mu sync.Mutex + cancelFuncs map[string]context.CancelFunc +} + +// NewCancelations returns a Cancelations instance. +func NewCancelations() *Cancelations { + return &Cancelations{ + cancelFuncs: make(map[string]context.CancelFunc), + } +} + +// Add adds a new cancel func to the set. +func (c *Cancelations) Add(id string, fn context.CancelFunc) { + c.mu.Lock() + defer c.mu.Unlock() + c.cancelFuncs[id] = fn +} + +// Delete deletes a cancel func from the set given an id. +func (c *Cancelations) Delete(id string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.cancelFuncs, id) +} + +// Get returns a cancel func given an id. +func (c *Cancelations) Get(id string) context.CancelFunc { + c.mu.Lock() + defer c.mu.Unlock() + return c.cancelFuncs[id] +} + +// GetAll returns all cancel funcs. +func (c *Cancelations) GetAll() []context.CancelFunc { + c.mu.Lock() + defer c.mu.Unlock() + var res []context.CancelFunc + for _, fn := range c.cancelFuncs { + res = append(res, fn) + } + return res +} diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index b8e1259..fc719bc 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -410,3 +410,19 @@ func (r *RDB) ClearProcessInfo(ps *base.ProcessInfo) error { key := base.ProcessInfoKey(ps.Host, ps.PID) return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err() } + +// CancelationPubSub returns a pubsub for cancelation messages. +func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { + pubsub := r.client.Subscribe(base.CancelChannel) + _, err := pubsub.Receive() + if err != nil { + return nil, err + } + return pubsub, nil +} + +// PublishCancelation publish cancelation message to all subscribers. +// The message is a string representing the task to be canceled. +func (r *RDB) PublishCancelation(id string) error { + return r.client.Publish(base.CancelChannel, id).Err() +} diff --git a/processor.go b/processor.go index 656c3d6..b91c4ca 100644 --- a/processor.go +++ b/processor.go @@ -14,7 +14,6 @@ import ( "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" - "github.com/rs/xid" "golang.org/x/time/rate" ) @@ -53,15 +52,14 @@ type processor struct { // quit channel communicates to the in-flight worker goroutines to stop. quit chan struct{} - // cancelFuncs is a map of task ID to cancel function for all in-progress tasks. - mu sync.Mutex - cancelFuncs map[string]context.CancelFunc + // cancelations is a set of cancel functions for all in-progress tasks. + cancelations *base.Cancelations } type retryDelayFunc func(n int, err error, task *Task) time.Duration // newProcessor constructs a new processor. -func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest) *processor { +func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest, cancelations *base.Cancelations) *processor { qcfg := normalizeQueueCfg(pinfo.Queues) orderedQueues := []string(nil) if pinfo.StrictPriority { @@ -74,12 +72,12 @@ func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRe orderedQueues: orderedQueues, retryDelayFunc: fn, syncRequestCh: syncRequestCh, + cancelations: cancelations, errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1), sema: make(chan struct{}, pinfo.Concurrency), done: make(chan struct{}), abort: make(chan struct{}), quit: make(chan struct{}), - cancelFuncs: make(map[string]context.CancelFunc), handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), } } @@ -107,7 +105,7 @@ func (p *processor) terminate() { logger.info("Waiting for all workers to finish...") // send cancellation signal to all in-progress task handlers - for _, cancel := range p.cancelFuncs { + for _, cancel := range p.cancelations.GetAll() { cancel() } @@ -174,10 +172,10 @@ func (p *processor) exec() { resCh := make(chan error, 1) task := NewTask(msg.Type, msg.Payload) ctx, cancel := createContext(msg) - p.addCancelFunc(msg.ID, cancel) + p.cancelations.Add(msg.ID.String(), cancel) go func() { resCh <- perform(ctx, task, p.handler) - p.deleteCancelFunc(msg.ID) + p.cancelations.Delete(msg.ID.String()) }() select { @@ -268,18 +266,6 @@ func (p *processor) kill(msg *base.TaskMessage, e error) { } } -func (p *processor) addCancelFunc(id xid.ID, fn context.CancelFunc) { - p.mu.Lock() - defer p.mu.Unlock() - p.cancelFuncs[id.String()] = fn -} - -func (p *processor) deleteCancelFunc(id xid.ID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.cancelFuncs, id.String()) -} - // queues returns a list of queues to query. // Order of the queue names is based on the priority of each queue. // Queue names is sorted by their priority level if strict-priority is true. diff --git a/processor_test.go b/processor_test.go index 25b0abe..1bc5d99 100644 --- a/processor_test.go +++ b/processor_test.go @@ -67,7 +67,8 @@ func TestProcessorSuccess(t *testing.T) { return nil } pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false) - p := newProcessor(rdbClient, pi, defaultDelayFunc, nil) + cancelations := base.NewCancelations() + p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations) p.handler = HandlerFunc(handler) p.start() @@ -151,7 +152,8 @@ func TestProcessorRetry(t *testing.T) { return fmt.Errorf(errMsg) } pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false) - p := newProcessor(rdbClient, pi, delayFunc, nil) + cancelations := base.NewCancelations() + p := newProcessor(rdbClient, pi, delayFunc, nil, cancelations) p.handler = HandlerFunc(handler) p.start() @@ -211,7 +213,8 @@ func TestProcessorQueues(t *testing.T) { for _, tc := range tests { pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false) - p := newProcessor(nil, pi, defaultDelayFunc, nil) + cancelations := base.NewCancelations() + p := newProcessor(nil, pi, defaultDelayFunc, nil, cancelations) got := p.queues() if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" { t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s", @@ -278,7 +281,8 @@ func TestProcessorWithStrictPriority(t *testing.T) { } // Note: Set concurrency to 1 to make sure tasks are processed one at a time. pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */) - p := newProcessor(rdbClient, pi, defaultDelayFunc, nil) + cancelations := base.NewCancelations() + p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations) p.handler = HandlerFunc(handler) p.start() diff --git a/subscriber.go b/subscriber.go new file mode 100644 index 0000000..220116c --- /dev/null +++ b/subscriber.go @@ -0,0 +1,58 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/rdb" +) + +type subscriber struct { + rdb *rdb.RDB + + // channel to communicate back to the long running "subscriber" goroutine. + done chan struct{} + + // cancelations hold cancel functions for all in-progress tasks. + cancelations *base.Cancelations +} + +func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber { + return &subscriber{ + rdb: rdb, + done: make(chan struct{}), + cancelations: cancelations, + } +} + +func (s *subscriber) terminate() { + logger.info("Subscriber shutting down...") + // Signal the subscriber goroutine to stop. + s.done <- struct{}{} +} + +func (s *subscriber) start() { + pubsub, err := s.rdb.CancelationPubSub() + cancelCh := pubsub.Channel() + if err != nil { + logger.error("cannot subscribe to cancelation channel: %v", err) + return + } + go func() { + for { + select { + case <-s.done: + pubsub.Close() + logger.info("Subscriber done") + return + case msg := <-cancelCh: + cancel := s.cancelations.Get(msg.Payload) + if cancel != nil { + cancel() + } + } + } + }() +} diff --git a/subscriber_test.go b/subscriber_test.go new file mode 100644 index 0000000..2b7f5a0 --- /dev/null +++ b/subscriber_test.go @@ -0,0 +1,57 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "testing" + "time" + + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/rdb" +) + +func TestSubscriber(t *testing.T) { + r := setup(t) + rdbClient := rdb.NewRDB(r) + + tests := []struct { + registeredID string // ID for which cancel func is registered + publishID string // ID to be published + wantCalled bool // whether cancel func should be called + }{ + {"abc123", "abc123", true}, + {"abc456", "abc123", false}, + } + + for _, tc := range tests { + called := false + fakeCancelFunc := func() { + called = true + } + cancelations := base.NewCancelations() + cancelations.Add(tc.registeredID, fakeCancelFunc) + + subscriber := newSubscriber(rdbClient, cancelations) + subscriber.start() + + if err := rdbClient.PublishCancelation(tc.publishID); err != nil { + subscriber.terminate() + t.Fatalf("could not publish cancelation message: %v", err) + } + + // allow for redis to publish message + time.Sleep(time.Second) + + if called != tc.wantCalled { + if tc.wantCalled { + t.Errorf("fakeCancelFunc was not called, want the function to be called") + } else { + t.Errorf("fakeCancelFunc was called, want the function to not be called") + } + } + + subscriber.terminate() + } +}