2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-10 11:31:58 +08:00

Add redis pubsub subscriber for cancelation

This commit is contained in:
Ken Hibino 2020-02-12 17:12:09 -08:00
parent 133bb6c2c6
commit 6685827147
7 changed files with 202 additions and 26 deletions

View File

@ -40,6 +40,7 @@ type Background struct {
processor *processor processor *processor
syncer *syncer syncer *syncer
heartbeater *heartbeater heartbeater *heartbeater
subscriber *subscriber
} }
// Config specifies the background-task processing behavior. // Config specifies the background-task processing behavior.
@ -120,10 +121,12 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority) pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority)
rdb := rdb.NewRDB(createRedisClient(r)) rdb := rdb.NewRDB(createRedisClient(r))
syncRequestCh := make(chan *syncRequest) syncRequestCh := make(chan *syncRequest)
cancelations := base.NewCancelations()
syncer := newSyncer(syncRequestCh, 5*time.Second) syncer := newSyncer(syncRequestCh, 5*time.Second)
heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second) heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second)
scheduler := newScheduler(rdb, 5*time.Second, queues) scheduler := newScheduler(rdb, 5*time.Second, queues)
processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh) processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh, cancelations)
subscriber := newSubscriber(rdb, cancelations)
return &Background{ return &Background{
pinfo: pinfo, pinfo: pinfo,
rdb: rdb, rdb: rdb,
@ -131,6 +134,7 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
processor: processor, processor: processor,
syncer: syncer, syncer: syncer,
heartbeater: heartbeater, heartbeater: heartbeater,
subscriber: subscriber,
} }
} }
@ -198,6 +202,7 @@ func (bg *Background) start(handler Handler) {
bg.processor.handler = handler bg.processor.handler = handler
bg.heartbeater.start() bg.heartbeater.start()
bg.subscriber.start()
bg.syncer.start() bg.syncer.start()
bg.scheduler.start() bg.scheduler.start()
bg.processor.start() bg.processor.start()
@ -216,6 +221,7 @@ func (bg *Background) stop() {
// Note: processor and all worker goroutines need to be exited // Note: processor and all worker goroutines need to be exited
// before shutting down syncer to avoid goroutine leak. // before shutting down syncer to avoid goroutine leak.
bg.syncer.terminate() bg.syncer.terminate()
bg.subscriber.terminate()
bg.heartbeater.terminate() bg.heartbeater.terminate()
bg.rdb.ClearProcessInfo(bg.pinfo) bg.rdb.ClearProcessInfo(bg.pinfo)

View File

@ -6,6 +6,7 @@
package base package base
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -30,6 +31,7 @@ const (
RetryQueue = "asynq:retry" // ZSET RetryQueue = "asynq:retry" // ZSET
DeadQueue = "asynq:dead" // ZSET DeadQueue = "asynq:dead" // ZSET
InProgressQueue = "asynq:in_progress" // LIST InProgressQueue = "asynq:in_progress" // LIST
CancelChannel = "asynq:cancel" // PubSub channel
) )
// QueueKey returns a redis key string for the given queue name. // QueueKey returns a redis key string for the given queue name.
@ -129,3 +131,50 @@ func (p *ProcessInfo) IncrActiveWorkerCount(delta int) {
defer p.mu.Unlock() defer p.mu.Unlock()
p.ActiveWorkerCount += delta p.ActiveWorkerCount += delta
} }
// Cancelations hold cancel functions for all in-progress tasks.
//
// Its methods are safe to be used in multiple concurrent goroutines
type Cancelations struct {
mu sync.Mutex
cancelFuncs map[string]context.CancelFunc
}
// NewCancelations returns a Cancelations instance.
func NewCancelations() *Cancelations {
return &Cancelations{
cancelFuncs: make(map[string]context.CancelFunc),
}
}
// Add adds a new cancel func to the set.
func (c *Cancelations) Add(id string, fn context.CancelFunc) {
c.mu.Lock()
defer c.mu.Unlock()
c.cancelFuncs[id] = fn
}
// Delete deletes a cancel func from the set given an id.
func (c *Cancelations) Delete(id string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cancelFuncs, id)
}
// Get returns a cancel func given an id.
func (c *Cancelations) Get(id string) context.CancelFunc {
c.mu.Lock()
defer c.mu.Unlock()
return c.cancelFuncs[id]
}
// GetAll returns all cancel funcs.
func (c *Cancelations) GetAll() []context.CancelFunc {
c.mu.Lock()
defer c.mu.Unlock()
var res []context.CancelFunc
for _, fn := range c.cancelFuncs {
res = append(res, fn)
}
return res
}

