mirror of
https://github.com/hibiken/asynq.git
synced 2024-11-10 11:31:58 +08:00
Add redis pubsub subscriber for cancelation
This commit is contained in:
parent
133bb6c2c6
commit
6685827147
@ -40,6 +40,7 @@ type Background struct {
|
||||
processor *processor
|
||||
syncer *syncer
|
||||
heartbeater *heartbeater
|
||||
subscriber *subscriber
|
||||
}
|
||||
|
||||
// Config specifies the background-task processing behavior.
|
||||
@ -120,10 +121,12 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
|
||||
pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority)
|
||||
rdb := rdb.NewRDB(createRedisClient(r))
|
||||
syncRequestCh := make(chan *syncRequest)
|
||||
cancelations := base.NewCancelations()
|
||||
syncer := newSyncer(syncRequestCh, 5*time.Second)
|
||||
heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second)
|
||||
scheduler := newScheduler(rdb, 5*time.Second, queues)
|
||||
processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh)
|
||||
processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh, cancelations)
|
||||
subscriber := newSubscriber(rdb, cancelations)
|
||||
return &Background{
|
||||
pinfo: pinfo,
|
||||
rdb: rdb,
|
||||
@ -131,6 +134,7 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
|
||||
processor: processor,
|
||||
syncer: syncer,
|
||||
heartbeater: heartbeater,
|
||||
subscriber: subscriber,
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,6 +202,7 @@ func (bg *Background) start(handler Handler) {
|
||||
bg.processor.handler = handler
|
||||
|
||||
bg.heartbeater.start()
|
||||
bg.subscriber.start()
|
||||
bg.syncer.start()
|
||||
bg.scheduler.start()
|
||||
bg.processor.start()
|
||||
@ -216,6 +221,7 @@ func (bg *Background) stop() {
|
||||
// Note: processor and all worker goroutines need to be exited
|
||||
// before shutting down syncer to avoid goroutine leak.
|
||||
bg.syncer.terminate()
|
||||
bg.subscriber.terminate()
|
||||
bg.heartbeater.terminate()
|
||||
|
||||
bg.rdb.ClearProcessInfo(bg.pinfo)
|
||||
|
@ -6,6 +6,7 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -30,6 +31,7 @@ const (
|
||||
RetryQueue = "asynq:retry" // ZSET
|
||||
DeadQueue = "asynq:dead" // ZSET
|
||||
InProgressQueue = "asynq:in_progress" // LIST
|
||||
CancelChannel = "asynq:cancel" // PubSub channel
|
||||
)
|
||||
|
||||
// QueueKey returns a redis key string for the given queue name.
|
||||
@ -129,3 +131,50 @@ func (p *ProcessInfo) IncrActiveWorkerCount(delta int) {
|
||||
defer p.mu.Unlock()
|
||||
p.ActiveWorkerCount += delta
|
||||
}
|
||||
|
||||
// Cancelations hold cancel functions for all in-progress tasks.
|
||||
//
|
||||
// Its methods are safe to be used in multiple concurrent goroutines
|
||||
type Cancelations struct {
|
||||
mu sync.Mutex
|
||||
cancelFuncs map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
// NewCancelations returns a Cancelations instance.
|
||||
func NewCancelations() *Cancelations {
|
||||
return &Cancelations{
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new cancel func to the set.
|
||||
func (c *Cancelations) Add(id string, fn context.CancelFunc) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cancelFuncs[id] = fn
|
||||
}
|
||||
|
||||
// Delete deletes a cancel func from the set given an id.
|
||||
func (c *Cancelations) Delete(id string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.cancelFuncs, id)
|
||||
}
|
||||
|
||||
// Get returns a cancel func given an id.
|
||||
func (c *Cancelations) Get(id string) context.CancelFunc {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.cancelFuncs[id]
|
||||
}
|
||||
|
||||
// GetAll returns all cancel funcs.
|
||||
func (c *Cancelations) GetAll() []context.CancelFunc {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var res []context.CancelFunc
|
||||
for _, fn := range c.cancelFuncs {
|
||||
res = append(res, fn)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -410,3 +410,19 @@ func (r *RDB) ClearProcessInfo(ps *base.ProcessInfo) error {
|
||||
key := base.ProcessInfoKey(ps.Host, ps.PID)
|
||||
return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err()
|
||||
}
|
||||
|
||||
// CancelationPubSub returns a pubsub for cancelation messages.
|
||||
func (r *RDB) CancelationPubSub() (*redis.PubSub, error) {
|
||||
pubsub := r.client.Subscribe(base.CancelChannel)
|
||||
_, err := pubsub.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pubsub, nil
|
||||
}
|
||||
|
||||
// PublishCancelation publish cancelation message to all subscribers.
|
||||
// The message is a string representing the task to be canceled.
|
||||
func (r *RDB) PublishCancelation(id string) error {
|
||||
return r.client.Publish(base.CancelChannel, id).Err()
|
||||
}
|
||||
|
28
processor.go
28
processor.go
@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/hibiken/asynq/internal/rdb"
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@ -53,15 +52,14 @@ type processor struct {
|
||||
// quit channel communicates to the in-flight worker goroutines to stop.
|
||||
quit chan struct{}
|
||||
|
||||
// cancelFuncs is a map of task ID to cancel function for all in-progress tasks.
|
||||
mu sync.Mutex
|
||||
cancelFuncs map[string]context.CancelFunc
|
||||
// cancelations is a set of cancel functions for all in-progress tasks.
|
||||
cancelations *base.Cancelations
|
||||
}
|
||||
|
||||
type retryDelayFunc func(n int, err error, task *Task) time.Duration
|
||||
|
||||
// newProcessor constructs a new processor.
|
||||
func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest) *processor {
|
||||
func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest, cancelations *base.Cancelations) *processor {
|
||||
qcfg := normalizeQueueCfg(pinfo.Queues)
|
||||
orderedQueues := []string(nil)
|
||||
if pinfo.StrictPriority {
|
||||
@ -74,12 +72,12 @@ func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRe
|
||||
orderedQueues: orderedQueues,
|
||||
retryDelayFunc: fn,
|
||||
syncRequestCh: syncRequestCh,
|
||||
cancelations: cancelations,
|
||||
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
|
||||
sema: make(chan struct{}, pinfo.Concurrency),
|
||||
done: make(chan struct{}),
|
||||
abort: make(chan struct{}),
|
||||
quit: make(chan struct{}),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }),
|
||||
}
|
||||
}
|
||||
@ -107,7 +105,7 @@ func (p *processor) terminate() {
|
||||
logger.info("Waiting for all workers to finish...")
|
||||
|
||||
// send cancellation signal to all in-progress task handlers
|
||||
for _, cancel := range p.cancelFuncs {
|
||||
for _, cancel := range p.cancelations.GetAll() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
@ -174,10 +172,10 @@ func (p *processor) exec() {
|
||||
resCh := make(chan error, 1)
|
||||
task := NewTask(msg.Type, msg.Payload)
|
||||
ctx, cancel := createContext(msg)
|
||||
p.addCancelFunc(msg.ID, cancel)
|
||||
p.cancelations.Add(msg.ID.String(), cancel)
|
||||
go func() {
|
||||
resCh <- perform(ctx, task, p.handler)
|
||||
p.deleteCancelFunc(msg.ID)
|
||||
p.cancelations.Delete(msg.ID.String())
|
||||
}()
|
||||
|
||||
select {
|
||||
@ -268,18 +266,6 @@ func (p *processor) kill(msg *base.TaskMessage, e error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processor) addCancelFunc(id xid.ID, fn context.CancelFunc) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.cancelFuncs[id.String()] = fn
|
||||
}
|
||||
|
||||
func (p *processor) deleteCancelFunc(id xid.ID) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.cancelFuncs, id.String())
|
||||
}
|
||||
|
||||
// queues returns a list of queues to query.
|
||||
// Order of the queue names is based on the priority of each queue.
|
||||
// Queue names is sorted by their priority level if strict-priority is true.
|
||||
|
@ -67,7 +67,8 @@ func TestProcessorSuccess(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil)
|
||||
cancelations := base.NewCancelations()
|
||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
|
||||
p.handler = HandlerFunc(handler)
|
||||
|
||||
p.start()
|
||||
@ -151,7 +152,8 @@ func TestProcessorRetry(t *testing.T) {
|
||||
return fmt.Errorf(errMsg)
|
||||
}
|
||||
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
||||
p := newProcessor(rdbClient, pi, delayFunc, nil)
|
||||
cancelations := base.NewCancelations()
|
||||
p := newProcessor(rdbClient, pi, delayFunc, nil, cancelations)
|
||||
p.handler = HandlerFunc(handler)
|
||||
|
||||
p.start()
|
||||
@ -211,7 +213,8 @@ func TestProcessorQueues(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false)
|
||||
p := newProcessor(nil, pi, defaultDelayFunc, nil)
|
||||
cancelations := base.NewCancelations()
|
||||
p := newProcessor(nil, pi, defaultDelayFunc, nil, cancelations)
|
||||
got := p.queues()
|
||||
if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
|
||||
t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
|
||||
@ -278,7 +281,8 @@ func TestProcessorWithStrictPriority(t *testing.T) {
|
||||
}
|
||||
// Note: Set concurrency to 1 to make sure tasks are processed one at a time.
|
||||
pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */)
|
||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil)
|
||||
cancelations := base.NewCancelations()
|
||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
|
||||
p.handler = HandlerFunc(handler)
|
||||
|
||||
p.start()
|
||||
|
58
subscriber.go
Normal file
58
subscriber.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright 2020 Kentaro Hibino. All rights reserved.
|
||||
// Use of this source code is governed by a MIT license
|
||||
// that can be found in the LICENSE file.
|
||||
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/hibiken/asynq/internal/rdb"
|
||||
)
|
||||
|
||||
type subscriber struct {
|
||||
rdb *rdb.RDB
|
||||
|
||||
// channel to communicate back to the long running "subscriber" goroutine.
|
||||
done chan struct{}
|
||||
|
||||
// cancelations hold cancel functions for all in-progress tasks.
|
||||
cancelations *base.Cancelations
|
||||
}
|
||||
|
||||
func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber {
|
||||
return &subscriber{
|
||||
rdb: rdb,
|
||||
done: make(chan struct{}),
|
||||
cancelations: cancelations,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *subscriber) terminate() {
|
||||
logger.info("Subscriber shutting down...")
|
||||
// Signal the subscriber goroutine to stop.
|
||||
s.done <- struct{}{}
|
||||
}
|
||||
|
||||
func (s *subscriber) start() {
|
||||
pubsub, err := s.rdb.CancelationPubSub()
|
||||
cancelCh := pubsub.Channel()
|
||||
if err != nil {
|
||||
logger.error("cannot subscribe to cancelation channel: %v", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
pubsub.Close()
|
||||
logger.info("Subscriber done")
|
||||
return
|
||||
case msg := <-cancelCh:
|
||||
cancel := s.cancelations.Get(msg.Payload)
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
57
subscriber_test.go
Normal file
57
subscriber_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright 2020 Kentaro Hibino. All rights reserved.
|
||||
// Use of this source code is governed by a MIT license
|
||||
// that can be found in the LICENSE file.
|
||||
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/hibiken/asynq/internal/rdb"
|
||||
)
|
||||
|
||||
func TestSubscriber(t *testing.T) {
|
||||
r := setup(t)
|
||||
rdbClient := rdb.NewRDB(r)
|
||||
|
||||
tests := []struct {
|
||||
registeredID string // ID for which cancel func is registered
|
||||
publishID string // ID to be published
|
||||
wantCalled bool // whether cancel func should be called
|
||||
}{
|
||||
{"abc123", "abc123", true},
|
||||
{"abc456", "abc123", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
called := false
|
||||
fakeCancelFunc := func() {
|
||||
called = true
|
||||
}
|
||||
cancelations := base.NewCancelations()
|
||||
cancelations.Add(tc.registeredID, fakeCancelFunc)
|
||||
|
||||
subscriber := newSubscriber(rdbClient, cancelations)
|
||||
subscriber.start()
|
||||
|
||||
if err := rdbClient.PublishCancelation(tc.publishID); err != nil {
|
||||
subscriber.terminate()
|
||||
t.Fatalf("could not publish cancelation message: %v", err)
|
||||
}
|
||||
|
||||
// allow for redis to publish message
|
||||
time.Sleep(time.Second)
|
||||
|
||||
if called != tc.wantCalled {
|
||||
if tc.wantCalled {
|
||||
t.Errorf("fakeCancelFunc was not called, want the function to be called")
|
||||
} else {
|
||||
t.Errorf("fakeCancelFunc was called, want the function to not be called")
|
||||
}
|
||||
}
|
||||
|
||||
subscriber.terminate()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user