2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-09-20 11:05:58 +08:00

Use sync.WaitGroup for shutdown

This commit is contained in:
Ken Hibino 2020-02-15 23:14:30 -08:00
parent 2bcaea52ce
commit 3d9a222bb3
11 changed files with 51 additions and 18 deletions

View File

@ -37,6 +37,9 @@ type Background struct {
// channel to send state updates. // channel to send state updates.
stateCh chan<- string stateCh chan<- string
// wait group to wait for all goroutines to finish.
wg sync.WaitGroup
rdb *rdb.RDB rdb *rdb.RDB
scheduler *scheduler scheduler *scheduler
processor *processor processor *processor
@ -211,11 +214,11 @@ func (bg *Background) start(handler Handler) {
bg.running = true bg.running = true
bg.processor.handler = handler bg.processor.handler = handler
bg.heartbeater.start() bg.heartbeater.start(&bg.wg)
bg.subscriber.start() bg.subscriber.start(&bg.wg)
bg.syncer.start() bg.syncer.start(&bg.wg)
bg.scheduler.start() bg.scheduler.start(&bg.wg)
bg.processor.start() bg.processor.start(&bg.wg)
} }
// stops the background-task processing. // stops the background-task processing.
@ -234,6 +237,8 @@ func (bg *Background) stop() {
bg.subscriber.terminate() bg.subscriber.terminate()
bg.heartbeater.terminate() bg.heartbeater.terminate()
bg.wg.Wait()
bg.rdb.Close() bg.rdb.Close()
bg.processor.handler = nil bg.processor.handler = nil
bg.running = false bg.running = false

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"time" "time"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
@ -49,10 +50,12 @@ func (h *heartbeater) terminate() {
h.done <- struct{}{} h.done <- struct{}{}
} }
func (h *heartbeater) start() { func (h *heartbeater) start(wg *sync.WaitGroup) {
h.pinfo.Started = time.Now() h.pinfo.Started = time.Now()
h.pinfo.State = "running" h.pinfo.State = "running"
wg.Add(1)
go func() { go func() {
defer wg.Done()
h.beat() h.beat()
timer := time.NewTimer(h.interval) timer := time.NewTimer(h.interval)
for { for {

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"testing" "testing"
"time" "time"
@ -15,6 +16,7 @@ import (
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
) )
// FIXME: Make this test better.
func TestHeartbeater(t *testing.T) { func TestHeartbeater(t *testing.T) {
r := setup(t) r := setup(t)
rdbClient := rdb.NewRDB(r) rdbClient := rdb.NewRDB(r)
@ -46,7 +48,8 @@ func TestHeartbeater(t *testing.T) {
Started: time.Now(), Started: time.Now(),
State: "running", State: "running",
} }
hb.start() var wg sync.WaitGroup
hb.start(&wg)
// allow for heartbeater to write to redis // allow for heartbeater to write to redis
time.Sleep(tc.interval * 2) time.Sleep(tc.interval * 2)

View File

@ -119,11 +119,13 @@ func (p *processor) terminate() {
p.restore() // move any unfinished tasks back to the queue. p.restore() // move any unfinished tasks back to the queue.
} }
func (p *processor) start() { func (p *processor) start(wg *sync.WaitGroup) {
// NOTE: The call to "restore" needs to complete before starting // NOTE: The call to "restore" needs to complete before starting
// the processor goroutine. // the processor goroutine.
p.restore() p.restore()
wg.Add(1)
go func() { go func() {
defer wg.Done()
for { for {
select { select {
case <-p.done: case <-p.done:

View File

@ -72,7 +72,8 @@ func TestProcessorSuccess(t *testing.T) {
p := newProcessor(rdbClient, defaultQueueConfig, false, 10, defaultDelayFunc, nil, workerCh, cancelations) p := newProcessor(rdbClient, defaultQueueConfig, false, 10, defaultDelayFunc, nil, workerCh, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() var wg sync.WaitGroup
p.start(&wg)
for _, msg := range tc.incoming { for _, msg := range tc.incoming {
err := rdbClient.Enqueue(msg) err := rdbClient.Enqueue(msg)
if err != nil { if err != nil {
@ -159,7 +160,8 @@ func TestProcessorRetry(t *testing.T) {
p := newProcessor(rdbClient, defaultQueueConfig, false, 10, delayFunc, nil, workerCh, cancelations) p := newProcessor(rdbClient, defaultQueueConfig, false, 10, delayFunc, nil, workerCh, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() var wg sync.WaitGroup
p.start(&wg)
for _, msg := range tc.incoming { for _, msg := range tc.incoming {
err := rdbClient.Enqueue(msg) err := rdbClient.Enqueue(msg)
if err != nil { if err != nil {
@ -290,7 +292,8 @@ func TestProcessorWithStrictPriority(t *testing.T) {
defaultDelayFunc, nil, workerCh, cancelations) defaultDelayFunc, nil, workerCh, cancelations)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
p.start() var wg sync.WaitGroup
p.start(&wg)
time.Sleep(tc.wait) time.Sleep(tc.wait)
p.terminate() p.terminate()
close(workerCh) close(workerCh)

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"time" "time"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
@ -43,8 +44,10 @@ func (s *scheduler) terminate() {
} }
// start starts the "scheduler" goroutine. // start starts the "scheduler" goroutine.
func (s *scheduler) start() { func (s *scheduler) start(wg *sync.WaitGroup) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
for { for {
select { select {
case <-s.done: case <-s.done:

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"testing" "testing"
"time" "time"
@ -69,7 +70,8 @@ func TestScheduler(t *testing.T) {
h.SeedRetryQueue(t, r, tc.initRetry) // initialize retry queue h.SeedRetryQueue(t, r, tc.initRetry) // initialize retry queue
h.SeedEnqueuedQueue(t, r, tc.initQueue) // initialize default queue h.SeedEnqueuedQueue(t, r, tc.initQueue) // initialize default queue
s.start() var wg sync.WaitGroup
s.start(&wg)
time.Sleep(tc.wait) time.Sleep(tc.wait)
s.terminate() s.terminate()

View File

@ -5,6 +5,8 @@
package asynq package asynq
import ( import (
"sync"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
) )
@ -33,14 +35,16 @@ func (s *subscriber) terminate() {
s.done <- struct{}{} s.done <- struct{}{}
} }
func (s *subscriber) start() { func (s *subscriber) start(wg *sync.WaitGroup) {
pubsub, err := s.rdb.CancelationPubSub() pubsub, err := s.rdb.CancelationPubSub()
cancelCh := pubsub.Channel() cancelCh := pubsub.Channel()
if err != nil { if err != nil {
logger.error("cannot subscribe to cancelation channel: %v", err) logger.error("cannot subscribe to cancelation channel: %v", err)
return return
} }
wg.Add(1)
go func() { go func() {
defer wg.Done()
for { for {
select { select {
case <-s.done: case <-s.done:

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"testing" "testing"
"time" "time"
@ -34,7 +35,8 @@ func TestSubscriber(t *testing.T) {
cancelations.Add(tc.registeredID, fakeCancelFunc) cancelations.Add(tc.registeredID, fakeCancelFunc)
subscriber := newSubscriber(rdbClient, cancelations) subscriber := newSubscriber(rdbClient, cancelations)
subscriber.start() var wg sync.WaitGroup
subscriber.start(&wg)
if err := rdbClient.PublishCancelation(tc.publishID); err != nil { if err := rdbClient.PublishCancelation(tc.publishID); err != nil {
subscriber.terminate() subscriber.terminate()

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"time" "time"
) )
@ -39,8 +40,10 @@ func (s *syncer) terminate() {
s.done <- struct{}{} s.done <- struct{}{}
} }
func (s *syncer) start() { func (s *syncer) start(wg *sync.WaitGroup) {
wg.Add(1)
go func() { go func() {
defer wg.Done()
var requests []*syncRequest var requests []*syncRequest
for { for {
select { select {

View File

@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"sync"
"testing" "testing"
"time" "time"
@ -27,7 +28,8 @@ func TestSyncer(t *testing.T) {
const interval = time.Second const interval = time.Second
syncRequestCh := make(chan *syncRequest) syncRequestCh := make(chan *syncRequest)
syncer := newSyncer(syncRequestCh, interval) syncer := newSyncer(syncRequestCh, interval)
syncer.start() var wg sync.WaitGroup
syncer.start(&wg)
defer syncer.terminate() defer syncer.terminate()
for _, msg := range inProgress { for _, msg := range inProgress {
@ -66,7 +68,8 @@ func TestSyncerRetry(t *testing.T) {
const interval = time.Second const interval = time.Second
syncRequestCh := make(chan *syncRequest) syncRequestCh := make(chan *syncRequest)
syncer := newSyncer(syncRequestCh, interval) syncer := newSyncer(syncRequestCh, interval)
syncer.start() var wg sync.WaitGroup
syncer.start(&wg)
defer syncer.terminate() defer syncer.terminate()
for _, msg := range inProgress { for _, msg := range inProgress {