diff --git a/servemux.go b/servemux.go index f9e13b4..2142ef6 100644 --- a/servemux.go +++ b/servemux.go @@ -23,9 +23,10 @@ import ( // "images:thumbnails" and the former will receive tasks with type name beginning // with "images". type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry - es []muxEntry // slice of entries sorted from longest to shortest. + mu sync.RWMutex + m map[string]muxEntry + es []muxEntry // slice of entries sorted from longest to shortest. + mws []MiddlewareFunc } type muxEntry struct { @@ -33,6 +34,11 @@ type muxEntry struct { pattern string } +// MiddlewareFunc is a function which receives an asynq.Handler and returns another asynq.Handler. +// Typically, the returned handler is a closure which does something with the context and task passed +// to it, and then calls the handler passed as parameter to the MiddlewareFunc. +type MiddlewareFunc func(Handler) Handler + // NewServeMux allocates and returns a new ServeMux. func NewServeMux() *ServeMux { return new(ServeMux) @@ -60,6 +66,9 @@ func (mux *ServeMux) Handler(t *Task) (h Handler, pattern string) { if h == nil { h, pattern = NotFoundHandler(), "" } + for i := len(mux.mws) - 1; i >= 0; i-- { + h = mux.mws[i](h) + } return h, pattern } @@ -130,6 +139,16 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(context.Context, *T mux.Handle(pattern, HandlerFunc(handler)) } +// Use appends a MiddlewareFunc to the chain. +// Middlewares are executed in the order that they are applied to the ServeMux. +func (mux *ServeMux) Use(mws ...MiddlewareFunc) { + mux.mu.Lock() + defer mux.mu.Unlock() + for _, fn := range mws { + mux.mws = append(mux.mws, fn) + } +} + // NotFound returns an error indicating that the handler was not found for the given task. func NotFound(ctx context.Context, task *Task) error { return fmt.Errorf("handler not found for task %q", task.Type) diff --git a/servemux_test.go b/servemux_test.go index 7b0c4cc..98dd52f 100644 --- a/servemux_test.go +++ b/servemux_test.go @@ -7,9 +7,12 @@ package asynq import ( "context" "testing" + + "github.com/google/go-cmp/cmp" ) -var called string +var called string // identity of the handler that was called. +var invoked []string // list of middlewares in the order they were invoked. // makeFakeHandler returns a handler that updates the global called variable // to the given identity. @@ -20,6 +23,17 @@ func makeFakeHandler(identity string) Handler { }) } +// makeFakeMiddleware returns a middleware function that appends the given identity +//to the global invoked slice. +func makeFakeMiddleware(identity string) MiddlewareFunc { + return func(next Handler) Handler { + return HandlerFunc(func(ctx context.Context, t *Task) error { + invoked = append(invoked, identity) + return next.ProcessTask(ctx, t) + }) + } +} + // A list of pattern, handler pair that is registered with mux. var serveMuxRegister = []struct { pattern string @@ -114,3 +128,43 @@ func TestServeMuxNotFound(t *testing.T) { } } } + +var middlewareTests = []struct { + typename string // task's type name + middlewares []string // middlewares to use. They should be called in this order. + want string // identifier of the handler that should be called +}{ + {"email:signup", []string{"logging", "expiration"}, "signup email handler"}, + {"csv:export", []string{}, "csv export handler"}, + {"email:daily", []string{"expiration", "logging"}, "default email handler"}, +} + +func TestServeMuxMiddlewares(t *testing.T) { + for _, tc := range middlewareTests { + mux := NewServeMux() + for _, e := range serveMuxRegister { + mux.Handle(e.pattern, e.h) + } + var mws []MiddlewareFunc + for _, s := range tc.middlewares { + mws = append(mws, makeFakeMiddleware(s)) + } + mux.Use(mws...) + + invoked = []string{} // reset to empty slice + called = "" // reset to zero value + + task := NewTask(tc.typename, nil) + if err := mux.ProcessTask(context.Background(), task); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(invoked, tc.middlewares); diff != "" { + t.Errorf("invoked middlewares were %v, want %v", invoked, tc.middlewares) + } + + if called != tc.want { + t.Errorf("%q handler was called for task %q, want %q to be called", called, task.Type, tc.want) + } + } +}