mirror of
https://github.com/hibiken/asynq.git
synced 2024-12-24 23:02:18 +08:00
Add redis pubsub subscriber for cancelation
This commit is contained in:
parent
133bb6c2c6
commit
6685827147
@ -40,6 +40,7 @@ type Background struct {
|
|||||||
processor *processor
|
processor *processor
|
||||||
syncer *syncer
|
syncer *syncer
|
||||||
heartbeater *heartbeater
|
heartbeater *heartbeater
|
||||||
|
subscriber *subscriber
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config specifies the background-task processing behavior.
|
// Config specifies the background-task processing behavior.
|
||||||
@ -120,10 +121,12 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
|
|||||||
pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority)
|
pinfo := base.NewProcessInfo(host, pid, n, queues, cfg.StrictPriority)
|
||||||
rdb := rdb.NewRDB(createRedisClient(r))
|
rdb := rdb.NewRDB(createRedisClient(r))
|
||||||
syncRequestCh := make(chan *syncRequest)
|
syncRequestCh := make(chan *syncRequest)
|
||||||
|
cancelations := base.NewCancelations()
|
||||||
syncer := newSyncer(syncRequestCh, 5*time.Second)
|
syncer := newSyncer(syncRequestCh, 5*time.Second)
|
||||||
heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second)
|
heartbeater := newHeartbeater(rdb, pinfo, 5*time.Second)
|
||||||
scheduler := newScheduler(rdb, 5*time.Second, queues)
|
scheduler := newScheduler(rdb, 5*time.Second, queues)
|
||||||
processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh)
|
processor := newProcessor(rdb, pinfo, delayFunc, syncRequestCh, cancelations)
|
||||||
|
subscriber := newSubscriber(rdb, cancelations)
|
||||||
return &Background{
|
return &Background{
|
||||||
pinfo: pinfo,
|
pinfo: pinfo,
|
||||||
rdb: rdb,
|
rdb: rdb,
|
||||||
@ -131,6 +134,7 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
|
|||||||
processor: processor,
|
processor: processor,
|
||||||
syncer: syncer,
|
syncer: syncer,
|
||||||
heartbeater: heartbeater,
|
heartbeater: heartbeater,
|
||||||
|
subscriber: subscriber,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,6 +202,7 @@ func (bg *Background) start(handler Handler) {
|
|||||||
bg.processor.handler = handler
|
bg.processor.handler = handler
|
||||||
|
|
||||||
bg.heartbeater.start()
|
bg.heartbeater.start()
|
||||||
|
bg.subscriber.start()
|
||||||
bg.syncer.start()
|
bg.syncer.start()
|
||||||
bg.scheduler.start()
|
bg.scheduler.start()
|
||||||
bg.processor.start()
|
bg.processor.start()
|
||||||
@ -216,6 +221,7 @@ func (bg *Background) stop() {
|
|||||||
// Note: processor and all worker goroutines need to be exited
|
// Note: processor and all worker goroutines need to be exited
|
||||||
// before shutting down syncer to avoid goroutine leak.
|
// before shutting down syncer to avoid goroutine leak.
|
||||||
bg.syncer.terminate()
|
bg.syncer.terminate()
|
||||||
|
bg.subscriber.terminate()
|
||||||
bg.heartbeater.terminate()
|
bg.heartbeater.terminate()
|
||||||
|
|
||||||
bg.rdb.ClearProcessInfo(bg.pinfo)
|
bg.rdb.ClearProcessInfo(bg.pinfo)
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package base
|
package base
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -30,6 +31,7 @@ const (
|
|||||||
RetryQueue = "asynq:retry" // ZSET
|
RetryQueue = "asynq:retry" // ZSET
|
||||||
DeadQueue = "asynq:dead" // ZSET
|
DeadQueue = "asynq:dead" // ZSET
|
||||||
InProgressQueue = "asynq:in_progress" // LIST
|
InProgressQueue = "asynq:in_progress" // LIST
|
||||||
|
CancelChannel = "asynq:cancel" // PubSub channel
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueueKey returns a redis key string for the given queue name.
|
// QueueKey returns a redis key string for the given queue name.
|
||||||
@ -129,3 +131,50 @@ func (p *ProcessInfo) IncrActiveWorkerCount(delta int) {
|
|||||||
defer p.mu.Unlock()
|
defer p.mu.Unlock()
|
||||||
p.ActiveWorkerCount += delta
|
p.ActiveWorkerCount += delta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cancelations hold cancel functions for all in-progress tasks.
|
||||||
|
//
|
||||||
|
// Its methods are safe to be used in multiple concurrent goroutines
|
||||||
|
type Cancelations struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cancelFuncs map[string]context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCancelations returns a Cancelations instance.
|
||||||
|
func NewCancelations() *Cancelations {
|
||||||
|
return &Cancelations{
|
||||||
|
cancelFuncs: make(map[string]context.CancelFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a new cancel func to the set.
|
||||||
|
func (c *Cancelations) Add(id string, fn context.CancelFunc) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.cancelFuncs[id] = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes a cancel func from the set given an id.
|
||||||
|
func (c *Cancelations) Delete(id string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
delete(c.cancelFuncs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a cancel func given an id.
|
||||||
|
func (c *Cancelations) Get(id string) context.CancelFunc {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.cancelFuncs[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all cancel funcs.
|
||||||
|
func (c *Cancelations) GetAll() []context.CancelFunc {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
var res []context.CancelFunc
|
||||||
|
for _, fn := range c.cancelFuncs {
|
||||||
|
res = append(res, fn)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
@ -410,3 +410,19 @@ func (r *RDB) ClearProcessInfo(ps *base.ProcessInfo) error {
|
|||||||
key := base.ProcessInfoKey(ps.Host, ps.PID)
|
key := base.ProcessInfoKey(ps.Host, ps.PID)
|
||||||
return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err()
|
return clearProcessInfoCmd.Run(r.client, []string{base.AllProcesses, key}).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelationPubSub returns a pubsub for cancelation messages.
|
||||||
|
func (r *RDB) CancelationPubSub() (*redis.PubSub, error) {
|
||||||
|
pubsub := r.client.Subscribe(base.CancelChannel)
|
||||||
|
_, err := pubsub.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return pubsub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishCancelation publish cancelation message to all subscribers.
|
||||||
|
// The message is a string representing the task to be canceled.
|
||||||
|
func (r *RDB) PublishCancelation(id string) error {
|
||||||
|
return r.client.Publish(base.CancelChannel, id).Err()
|
||||||
|
}
|
||||||
|
28
processor.go
28
processor.go
@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/hibiken/asynq/internal/base"
|
"github.com/hibiken/asynq/internal/base"
|
||||||
"github.com/hibiken/asynq/internal/rdb"
|
"github.com/hibiken/asynq/internal/rdb"
|
||||||
"github.com/rs/xid"
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,15 +52,14 @@ type processor struct {
|
|||||||
// quit channel communicates to the in-flight worker goroutines to stop.
|
// quit channel communicates to the in-flight worker goroutines to stop.
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
|
|
||||||
// cancelFuncs is a map of task ID to cancel function for all in-progress tasks.
|
// cancelations is a set of cancel functions for all in-progress tasks.
|
||||||
mu sync.Mutex
|
cancelations *base.Cancelations
|
||||||
cancelFuncs map[string]context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type retryDelayFunc func(n int, err error, task *Task) time.Duration
|
type retryDelayFunc func(n int, err error, task *Task) time.Duration
|
||||||
|
|
||||||
// newProcessor constructs a new processor.
|
// newProcessor constructs a new processor.
|
||||||
func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest) *processor {
|
func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRequestCh chan<- *syncRequest, cancelations *base.Cancelations) *processor {
|
||||||
qcfg := normalizeQueueCfg(pinfo.Queues)
|
qcfg := normalizeQueueCfg(pinfo.Queues)
|
||||||
orderedQueues := []string(nil)
|
orderedQueues := []string(nil)
|
||||||
if pinfo.StrictPriority {
|
if pinfo.StrictPriority {
|
||||||
@ -74,12 +72,12 @@ func newProcessor(r *rdb.RDB, pinfo *base.ProcessInfo, fn retryDelayFunc, syncRe
|
|||||||
orderedQueues: orderedQueues,
|
orderedQueues: orderedQueues,
|
||||||
retryDelayFunc: fn,
|
retryDelayFunc: fn,
|
||||||
syncRequestCh: syncRequestCh,
|
syncRequestCh: syncRequestCh,
|
||||||
|
cancelations: cancelations,
|
||||||
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
|
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
|
||||||
sema: make(chan struct{}, pinfo.Concurrency),
|
sema: make(chan struct{}, pinfo.Concurrency),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
abort: make(chan struct{}),
|
abort: make(chan struct{}),
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
cancelFuncs: make(map[string]context.CancelFunc),
|
|
||||||
handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }),
|
handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,7 +105,7 @@ func (p *processor) terminate() {
|
|||||||
logger.info("Waiting for all workers to finish...")
|
logger.info("Waiting for all workers to finish...")
|
||||||
|
|
||||||
// send cancellation signal to all in-progress task handlers
|
// send cancellation signal to all in-progress task handlers
|
||||||
for _, cancel := range p.cancelFuncs {
|
for _, cancel := range p.cancelations.GetAll() {
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,10 +172,10 @@ func (p *processor) exec() {
|
|||||||
resCh := make(chan error, 1)
|
resCh := make(chan error, 1)
|
||||||
task := NewTask(msg.Type, msg.Payload)
|
task := NewTask(msg.Type, msg.Payload)
|
||||||
ctx, cancel := createContext(msg)
|
ctx, cancel := createContext(msg)
|
||||||
p.addCancelFunc(msg.ID, cancel)
|
p.cancelations.Add(msg.ID.String(), cancel)
|
||||||
go func() {
|
go func() {
|
||||||
resCh <- perform(ctx, task, p.handler)
|
resCh <- perform(ctx, task, p.handler)
|
||||||
p.deleteCancelFunc(msg.ID)
|
p.cancelations.Delete(msg.ID.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -268,18 +266,6 @@ func (p *processor) kill(msg *base.TaskMessage, e error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) addCancelFunc(id xid.ID, fn context.CancelFunc) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
p.cancelFuncs[id.String()] = fn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *processor) deleteCancelFunc(id xid.ID) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
delete(p.cancelFuncs, id.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// queues returns a list of queues to query.
|
// queues returns a list of queues to query.
|
||||||
// Order of the queue names is based on the priority of each queue.
|
// Order of the queue names is based on the priority of each queue.
|
||||||
// Queue names is sorted by their priority level if strict-priority is true.
|
// Queue names is sorted by their priority level if strict-priority is true.
|
||||||
|
@ -67,7 +67,8 @@ func TestProcessorSuccess(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
||||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil)
|
cancelations := base.NewCancelations()
|
||||||
|
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
|
||||||
p.handler = HandlerFunc(handler)
|
p.handler = HandlerFunc(handler)
|
||||||
|
|
||||||
p.start()
|
p.start()
|
||||||
@ -151,7 +152,8 @@ func TestProcessorRetry(t *testing.T) {
|
|||||||
return fmt.Errorf(errMsg)
|
return fmt.Errorf(errMsg)
|
||||||
}
|
}
|
||||||
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
pi := base.NewProcessInfo("localhost", 1234, 10, defaultQueueConfig, false)
|
||||||
p := newProcessor(rdbClient, pi, delayFunc, nil)
|
cancelations := base.NewCancelations()
|
||||||
|
p := newProcessor(rdbClient, pi, delayFunc, nil, cancelations)
|
||||||
p.handler = HandlerFunc(handler)
|
p.handler = HandlerFunc(handler)
|
||||||
|
|
||||||
p.start()
|
p.start()
|
||||||
@ -211,7 +213,8 @@ func TestProcessorQueues(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false)
|
pi := base.NewProcessInfo("localhost", 1234, 10, tc.queueCfg, false)
|
||||||
p := newProcessor(nil, pi, defaultDelayFunc, nil)
|
cancelations := base.NewCancelations()
|
||||||
|
p := newProcessor(nil, pi, defaultDelayFunc, nil, cancelations)
|
||||||
got := p.queues()
|
got := p.queues()
|
||||||
if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
|
if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
|
||||||
t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
|
t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
|
||||||
@ -278,7 +281,8 @@ func TestProcessorWithStrictPriority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// Note: Set concurrency to 1 to make sure tasks are processed one at a time.
|
// Note: Set concurrency to 1 to make sure tasks are processed one at a time.
|
||||||
pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */)
|
pi := base.NewProcessInfo("localhost", 1234, 1 /*concurrency */, queueCfg, true /* strict */)
|
||||||
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil)
|
cancelations := base.NewCancelations()
|
||||||
|
p := newProcessor(rdbClient, pi, defaultDelayFunc, nil, cancelations)
|
||||||
p.handler = HandlerFunc(handler)
|
p.handler = HandlerFunc(handler)
|
||||||
|
|
||||||
p.start()
|
p.start()
|
||||||
|
58
subscriber.go
Normal file
58
subscriber.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright 2020 Kentaro Hibino. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT license
|
||||||
|
// that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package asynq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/hibiken/asynq/internal/base"
|
||||||
|
"github.com/hibiken/asynq/internal/rdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
type subscriber struct {
|
||||||
|
rdb *rdb.RDB
|
||||||
|
|
||||||
|
// channel to communicate back to the long running "subscriber" goroutine.
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
// cancelations hold cancel functions for all in-progress tasks.
|
||||||
|
cancelations *base.Cancelations
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber {
|
||||||
|
return &subscriber{
|
||||||
|
rdb: rdb,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
cancelations: cancelations,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *subscriber) terminate() {
|
||||||
|
logger.info("Subscriber shutting down...")
|
||||||
|
// Signal the subscriber goroutine to stop.
|
||||||
|
s.done <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *subscriber) start() {
|
||||||
|
pubsub, err := s.rdb.CancelationPubSub()
|
||||||
|
cancelCh := pubsub.Channel()
|
||||||
|
if err != nil {
|
||||||
|
logger.error("cannot subscribe to cancelation channel: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
pubsub.Close()
|
||||||
|
logger.info("Subscriber done")
|
||||||
|
return
|
||||||
|
case msg := <-cancelCh:
|
||||||
|
cancel := s.cancelations.Get(msg.Payload)
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
57
subscriber_test.go
Normal file
57
subscriber_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2020 Kentaro Hibino. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT license
|
||||||
|
// that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package asynq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hibiken/asynq/internal/base"
|
||||||
|
"github.com/hibiken/asynq/internal/rdb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubscriber(t *testing.T) {
|
||||||
|
r := setup(t)
|
||||||
|
rdbClient := rdb.NewRDB(r)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
registeredID string // ID for which cancel func is registered
|
||||||
|
publishID string // ID to be published
|
||||||
|
wantCalled bool // whether cancel func should be called
|
||||||
|
}{
|
||||||
|
{"abc123", "abc123", true},
|
||||||
|
{"abc456", "abc123", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
called := false
|
||||||
|
fakeCancelFunc := func() {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
cancelations := base.NewCancelations()
|
||||||
|
cancelations.Add(tc.registeredID, fakeCancelFunc)
|
||||||
|
|
||||||
|
subscriber := newSubscriber(rdbClient, cancelations)
|
||||||
|
subscriber.start()
|
||||||
|
|
||||||
|
if err := rdbClient.PublishCancelation(tc.publishID); err != nil {
|
||||||
|
subscriber.terminate()
|
||||||
|
t.Fatalf("could not publish cancelation message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow for redis to publish message
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
if called != tc.wantCalled {
|
||||||
|
if tc.wantCalled {
|
||||||
|
t.Errorf("fakeCancelFunc was not called, want the function to be called")
|
||||||
|
} else {
|
||||||
|
t.Errorf("fakeCancelFunc was called, want the function to not be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
subscriber.terminate()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user