View File

@ -410,3 +410,19 @@ func (r *RDB) ClearProcessInfo(ps *base.ProcessInfo) error {
key := base.ProcessInfoKey(ps.Host, ps.PID) key := base.ProcessInfoKey(ps.Host, ps.PID)
return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err() return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err()
} }
// CancelationPubSub returns a pubsub for cancelation messages.
func (r *RDB) CancelationPubSub() (*redis.PubSub, error) {
pubsub := r.client.Subscribe(base.CancelChannel)
_, err := pubsub.Receive()
if err != nil {
return nil, err
}
return pubsub, nil
}
// PublishCancelation publish cancelation message to all subscribers.
// The message is a string representing the task to be canceled.
func (r *RDB) PublishCancelation(id string) error {
return r.client.Publish(base.CancelChannel, id).Err()
}

View File

@ -14,7 +14,6 @@ import (
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
"github.com/rs/xid"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -53,15 +52,14 @@ type processor struct {
// quit channel communicates to the in-flight worker goroutines to stop. // quit channel communicates to the in-flight worker goroutines to stop.
quit chan struct{} quit chan struct{}
// cancelFuncs is a map of task ID to cancel function for all in-progress tasks. // cancelations is a set of cancel functions for all in-progress tasks.
mu sync.Mutex cancelations *base.Cancelations
cancelFuncs map[string]context.CancelFunc
} }
type retryDelayFunc func(n int, err error, task *Task) time.Duration type retryDelayFunc func(n int, err error, task *Task) time.Duration
// newProcessor constructs a new processor. // newProcessor constructs a new processor.
func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest) *processor { func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest, cancelations *base.Cancelations) *processor {
qcfg := normalizeQueueCfg(pinfo.Queues) qcfg := normalizeQueueCfg(pinfo.Queues)
orderedQueues := []string(nil) orderedQueues := []string(nil)
if pinfo.StrictPriority { if pinfo.StrictPriority {
@ -74,12 +72,12 @@ func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRe
orderedQueues: orderedQueues, orderedQueues: orderedQueues,
retryDelayFunc: fn, retryDelayFunc: fn,
syncRequestCh: syncRequestCh, syncRequestCh: syncRequestCh,
cancelations: cancelations,
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1), errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
sema: make(chan struct{}, pinfo.Concurrency), sema: make(chan struct{}, pinfo.Concurrency),
done: make(chan struct{}), done: make(chan struct{}),
abort: make(chan struct{}), abort: make(chan struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
cancelFuncs: make(map[string]context.CancelFunc),
handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }),
} }
} }
@ -107,7 +105,7 @@ func (p *processor) terminate() {
logger.info("Waiting for all workers to finish...") logger.info("Waiting for all workers to finish...")
// send cancellation signal to all in-progress task handlers // send cancellation signal to all in-progress task handlers
for _, cancel := range p.cancelFuncs { for _, cancel := range p.cancelations.GetAll() {
cancel() cancel()
} }
@ -174,10 +172,10 @@ func (p *processor) exec() {
resCh := make(chan error, 1) resCh := make(chan error, 1)
task := NewTask(msg.Type, msg.Payload) task := NewTask(msg.Type, msg.Payload)
ctx, cancel := createContext(msg) ctx, cancel := createContext(msg)
p.addCancelFunc(msg.ID, cancel) p.cancelations.Add(msg.ID.String(), cancel)
go func() { go func() {
resCh <- perform(ctx, task, p.handler) resCh <- perform(ctx, task, p.handler)
p.deleteCancelFunc(msg.ID) p.cancelations.Delete(msg.ID.String())
}() }()
select { select {
@ -268,18 +266,6 @@ func (p *processor) kill(msg *base.TaskMessage, e error) {
} }
} }
func (p *processor) addCancelFunc(id xid.ID, fn context.CancelFunc) {
p.mu.Lock()
defer p.mu.Unlock()
p.cancelFuncs[id.String()] = fn
}
func (p *processor) deleteCancelFunc(id xid.ID) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.cancelFuncs, id.String())
}
// queues returns a list of queues to query. // queues returns a list of queues to query.
// Order of the queue names is based on the priority of each queue. // Order of the queue names is based on the priority of each queue.
// Queue names is sorted by their priority level if strict-priority is true. // Queue names is sorted by their priority level if strict-priority is true.

