mirror of
https://github.com/hibiken/asynq.git
synced 2024-11-10 11:31:58 +08:00
Add Use method to better support middlewares with ServeMux
This commit is contained in:
parent
cf7a677312
commit
516f95edff
25
servemux.go
25
servemux.go
@ -23,9 +23,10 @@ import (
|
||||
// "images:thumbnails" and the former will receive tasks with type name beginning
|
||||
// with "images".
|
||||
type ServeMux struct {
|
||||
mu sync.RWMutex
|
||||
m map[string]muxEntry
|
||||
es []muxEntry // slice of entries sorted from longest to shortest.
|
||||
mu sync.RWMutex
|
||||
m map[string]muxEntry
|
||||
es []muxEntry // slice of entries sorted from longest to shortest.
|
||||
mws []MiddlewareFunc
|
||||
}
|
||||
|
||||
type muxEntry struct {
|
||||
@ -33,6 +34,11 @@ type muxEntry struct {
|
||||
pattern string
|
||||
}
|
||||
|
||||
// MiddlewareFunc is a function which receives an asynq.Handler and returns another asynq.Handler.
|
||||
// Typically, the returned handler is a closure which does something with the context and task passed
|
||||
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
|
||||
type MiddlewareFunc func(Handler) Handler
|
||||
|
||||
// NewServeMux allocates and returns a new ServeMux.
|
||||
func NewServeMux() *ServeMux {
|
||||
return new(ServeMux)
|
||||
@ -60,6 +66,9 @@ func (mux *ServeMux) Handler(t *Task) (h Handler, pattern string) {
|
||||
if h == nil {
|
||||
h, pattern = NotFoundHandler(), ""
|
||||
}
|
||||
for i := len(mux.mws) - 1; i >= 0; i-- {
|
||||
h = mux.mws[i](h)
|
||||
}
|
||||
return h, pattern
|
||||
}
|
||||
|
||||
@ -130,6 +139,16 @@ func (mux *ServeMux) HandleFunc(pattern string, handler func(context.Context, *T
|
||||
mux.Handle(pattern, HandlerFunc(handler))
|
||||
}
|
||||
|
||||
// Use appends a MiddlewareFunc to the chain.
|
||||
// Middlewares are executed in the order that they are applied to the ServeMux.
|
||||
func (mux *ServeMux) Use(mws ...MiddlewareFunc) {
|
||||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
for _, fn := range mws {
|
||||
mux.mws = append(mux.mws, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// NotFound returns an error indicating that the handler was not found for the given task.
|
||||
func NotFound(ctx context.Context, task *Task) error {
|
||||
return fmt.Errorf("handler not found for task %q", task.Type)
|
||||
|
@ -7,9 +7,12 @@ package asynq
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
var called string
|
||||
var called string // identity of the handler that was called.
|
||||
var invoked []string // list of middlewares in the order they were invoked.
|
||||
|
||||
// makeFakeHandler returns a handler that updates the global called variable
|
||||
// to the given identity.
|
||||
@ -20,6 +23,17 @@ func makeFakeHandler(identity string) Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// makeFakeMiddleware returns a middleware function that appends the given identity
|
||||
//to the global invoked slice.
|
||||
func makeFakeMiddleware(identity string) MiddlewareFunc {
|
||||
return func(next Handler) Handler {
|
||||
return HandlerFunc(func(ctx context.Context, t *Task) error {
|
||||
invoked = append(invoked, identity)
|
||||
return next.ProcessTask(ctx, t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// A list of pattern, handler pair that is registered with mux.
|
||||
var serveMuxRegister = []struct {
|
||||
pattern string
|
||||
@ -114,3 +128,43 @@ func TestServeMuxNotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var middlewareTests = []struct {
|
||||
typename string // task's type name
|
||||
middlewares []string // middlewares to use. They should be called in this order.
|
||||
want string // identifier of the handler that should be called
|
||||
}{
|
||||
{"email:signup", []string{"logging", "expiration"}, "signup email handler"},
|
||||
{"csv:export", []string{}, "csv export handler"},
|
||||
{"email:daily", []string{"expiration", "logging"}, "default email handler"},
|
||||
}
|
||||
|
||||
func TestServeMuxMiddlewares(t *testing.T) {
|
||||
for _, tc := range middlewareTests {
|
||||
mux := NewServeMux()
|
||||
for _, e := range serveMuxRegister {
|
||||
mux.Handle(e.pattern, e.h)
|
||||
}
|
||||
var mws []MiddlewareFunc
|
||||
for _, s := range tc.middlewares {
|
||||
mws = append(mws, makeFakeMiddleware(s))
|
||||
}
|
||||
mux.Use(mws...)
|
||||
|
||||
invoked = []string{} // reset to empty slice
|
||||
called = "" // reset to zero value
|
||||
|
||||
task := NewTask(tc.typename, nil)
|
||||
if err := mux.ProcessTask(context.Background(), task); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(invoked, tc.middlewares); diff != "" {
|
||||
t.Errorf("invoked middlewares were %v, want %v", invoked, tc.middlewares)
|
||||
}
|
||||
|
||||
if called != tc.want {
|
||||
t.Errorf("%q handler was called for task %q, want %q to be called", called, task.Type, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user