From d4f136ebc9ae9f37395c9861a0c1b98472695c9c Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Wed, 27 Nov 2019 14:03:04 -0800 Subject: [PATCH] Protect handler call against panic --- processor.go | 14 +++++++++++- processor_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 processor_test.go diff --git a/processor.go b/processor.go index f64737a..2ec4745 100644 --- a/processor.go +++ b/processor.go @@ -88,7 +88,7 @@ func (p *processor) exec() { } <-p.sema // release token }() - err := p.handler(task) // TODO(hibiken): maybe also handle panic? + err := perform(p.handler, task) if err != nil { retryTask(p.rdb, msg, err) } @@ -103,3 +103,15 @@ func (p *processor) restore() { log.Printf("[ERROR] could not move tasks from %q to %q\n", inProgress, defaultQueue) } } + +// perform calls the handler with the given task. +// If the call returns without panic, it simply returns the value, +// otherwise, it recovers from panic and returns an error. +func perform(handler TaskHandler, task *Task) (err error) { + defer func() { + if x := recover(); x != nil { + err = fmt.Errorf("panic: %v", x) + } + }() + return handler(task) +} diff --git a/processor_test.go b/processor_test.go new file mode 100644 index 0000000..4e5c5c0 --- /dev/null +++ b/processor_test.go @@ -0,0 +1,55 @@ +package asynq + +import ( + "fmt" + "testing" +) + +func TestPerform(t *testing.T) { + tests := []struct { + desc string + handler TaskHandler + task *Task + wantErr bool + }{ + { + desc: "handler returns nil", + handler: func(t *Task) error { + fmt.Println("processing...") + return nil + }, + task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, + wantErr: false, + }, + { + desc: "handler returns error", + handler: func(t *Task) error { + fmt.Println("processing...") + return fmt.Errorf("something went wrong") + }, + task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, + wantErr: true, + }, + { + desc: "handler panics", + handler: func(t *Task) error { + fmt.Println("processing...") + panic("something went terribly wrong") + }, + task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, + wantErr: true, + }, + } + + for _, tc := range tests { + got := perform(tc.handler, tc.task) + if !tc.wantErr && got != nil { + t.Errorf("%s: perform() = %v, want nil", tc.desc, got) + continue + } + if tc.wantErr && got == nil { + t.Errorf("%s: perform() = nil, want non-nil error", tc.desc) + continue + } + } +}