diff --git a/background.go b/background.go index 0d0cccf..91635bf 100644 --- a/background.go +++ b/background.go @@ -37,6 +37,9 @@ type Background struct { // channel to send state updates. stateCh chan<- string + // wait group to wait for all goroutines to finish. + wg sync.WaitGroup + rdb *rdb.RDB scheduler *scheduler processor *processor @@ -211,11 +214,11 @@ func (bg *Background) start(handler Handler) { bg.running = true bg.processor.handler = handler - bg.heartbeater.start() - bg.subscriber.start() - bg.syncer.start() - bg.scheduler.start() - bg.processor.start() + bg.heartbeater.start(&bg.wg) + bg.subscriber.start(&bg.wg) + bg.syncer.start(&bg.wg) + bg.scheduler.start(&bg.wg) + bg.processor.start(&bg.wg) } // stops the background-task processing. @@ -234,6 +237,8 @@ func (bg *Background) stop() { bg.subscriber.terminate() bg.heartbeater.terminate() + bg.wg.Wait() + bg.rdb.Close() bg.processor.handler = nil bg.running = false diff --git a/heartbeat.go b/heartbeat.go index f8fc326..9c78b74 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "time" "github.com/hibiken/asynq/internal/base" @@ -49,10 +50,12 @@ func (h *heartbeater) terminate() { h.done <- struct{}{} } -func (h *heartbeater) start() { +func (h *heartbeater) start(wg *sync.WaitGroup) { h.pinfo.Started = time.Now() h.pinfo.State = "running" + wg.Add(1) go func() { + defer wg.Done() h.beat() timer := time.NewTimer(h.interval) for { diff --git a/heartbeat_test.go b/heartbeat_test.go index ccbd114..cba8f18 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/hibiken/asynq/internal/rdb" ) +// FIXME: Make this test better. func TestHeartbeater(t *testing.T) { r := setup(t) rdbClient := rdb.NewRDB(r) @@ -46,7 +48,8 @@ func TestHeartbeater(t *testing.T) { Started: time.Now(), State: "running", } - hb.start() + var wg sync.WaitGroup + hb.start(&wg) // allow for heartbeater to write to redis time.Sleep(tc.interval * 2) diff --git a/processor.go b/processor.go index dfdd6ae..edf8de6 100644 --- a/processor.go +++ b/processor.go @@ -119,11 +119,13 @@ func (p *processor) terminate() { p.restore() // move any unfinished tasks back to the queue. } -func (p *processor) start() { +func (p *processor) start(wg *sync.WaitGroup) { // NOTE: The call to "restore" needs to complete before starting // the processor goroutine. p.restore() + wg.Add(1) go func() { + defer wg.Done() for { select { case <-p.done: diff --git a/processor_test.go b/processor_test.go index 1615571..77224c0 100644 --- a/processor_test.go +++ b/processor_test.go @@ -72,7 +72,8 @@ func TestProcessorSuccess(t *testing.T) { p := newProcessor(rdbClient, defaultQueueConfig, false, 10, defaultDelayFunc, nil, workerCh, cancelations) p.handler = HandlerFunc(handler) - p.start() + var wg sync.WaitGroup + p.start(&wg) for _, msg := range tc.incoming { err := rdbClient.Enqueue(msg) if err != nil { @@ -159,7 +160,8 @@ func TestProcessorRetry(t *testing.T) { p := newProcessor(rdbClient, defaultQueueConfig, false, 10, delayFunc, nil, workerCh, cancelations) p.handler = HandlerFunc(handler) - p.start() + var wg sync.WaitGroup + p.start(&wg) for _, msg := range tc.incoming { err := rdbClient.Enqueue(msg) if err != nil { @@ -290,7 +292,8 @@ func TestProcessorWithStrictPriority(t *testing.T) { defaultDelayFunc, nil, workerCh, cancelations) p.handler = HandlerFunc(handler) - p.start() + var wg sync.WaitGroup + p.start(&wg) time.Sleep(tc.wait) p.terminate() close(workerCh) diff --git a/scheduler.go b/scheduler.go index 7276f4c..c5f28d2 100644 --- a/scheduler.go +++ b/scheduler.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "time" "github.com/hibiken/asynq/internal/rdb" @@ -43,8 +44,10 @@ func (s *scheduler) terminate() { } // start starts the "scheduler" goroutine. -func (s *scheduler) start() { +func (s *scheduler) start(wg *sync.WaitGroup) { + wg.Add(1) go func() { + defer wg.Done() for { select { case <-s.done: diff --git a/scheduler_test.go b/scheduler_test.go index b16ee04..4f7575f 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "testing" "time" @@ -69,7 +70,8 @@ func TestScheduler(t *testing.T) { h.SeedRetryQueue(t, r, tc.initRetry) // initialize retry queue h.SeedEnqueuedQueue(t, r, tc.initQueue) // initialize default queue - s.start() + var wg sync.WaitGroup + s.start(&wg) time.Sleep(tc.wait) s.terminate() diff --git a/subscriber.go b/subscriber.go index 220116c..fc50c06 100644 --- a/subscriber.go +++ b/subscriber.go @@ -5,6 +5,8 @@ package asynq import ( + "sync" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" ) @@ -33,14 +35,16 @@ func (s *subscriber) terminate() { s.done <- struct{}{} } -func (s *subscriber) start() { +func (s *subscriber) start(wg *sync.WaitGroup) { pubsub, err := s.rdb.CancelationPubSub() cancelCh := pubsub.Channel() if err != nil { logger.error("cannot subscribe to cancelation channel: %v", err) return } + wg.Add(1) go func() { + defer wg.Done() for { select { case <-s.done: diff --git a/subscriber_test.go b/subscriber_test.go index 2b7f5a0..1a67159 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "testing" "time" @@ -34,7 +35,8 @@ func TestSubscriber(t *testing.T) { cancelations.Add(tc.registeredID, fakeCancelFunc) subscriber := newSubscriber(rdbClient, cancelations) - subscriber.start() + var wg sync.WaitGroup + subscriber.start(&wg) if err := rdbClient.PublishCancelation(tc.publishID); err != nil { subscriber.terminate() diff --git a/syncer.go b/syncer.go index 2d6c2b0..7494cbf 100644 --- a/syncer.go +++ b/syncer.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "time" ) @@ -39,8 +40,10 @@ func (s *syncer) terminate() { s.done <- struct{}{} } -func (s *syncer) start() { +func (s *syncer) start(wg *sync.WaitGroup) { + wg.Add(1) go func() { + defer wg.Done() var requests []*syncRequest for { select { diff --git a/syncer_test.go b/syncer_test.go index 8eae321..24793df 100644 --- a/syncer_test.go +++ b/syncer_test.go @@ -5,6 +5,7 @@ package asynq import ( + "sync" "testing" "time" @@ -27,7 +28,8 @@ func TestSyncer(t *testing.T) { const interval = time.Second syncRequestCh := make(chan *syncRequest) syncer := newSyncer(syncRequestCh, interval) - syncer.start() + var wg sync.WaitGroup + syncer.start(&wg) defer syncer.terminate() for _, msg := range inProgress { @@ -66,7 +68,8 @@ func TestSyncerRetry(t *testing.T) { const interval = time.Second syncRequestCh := make(chan *syncRequest) syncer := newSyncer(syncRequestCh, interval) - syncer.start() + var wg sync.WaitGroup + syncer.start(&wg) defer syncer.terminate() for _, msg := range inProgress {