View File

@ -67,7 +67,8 @@ func TestProcessorSuccess(t *testing.T) {
return nil return nil
} }
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false) pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil) cancelations := base.NewCancelations()
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() p.start()
@ -151,7 +152,8 @@ func TestProcessorRetry(t *testing.T) {
return fmt.Errorf(errMsg) return fmt.Errorf(errMsg)
} }
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false) pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
p := newProcessor(rdbClient, pi, delayFunc, nil) cancelations := base.NewCancelations()
p := newProcessor(rdbClient, pi, delayFunc, nil, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() p.start()
@ -211,7 +213,8 @@ func TestProcessorQueues(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false) pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false)
p := newProcessor(nil, pi, defaultDelayFunc, nil) cancelations := base.NewCancelations()
p := newProcessor(nil, pi, defaultDelayFunc, nil, cancelations)
got := p.queues() got := p.queues()
if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" { if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s", t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
@ -278,7 +281,8 @@ func TestProcessorWithStrictPriority(t *testing.T) {
} }
// Note: Set concurrency to 1 to make sure tasks are processed one at a time. // Note: Set concurrency to 1 to make sure tasks are processed one at a time.
pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */) pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */)
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil) cancelations := base.NewCancelations()
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() p.start()

58
subscriber.go Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package asynq
import (
"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb"
)
type subscriber struct {
rdb *rdb.RDB
// channel to communicate back to the long running "subscriber" goroutine.
done chan struct{}
// cancelations hold cancel functions for all in-progress tasks.
cancelations *base.Cancelations
}
func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber {
return &subscriber{
rdb: rdb,
done: make(chan struct{}),
cancelations: cancelations,
}
}
func (s *subscriber) terminate() {
logger.info("Subscriber shutting down...")
// Signal the subscriber goroutine to stop.
s.done <- struct{}{}
}
func (s *subscriber) start() {
pubsub, err := s.rdb.CancelationPubSub()
cancelCh := pubsub.Channel()
if err != nil {
logger.error("cannot subscribe to cancelation channel: %v", err)
return
}
go func() {
for {
select {
case <-s.done:
pubsub.Close()
logger.info("Subscriber done")
return
case msg := <-cancelCh:
cancel := s.cancelations.Get(msg.Payload)
if cancel != nil {
cancel()
}
}
}
}()
}

57
subscriber_test.go Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package asynq
import (
"testing"
"time"
"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb"
)
func TestSubscriber(t *testing.T) {
r := setup(t)
rdbClient := rdb.NewRDB(r)
tests := []struct {
registeredID string // ID for which cancel func is registered
publishID string // ID to be published
wantCalled bool // whether cancel func should be called
}{
{"abc123", "abc123", true},
{"abc456", "abc123", false},
}
for _, tc := range tests {
called := false
fakeCancelFunc := func() {
called = true
}
cancelations := base.NewCancelations()
cancelations.Add(tc.registeredID, fakeCancelFunc)
subscriber := newSubscriber(rdbClient, cancelations)
subscriber.start()
if err := rdbClient.PublishCancelation(tc.publishID); err != nil {
subscriber.terminate()
t.Fatalf("could not publish cancelation message: %v", err)
}
// allow for redis to publish message
time.Sleep(time.Second)
if called != tc.wantCalled {
if tc.wantCalled {
t.Errorf("fakeCancelFunc was not called, want the function to be called")
} else {
t.Errorf("fakeCancelFunc was called, want the function to not be called")
}
}
subscriber.terminate()
}
}