diff --git a/processor.go b/processor.go index f6e198e..f09bb62 100644 --- a/processor.go +++ b/processor.go @@ -11,8 +11,13 @@ type processor struct { handler TaskHandler + // timeout for blocking dequeue operation. + // dequeue needs to timeout to avoid blocking forever + // in case of a program shutdown or additon of a new queue. + dequeueTimeout time.Duration + // sema is a counting semaphore to ensure the number of active workers - // does not exceed the limit + // does not exceed the limit. sema chan struct{} // channel to communicate back to the long running "processor" goroutine. @@ -21,13 +26,15 @@ type processor struct { func newProcessor(rdb *rdb, numWorkers int, handler TaskHandler) *processor { return &processor{ - rdb: rdb, - handler: handler, - sema: make(chan struct{}, numWorkers), - done: make(chan struct{}), + rdb: rdb, + handler: handler, + dequeueTimeout: 5 * time.Second, + sema: make(chan struct{}, numWorkers), + done: make(chan struct{}), } } +// NOTE: once terminated, processor cannot be re-started. func (p *processor) terminate() { log.Println("[INFO] Processor shutting down...") // Signal the processor goroutine to stop processing tasks from the queue. @@ -61,10 +68,7 @@ func (p *processor) start() { // exec pulls a task out of the queue and starts a worker goroutine to // process the task. func (p *processor) exec() { - // NOTE: dequeue needs to timeout to avoid blocking forever - // in case of a program shutdown or additon of a new queue. - const timeout = 5 * time.Second - msg, err := p.rdb.dequeue(defaultQueue, timeout) + msg, err := p.rdb.dequeue(defaultQueue, p.dequeueTimeout) if err == errDequeueTimeout { // timed out, this is a normal behavior. return diff --git a/processor_test.go b/processor_test.go index 4e5c5c0..6d9a14d 100644 --- a/processor_test.go +++ b/processor_test.go @@ -2,9 +2,179 @@ package asynq import ( "fmt" + "sync" "testing" + "time" + + "github.com/google/go-cmp/cmp" ) +func TestProcessorSuccess(t *testing.T) { + r := setup(t) + + m1 := randomTask("send_email", "default", nil) + m2 := randomTask("gen_thumbnail", "default", nil) + m3 := randomTask("reindex", "default", nil) + m4 := randomTask("sync", "default", nil) + + t1 := &Task{Type: m1.Type, Payload: m1.Payload} + t2 := &Task{Type: m2.Type, Payload: m2.Payload} + t3 := &Task{Type: m3.Type, Payload: m3.Payload} + t4 := &Task{Type: m4.Type, Payload: m4.Payload} + + tests := []struct { + initQueue []*taskMessage // initial default queue state + incoming []*taskMessage // tasks to be enqueued during run + wait time.Duration // wait duration between starting and stopping processor for this test case + wantProcessed []*Task // tasks to be processed at the end + }{ + { + initQueue: []*taskMessage{m1}, + incoming: []*taskMessage{m2, m3, m4}, + wait: time.Second, + wantProcessed: []*Task{t1, t2, t3, t4}, + }, + { + initQueue: []*taskMessage{}, + incoming: []*taskMessage{m1}, + wait: time.Second, + wantProcessed: []*Task{t1}, + }, + } + + for _, tc := range tests { + // clean up db before each test case. + if err := r.client.FlushDB().Err(); err != nil { + t.Fatal(err) + } + // instantiate a new processor + var mu sync.Mutex + var processed []*Task + h := func(task *Task) error { + mu.Lock() + defer mu.Unlock() + processed = append(processed, task) + return nil + } + p := newProcessor(r, 10, h) + p.dequeueTimeout = time.Second // short time out for test purpose + // initialize default queue. + for _, msg := range tc.initQueue { + err := r.enqueue(msg) + if err != nil { + t.Fatal(err) + } + } + + p.start() + + for _, msg := range tc.incoming { + err := r.enqueue(msg) + if err != nil { + p.terminate() + t.Fatal(err) + } + } + time.Sleep(tc.wait) + p.terminate() + + if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt); diff != "" { + t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) + } + + if l := r.client.LLen(inProgress).Val(); l != 0 { + t.Errorf("%q has %d tasks, want 0", inProgress, l) + } + } +} + +func TestProcessorRetry(t *testing.T) { + r := setup(t) + + m1 := randomTask("send_email", "default", nil) + m1.Retried = m1.Retry // m1 has reached its max retry count + m2 := randomTask("gen_thumbnail", "default", nil) + m3 := randomTask("reindex", "default", nil) + m4 := randomTask("sync", "default", nil) + + errMsg := "something went wrong" + // r* is m* after retry + r1 := *m1 + r1.ErrorMsg = errMsg + r2 := *m2 + r2.ErrorMsg = errMsg + r2.Retried = m2.Retried + 1 + r3 := *m3 + r3.ErrorMsg = errMsg + r3.Retried = m3.Retried + 1 + r4 := *m4 + r4.ErrorMsg = errMsg + r4.Retried = m4.Retried + 1 + + tests := []struct { + initQueue []*taskMessage // initial default queue state + incoming []*taskMessage // tasks to be enqueued during run + wait time.Duration // wait duration between starting and stopping processor for this test case + wantRetry []*taskMessage // tasks in retry queue at the end + wantDead []*taskMessage // tasks in dead queue at the end + }{ + { + initQueue: []*taskMessage{m1, m2}, + incoming: []*taskMessage{m3, m4}, + wait: time.Second, + wantRetry: []*taskMessage{&r2, &r3, &r4}, + wantDead: []*taskMessage{&r1}, + }, + } + + for _, tc := range tests { + // clean up db before each test case. + if err := r.client.FlushDB().Err(); err != nil { + t.Fatal(err) + } + // instantiate a new processor + h := func(task *Task) error { + return fmt.Errorf(errMsg) + } + p := newProcessor(r, 10, h) + p.dequeueTimeout = time.Second // short time out for test purpose + // initialize default queue. + for _, msg := range tc.initQueue { + err := r.enqueue(msg) + if err != nil { + t.Fatal(err) + } + } + + p.start() + for _, msg := range tc.incoming { + err := r.enqueue(msg) + if err != nil { + p.terminate() + t.Fatal(err) + } + } + time.Sleep(tc.wait) + p.terminate() + + gotRetryRaw := r.client.ZRange(retry, 0, -1).Val() + gotRetry := mustUnmarshalSlice(t, gotRetryRaw) + if diff := cmp.Diff(tc.wantRetry, gotRetry, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q after running processor; (-want, +got)\n%s", retry, diff) + } + + gotDeadRaw := r.client.ZRange(dead, 0, -1).Val() + gotDead := mustUnmarshalSlice(t, gotDeadRaw) + if diff := cmp.Diff(tc.wantDead, gotDead, sortMsgOpt); diff != "" { + t.Errorf("mismatch found in %q after running processor; (-want, +got)\n%s", dead, diff) + } + + if l := r.client.LLen(inProgress).Val(); l != 0 { + t.Errorf("%q has %d tasks, want 0", inProgress, l) + } + } +} + func TestPerform(t *testing.T) { tests := []struct { desc string @@ -15,7 +185,6 @@ func TestPerform(t *testing.T) { { desc: "handler returns nil", handler: func(t *Task) error { - fmt.Println("processing...") return nil }, task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, @@ -24,7 +193,6 @@ func TestPerform(t *testing.T) { { desc: "handler returns error", handler: func(t *Task) error { - fmt.Println("processing...") return fmt.Errorf("something went wrong") }, task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, @@ -33,7 +201,6 @@ func TestPerform(t *testing.T) { { desc: "handler panics", handler: func(t *Task) error { - fmt.Println("processing...") panic("something went terribly wrong") }, task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, diff --git a/rdb_test.go b/rdb_test.go index 0a5a474..17573a8 100644 --- a/rdb_test.go +++ b/rdb_test.go @@ -25,6 +25,14 @@ var sortMsgOpt = cmp.Transformer("SortMsg", func(in []*taskMessage) []*taskMessa return out }) +var sortTaskOpt = cmp.Transformer("SortMsg", func(in []*Task) []*Task { + out := append([]*Task(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + return out[i].Type < out[j].Type + }) + return out +}) + // setup connects to a redis database and flush all keys // before returning an instance of rdb. func setup(t *testing.T) *rdb {