From 9699d196e575232b0803c0d4fdcf4ee215f8eec5 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Sun, 21 Jun 2020 07:05:57 -0700 Subject: [PATCH] Add recoverer --- internal/asynqtest/asynqtest.go | 14 +++ internal/base/base.go | 1 + internal/rdb/rdb_test.go | 149 --------------------------- internal/testbroker/testbroker.go | 9 ++ recoverer.go | 96 ++++++++++++++++++ recoverer_test.go | 162 ++++++++++++++++++++++++++++++ server.go | 10 ++ 7 files changed, 292 insertions(+), 149 deletions(-) create mode 100644 recoverer.go create mode 100644 recoverer_test.go diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 0f18140..4df8d2b 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -97,6 +97,20 @@ func NewTaskMessageWithQueue(taskType string, payload map[string]interface{}, qn } } +// TaskMessageAfterRetry returns an updated copy of t after retry. +// It increments retry count and sets the error message. +func TaskMessageAfterRetry(t base.TaskMessage, errMsg string) *base.TaskMessage { + t.Retried = t.Retried + 1 + t.ErrorMsg = errMsg + return &t +} + +// TaskMessageWithError returns an updated copy of t with the given error message. +func TaskMessageWithError(t base.TaskMessage, errMsg string) *base.TaskMessage { + t.ErrorMsg = errMsg + return &t +} + // MustMarshal marshals given task message and returns a json string. // Calling test will fail if marshaling errors out. func MustMarshal(tb testing.TB, msg *base.TaskMessage) string { diff --git a/internal/base/base.go b/internal/base/base.go index 08b96b1..209e8a3 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -264,6 +264,7 @@ type Broker interface { Retry(msg *TaskMessage, processAt time.Time, errMsg string) error Kill(msg *TaskMessage, errMsg string) error CheckAndEnqueue() error + ListDeadlineExceeded(deadline time.Time) ([]*TaskMessage, error) WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error ClearServerState(host string, pid int, serverID string) error CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 77e88c0..6c5ca2e 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -1192,155 +1192,6 @@ func TestListDeadlineExceeded(t *testing.T) { } } -/* -func TestRequeueDeadlineExceeded(t *testing.T) { - t1 := h.NewTaskMessage("task1", nil) - t2 := h.NewTaskMessage("task2", nil) - t3 := h.NewTaskMessageWithQueue("task3", nil, "critical") - - now := time.Now() - oneHourFromNow := now.Add(1 * time.Hour) - fiveMinutesFromNow := now.Add(5 * time.Minute) - fiveMinutesAgo := now.Add(-5 * time.Minute) - oneHourAgo := now.Add(-1 * time.Hour) - - tests := []struct { - desc string - inProgress []*base.TaskMessage - deadlines []h.ZSetEntry - enqueued map[string][]*base.TaskMessage - wantN int - wantInProgress []*base.TaskMessage - wantDeadlines []h.ZSetEntry - wantEnqueued map[string][]*base.TaskMessage - }{ - { - desc: "with one task in-progress", - inProgress: []*base.TaskMessage{t1}, - deadlines: []h.ZSetEntry{ - {Msg: t1, Score: float64(fiveMinutesAgo.Unix())}, - }, - enqueued: map[string][]*base.TaskMessage{ - "default": {}, - }, - wantN: 1, - wantInProgress: []*base.TaskMessage{}, - wantDeadlines: []h.ZSetEntry{}, - wantEnqueued: map[string][]*base.TaskMessage{ - "default": {t1}, - }, - }, - { - desc: "with multiple tasks in-progress, and one expired", - inProgress: []*base.TaskMessage{t1, t2, t3}, - deadlines: []h.ZSetEntry{ - {Msg: t1, Score: float64(oneHourAgo.Unix())}, - {Msg: t2, Score: float64(fiveMinutesFromNow.Unix())}, - {Msg: t3, Score: float64(oneHourFromNow.Unix())}, - }, - enqueued: map[string][]*base.TaskMessage{ - "default": {}, - }, - wantN: 1, - wantInProgress: []*base.TaskMessage{t2, t3}, - wantDeadlines: []h.ZSetEntry{ - {Msg: t2, Score: float64(fiveMinutesFromNow.Unix())}, - {Msg: t3, Score: float64(oneHourFromNow.Unix())}, - }, - wantEnqueued: map[string][]*base.TaskMessage{ - "default": {t1}, - }, - }, - { - desc: "with multiple expired tasks in-progress", - inProgress: []*base.TaskMessage{t1, t2, t3}, - deadlines: []h.ZSetEntry{ - {Msg: t1, Score: float64(oneHourAgo.Unix())}, - {Msg: t2, Score: float64(fiveMinutesAgo.Unix())}, - {Msg: t3, Score: float64(oneHourFromNow.Unix())}, - }, - enqueued: map[string][]*base.TaskMessage{ - "default": {}, - }, - wantN: 2, - wantInProgress: []*base.TaskMessage{t3}, - wantDeadlines: []h.ZSetEntry{ - {Msg: t3, Score: float64(oneHourFromNow.Unix())}, - }, - wantEnqueued: map[string][]*base.TaskMessage{ - "default": {t1, t2}, - }, - }, - { - desc: "with empty in-progress queue", - inProgress: []*base.TaskMessage{}, - deadlines: []h.ZSetEntry{}, - enqueued: map[string][]*base.TaskMessage{ - "default": {}, - }, - wantN: 0, - wantInProgress: []*base.TaskMessage{}, - wantDeadlines: []h.ZSetEntry{}, - wantEnqueued: map[string][]*base.TaskMessage{ - "default": {}, - }, - }, - { - desc: "push back task to appropriate queue", - inProgress: []*base.TaskMessage{t3}, - deadlines: []h.ZSetEntry{ - {Msg: t3, Score: float64(fiveMinutesAgo.Unix())}, - }, - enqueued: map[string][]*base.TaskMessage{ - "default": {}, - "critical": {}, - }, - wantN: 1, - wantInProgress: []*base.TaskMessage{}, - wantDeadlines: []h.ZSetEntry{}, - wantEnqueued: map[string][]*base.TaskMessage{ - "default": {}, - "critical": {t3}, - }, - }, - } - - r := setup(t) - for _, tc := range tests { - h.FlushDB(t, r.client) - h.SeedInProgressQueue(t, r.client, tc.inProgress) - h.SeedDeadlines(t, r.client, tc.deadlines) - for queue, msgs := range tc.enqueued { - h.SeedEnqueuedQueue(t, r.client, msgs, queue) - } - - gotN, err := r.RequeueDeadlineExceeded() - if err != nil { - t.Errorf("%s; RequeueDeadlineExceeded() returned error: %v", tc.desc, err) - continue - } - if gotN != tc.wantN { - t.Errorf("%s; RequeueDeadlineExeeded() == %d want %d", tc.desc, gotN, tc.wantN) - } - - gotInProgress := h.GetInProgressMessages(t, r.client) - if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" { - t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", tc.desc, base.InProgressQueue, diff) - } - gotDeadlines := h.GetDeadlinesEntries(t, r.client) - if diff := cmp.Diff(tc.wantDeadlines, gotDeadlines, h.SortZSetEntryOpt); diff != "" { - t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", tc.desc, base.KeyDeadlines, diff) - } - for queue, want := range tc.wantEnqueued { - gotEnqueued := h.GetEnqueuedMessages(t, r.client, queue) - if diff := cmp.Diff(want, gotEnqueued, h.SortMsgOpt); diff != "" { - t.Errorf("%s; mismatch found in %q: (-want,+got):\n%s", tc.desc, base.QueueKey(queue), diff) - } - } - } -} -*/ - func TestWriteServerState(t *testing.T) { r := setup(t) diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 2b0660a..227f61e 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -132,6 +132,15 @@ func (tb *TestBroker) CheckAndEnqueue() error { return tb.real.CheckAndEnqueue() } +func (tb *TestBroker) ListDeadlineExceeded(deadline time.Time) ([]*base.TaskMessage, error) { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return nil, errRedisDown + } + return tb.real.ListDeadlineExceeded(deadline) +} + func (tb *TestBroker) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error { tb.mu.Lock() defer tb.mu.Unlock() diff --git a/recoverer.go b/recoverer.go new file mode 100644 index 0000000..7fde8a4 --- /dev/null +++ b/recoverer.go @@ -0,0 +1,96 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "fmt" + "sync" + "time" + + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/log" +) + +type recoverer struct { + logger *log.Logger + broker base.Broker + retryDelayFunc retryDelayFunc + + // channel to communicate back to the long running "recoverer" goroutine. + done chan struct{} + + // poll interval. + interval time.Duration +} + +type recovererParams struct { + logger *log.Logger + broker base.Broker + interval time.Duration + retryDelayFunc retryDelayFunc +} + +func newRecoverer(params recovererParams) *recoverer { + return &recoverer{ + logger: params.logger, + broker: params.broker, + done: make(chan struct{}), + interval: params.interval, + retryDelayFunc: params.retryDelayFunc, + } +} + +func (r *recoverer) terminate() { + r.logger.Debug("Recoverer shutting down...") + // Signal the recoverer goroutine to stop polling. + r.done <- struct{}{} +} + +func (r *recoverer) start(wg *sync.WaitGroup) { + wg.Add(1) + go func() { + defer wg.Done() + timer := time.NewTimer(r.interval) + for { + select { + case <-r.done: + r.logger.Debug("Recoverer done") + timer.Stop() + return + case <-timer.C: + // Get all tasks which have expired 30 seconds ago or earlier. + deadline := time.Now().Add(-30 * time.Second) + msgs, err := r.broker.ListDeadlineExceeded(deadline) + if err != nil { + r.logger.Warn("recoverer: could not list deadline exceeded tasks") + continue + } + const errMsg = "deadline exceeded" // TODO: better error message + for _, msg := range msgs { + if msg.Retried >= msg.Retry { + r.kill(msg, errMsg) + } else { + r.retry(msg, errMsg) + } + } + + } + } + }() +} + +func (r *recoverer) retry(msg *base.TaskMessage, errMsg string) { + delay := r.retryDelayFunc(msg.Retried, fmt.Errorf(errMsg), NewTask(msg.Type, msg.Payload)) + retryAt := time.Now().Add(delay) + if err := r.broker.Retry(msg, retryAt, errMsg); err != nil { + r.logger.Warnf("recoverer: could not retry deadline exceeded task: %v", err) + } +} + +func (r *recoverer) kill(msg *base.TaskMessage, errMsg string) { + if err := r.broker.Kill(msg, errMsg); err != nil { + r.logger.Warnf("recoverer: could not move task to dead queue: %v", err) + } +} diff --git a/recoverer_test.go b/recoverer_test.go new file mode 100644 index 0000000..5d33f9a --- /dev/null +++ b/recoverer_test.go @@ -0,0 +1,162 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + h "github.com/hibiken/asynq/internal/asynqtest" + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/rdb" +) + +func TestRecoverer(t *testing.T) { + r := setup(t) + rdbClient := rdb.NewRDB(r) + + t1 := h.NewTaskMessage("task1", nil) + t2 := h.NewTaskMessage("task2", nil) + t3 := h.NewTaskMessageWithQueue("task3", nil, "critical") + t4 := h.NewTaskMessage("task4", nil) + t4.Retried = t4.Retry // t4 has reached its max retry count + + now := time.Now() + oneHourFromNow := now.Add(1 * time.Hour) + fiveMinutesFromNow := now.Add(5 * time.Minute) + fiveMinutesAgo := now.Add(-5 * time.Minute) + oneHourAgo := now.Add(-1 * time.Hour) + + tests := []struct { + desc string + inProgress []*base.TaskMessage + deadlines []h.ZSetEntry + retry []h.ZSetEntry + dead []h.ZSetEntry + wantInProgress []*base.TaskMessage + wantDeadlines []h.ZSetEntry + wantRetry []*base.TaskMessage + wantDead []*base.TaskMessage + }{ + { + desc: "with one task in-progress", + inProgress: []*base.TaskMessage{t1}, + deadlines: []h.ZSetEntry{ + {Msg: t1, Score: float64(fiveMinutesAgo.Unix())}, + }, + retry: []h.ZSetEntry{}, + dead: []h.ZSetEntry{}, + wantInProgress: []*base.TaskMessage{}, + wantDeadlines: []h.ZSetEntry{}, + wantRetry: []*base.TaskMessage{ + h.TaskMessageAfterRetry(*t1, "deadline exceeded"), + }, + wantDead: []*base.TaskMessage{}, + }, + { + desc: "with a task with max-retry reached", + inProgress: []*base.TaskMessage{t4}, + deadlines: []h.ZSetEntry{ + {Msg: t4, Score: float64(fiveMinutesAgo.Unix())}, + }, + retry: []h.ZSetEntry{}, + dead: []h.ZSetEntry{}, + wantInProgress: []*base.TaskMessage{}, + wantDeadlines: []h.ZSetEntry{}, + wantRetry: []*base.TaskMessage{}, + wantDead: []*base.TaskMessage{h.TaskMessageWithError(*t4, "deadline exceeded")}, + }, + { + desc: "with multiple tasks in-progress, and one expired", + inProgress: []*base.TaskMessage{t1, t2, t3}, + deadlines: []h.ZSetEntry{ + {Msg: t1, Score: float64(oneHourAgo.Unix())}, + {Msg: t2, Score: float64(fiveMinutesFromNow.Unix())}, + {Msg: t3, Score: float64(oneHourFromNow.Unix())}, + }, + retry: []h.ZSetEntry{}, + dead: []h.ZSetEntry{}, + wantInProgress: []*base.TaskMessage{t2, t3}, + wantDeadlines: []h.ZSetEntry{ + {Msg: t2, Score: float64(fiveMinutesFromNow.Unix())}, + {Msg: t3, Score: float64(oneHourFromNow.Unix())}, + }, + wantRetry: []*base.TaskMessage{ + h.TaskMessageAfterRetry(*t1, "deadline exceeded"), + }, + wantDead: []*base.TaskMessage{}, + }, + { + desc: "with multiple expired tasks in-progress", + inProgress: []*base.TaskMessage{t1, t2, t3}, + deadlines: []h.ZSetEntry{ + {Msg: t1, Score: float64(oneHourAgo.Unix())}, + {Msg: t2, Score: float64(fiveMinutesAgo.Unix())}, + {Msg: t3, Score: float64(oneHourFromNow.Unix())}, + }, + retry: []h.ZSetEntry{}, + dead: []h.ZSetEntry{}, + wantInProgress: []*base.TaskMessage{t3}, + wantDeadlines: []h.ZSetEntry{ + {Msg: t3, Score: float64(oneHourFromNow.Unix())}, + }, + wantRetry: []*base.TaskMessage{ + h.TaskMessageAfterRetry(*t1, "deadline exceeded"), + h.TaskMessageAfterRetry(*t2, "deadline exceeded"), + }, + wantDead: []*base.TaskMessage{}, + }, + { + desc: "with empty in-progress queue", + inProgress: []*base.TaskMessage{}, + deadlines: []h.ZSetEntry{}, + retry: []h.ZSetEntry{}, + dead: []h.ZSetEntry{}, + wantInProgress: []*base.TaskMessage{}, + wantDeadlines: []h.ZSetEntry{}, + wantRetry: []*base.TaskMessage{}, + wantDead: []*base.TaskMessage{}, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) + h.SeedInProgressQueue(t, r, tc.inProgress) + h.SeedDeadlines(t, r, tc.deadlines) + h.SeedRetryQueue(t, r, tc.retry) + h.SeedDeadQueue(t, r, tc.dead) + + recoverer := newRecoverer(recovererParams{ + logger: testLogger, + broker: rdbClient, + interval: 1 * time.Second, + retryDelayFunc: func(n int, err error, task *Task) time.Duration { return 30 * time.Second }, + }) + + var wg sync.WaitGroup + recoverer.start(&wg) + time.Sleep(2 * time.Second) + recoverer.terminate() + + gotInProgress := h.GetInProgressMessages(t, r) + if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" { + t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", tc.desc, base.InProgressQueue, diff) + } + gotDeadlines := h.GetDeadlinesEntries(t, r) + if diff := cmp.Diff(tc.wantDeadlines, gotDeadlines, h.SortZSetEntryOpt); diff != "" { + t.Errorf("%s; mismatch found in %q; (-want,+got)\n%s", tc.desc, base.KeyDeadlines, diff) + } + gotRetry := h.GetRetryMessages(t, r) + if diff := cmp.Diff(tc.wantRetry, gotRetry, h.SortMsgOpt); diff != "" { + t.Errorf("%s; mismatch found in %q: (-want, +got)\n%s", tc.desc, base.RetryQueue, diff) + } + gotDead := h.GetDeadMessages(t, r) + if diff := cmp.Diff(tc.wantDead, gotDead, h.SortMsgOpt); diff != "" { + t.Errorf("%s; mismatch found in %q: (-want, +got)\n%s", tc.desc, base.DeadQueue, diff) + } + } +} diff --git a/server.go b/server.go index 896c03c..95e7ee3 100644 --- a/server.go +++ b/server.go @@ -46,6 +46,7 @@ type Server struct { syncer *syncer heartbeater *heartbeater subscriber *subscriber + recoverer *recoverer } // Config specifies the server's background-task processing behavior. @@ -329,6 +330,12 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { starting: starting, finished: finished, }) + recoverer := newRecoverer(recovererParams{ + logger: logger, + broker: rdb, + retryDelayFunc: delayFunc, + interval: 1 * time.Minute, + }) return &Server{ logger: logger, broker: rdb, @@ -338,6 +345,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { syncer: syncer, heartbeater: heartbeater, subscriber: subscriber, + recoverer: recoverer, } } @@ -407,6 +415,7 @@ func (srv *Server) Start(handler Handler) error { srv.heartbeater.start(&srv.wg) srv.subscriber.start(&srv.wg) srv.syncer.start(&srv.wg) + srv.recoverer.start(&srv.wg) srv.scheduler.start(&srv.wg) srv.processor.start(&srv.wg) return nil @@ -430,6 +439,7 @@ func (srv *Server) Stop() { // processor -> heartbeater (via starting, finished channels) srv.scheduler.terminate() srv.processor.terminate() + srv.recoverer.terminate() srv.syncer.terminate() srv.subscriber.terminate() srv.heartbeater.terminate()