diff --git a/README.md b/README.md index 0bee5e6..3242832 100644 --- a/README.md +++ b/README.md @@ -66,11 +66,11 @@ func main() { func handler(t *asynq.Task) error { switch t.Type { case "send_welcome_email": - rid, ok := t.Payload["recipient_id"] - if !ok { - return fmt.Errorf("recipient_id not found in payload") + id, err := t.Payload.GetInt("recipient_id") + if err != nil{ + return err } - fmt.Printf("Send Welcome Email to %d\n", rid.(int)) + fmt.Printf("Send Welcome Email to %d\n", id) // ... handle other task types. diff --git a/asynq.go b/asynq.go index 8ff40a5..06f6cc2 100644 --- a/asynq.go +++ b/asynq.go @@ -24,9 +24,8 @@ type Task struct { // Type indicates the kind of the task to be performed. Type string - // Payload is an arbitrary data needed for task execution. - // The value has to be serializable. - Payload map[string]interface{} + // Payload holds data needed for the task execution. + Payload Payload } // RedisConfig specifies redis configurations. diff --git a/go.mod b/go.mod index 5118097..6966a3c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/pelletier/go-toml v1.6.0 // indirect github.com/rs/xid v1.2.1 github.com/spf13/afero v1.2.2 // indirect + github.com/spf13/cast v1.3.1 github.com/spf13/cobra v0.0.5 github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index b9e9336..a409284 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,8 @@ github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9Gc1vn7yk= diff --git a/payload.go b/payload.go new file mode 100644 index 0000000..0ca0393 --- /dev/null +++ b/payload.go @@ -0,0 +1,155 @@ +package asynq + +import ( + "fmt" + "time" + + "github.com/spf13/cast" +) + +// Payload is an arbitrary data needed for task execution. +// The values have to be JSON serializable. +type Payload map[string]interface{} + +type errKeyNotFound struct { + key string +} + +func (e *errKeyNotFound) Error() string { + return fmt.Sprintf("key %q does not exist", e.key) +} + +// GetString returns a string value if a string type is associated with +// the key, otherwise reports an error. +func (p Payload) GetString(key string) (string, error) { + v, ok := p[key] + if !ok { + return "", &errKeyNotFound{key} + } + return cast.ToStringE(v) +} + +// GetInt returns an int value if a numeric type is associated with +// the key, otherwise reports an error. +func (p Payload) GetInt(key string) (int, error) { + v, ok := p[key] + if !ok { + return 0, &errKeyNotFound{key} + } + return cast.ToIntE(v) +} + +// GetFloat64 returns a float64 value if a numeric type is associated with +// the key, otherwise reports an error. +func (p Payload) GetFloat64(key string) (float64, error) { + v, ok := p[key] + if !ok { + return 0, &errKeyNotFound{key} + } + return cast.ToFloat64E(v) +} + +// GetBool returns a boolean value if a boolean type is associated with +// the key, otherwise reports an error. +func (p Payload) GetBool(key string) (bool, error) { + v, ok := p[key] + if !ok { + return false, &errKeyNotFound{key} + } + return cast.ToBoolE(v) +} + +// GetStringSlice returns a slice of strings if a string slice type is associated with +// the key, otherwise reports an error. +func (p Payload) GetStringSlice(key string) ([]string, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringSliceE(v) +} + +// GetIntSlice returns a slice of ints if a int slice type is associated with +// the key, otherwise reports an error. +func (p Payload) GetIntSlice(key string) ([]int, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToIntSliceE(v) +} + +// GetStringMap returns a map of string to empty interface +// if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetStringMap(key string) (map[string]interface{}, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringMapE(v) +} + +// GetStringMapString returns a map of string to string +// if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetStringMapString(key string) (map[string]string, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringMapStringE(v) +} + +// GetStringMapStringSlice returns a map of string to string slice +// if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetStringMapStringSlice(key string) (map[string][]string, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringMapStringSliceE(v) +} + +// GetStringMapInt returns a map of string to int +// if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetStringMapInt(key string) (map[string]int, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringMapIntE(v) +} + +// GetStringMapBool returns a map of string to boolean +// if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetStringMapBool(key string) (map[string]bool, error) { + v, ok := p[key] + if !ok { + return nil, &errKeyNotFound{key} + } + return cast.ToStringMapBoolE(v) +} + +// GetTime returns a time value if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetTime(key string) (time.Time, error) { + v, ok := p[key] + if !ok { + return time.Time{}, &errKeyNotFound{key} + } + return cast.ToTimeE(v) +} + +// GetDuration returns a duration value if a correct map type is associated with the key, +// otherwise reports an error. +func (p Payload) GetDuration(key string) (time.Duration, error) { + v, ok := p[key] + if !ok { + return 0, &errKeyNotFound{key} + } + return cast.ToDurationE(v) +} diff --git a/payload_test.go b/payload_test.go new file mode 100644 index 0000000..e943bb8 --- /dev/null +++ b/payload_test.go @@ -0,0 +1,254 @@ +package asynq + +import ( + "encoding/json" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestPayload(t *testing.T) { + names := []string{"luke", "anakin", "ray"} + primes := []int{2, 3, 5, 7, 11, 13, 17} + user := map[string]interface{}{"name": "Ken", "score": 3.14} + location := map[string]string{"address": "123 Main St.", "state": "NY", "zipcode": "10002"} + favs := map[string][]string{ + "movies": []string{"forrest gump", "star wars"}, + "tv_shows": []string{"game of throwns", "HIMYM", "breaking bad"}, + } + counter := map[string]int{ + "a": 1, + "b": 101, + "c": 42, + } + features := map[string]bool{ + "A": false, + "B": true, + "C": true, + } + now := time.Now() + duration := 15 * time.Minute + + payload := Payload{ + "greeting": "Hello", + "user_id": 9876, + "pi": 3.1415, + "enabled": false, + "names": names, + "primes": primes, + "user": user, + "location": location, + "favs": favs, + "counter": counter, + "features": features, + "timestamp": now, + "duration": duration, + } + + gotStr, err := payload.GetString("greeting") + if gotStr != "Hello" || err != nil { + t.Errorf("Payload.GetString(%q) = %v, %v, want %v, nil", + "greeting", gotStr, err, "Hello") + } + + gotInt, err := payload.GetInt("user_id") + if gotInt != 9876 || err != nil { + t.Errorf("Payload.GetInt(%q) = %v, %v, want, %v, nil", + "user_id", gotInt, err, 9876) + } + + gotFloat, err := payload.GetFloat64("pi") + if gotFloat != 3.1415 || err != nil { + t.Errorf("Payload.GetFloat64(%q) = %v, %v, want, %v, nil", + "pi", gotFloat, err, 3.141592) + } + + gotBool, err := payload.GetBool("enabled") + if gotBool != false || err != nil { + t.Errorf("Payload.GetBool(%q) = %v, %v, want, %v, nil", + "enabled", gotBool, err, false) + } + + gotStrSlice, err := payload.GetStringSlice("names") + if diff := cmp.Diff(gotStrSlice, names); diff != "" { + t.Errorf("Payload.GetStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "names", gotStrSlice, err, names, diff) + } + + gotIntSlice, err := payload.GetIntSlice("primes") + if diff := cmp.Diff(gotIntSlice, primes); diff != "" { + t.Errorf("Payload.GetIntSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "primes", gotIntSlice, err, primes, diff) + } + + gotStrMap, err := payload.GetStringMap("user") + if diff := cmp.Diff(gotStrMap, user); diff != "" { + t.Errorf("Payload.GetStringMap(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "user", gotStrMap, err, user, diff) + } + + gotStrMapStr, err := payload.GetStringMapString("location") + if diff := cmp.Diff(gotStrMapStr, location); diff != "" { + t.Errorf("Payload.GetStringMapString(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "location", gotStrMapStr, err, location, diff) + } + + gotStrMapStrSlice, err := payload.GetStringMapStringSlice("favs") + if diff := cmp.Diff(gotStrMapStrSlice, favs); diff != "" { + t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "favs", gotStrMapStrSlice, err, favs, diff) + } + + gotStrMapInt, err := payload.GetStringMapInt("counter") + if diff := cmp.Diff(gotStrMapInt, counter); diff != "" { + t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "counter", gotStrMapInt, err, counter, diff) + } + + gotStrMapBool, err := payload.GetStringMapBool("features") + if diff := cmp.Diff(gotStrMapBool, features); diff != "" { + t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "features", gotStrMapBool, err, features, diff) + } + + gotTime, err := payload.GetTime("timestamp") + if !gotTime.Equal(now) { + t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, nil", + "timestamp", gotTime, err, now) + } + + gotDuration, err := payload.GetDuration("duration") + if gotDuration != duration { + t.Errorf("Payload.GetDuration(%q) = %v, %v, want %v, nil", + "duration", gotDuration, err, duration) + } +} + +func TestPayloadWithMarshaling(t *testing.T) { + names := []string{"luke", "anakin", "ray"} + primes := []int{2, 3, 5, 7, 11, 13, 17} + user := map[string]interface{}{"name": "Ken", "score": 3.14} + location := map[string]string{"address": "123 Main St.", "state": "NY", "zipcode": "10002"} + favs := map[string][]string{ + "movies": []string{"forrest gump", "star wars"}, + "tv_shows": []string{"game of throwns", "HIMYM", "breaking bad"}, + } + counter := map[string]int{ + "a": 1, + "b": 101, + "c": 42, + } + features := map[string]bool{ + "A": false, + "B": true, + "C": true, + } + now := time.Now() + duration := 15 * time.Minute + + in := Payload{ + "subject": "Hello", + "recipient_id": 9876, + "pi": 3.14, + "enabled": true, + "names": names, + "primes": primes, + "user": user, + "location": location, + "favs": favs, + "counter": counter, + "features": features, + "timestamp": now, + "duration": duration, + } + + // encode and then decode + data, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + var out Payload + err = json.Unmarshal(data, &out) + if err != nil { + t.Fatal(err) + } + + gotStr, err := out.GetString("subject") + if gotStr != "Hello" || err != nil { + t.Errorf("Payload.GetString(%q) = %v, %v; want %q, nil", + "subject", gotStr, err, "Hello") + } + + gotInt, err := out.GetInt("recipient_id") + if gotInt != 9876 || err != nil { + t.Errorf("Payload.GetInt(%q) = %v, %v; want %v, nil", + "recipient_id", gotInt, err, 9876) + } + + gotFloat, err := out.GetFloat64("pi") + if gotFloat != 3.14 || err != nil { + t.Errorf("Payload.GetFloat64(%q) = %v, %v; want %v, nil", + "pi", gotFloat, err, 3.14) + } + + gotBool, err := out.GetBool("enabled") + if gotBool != true || err != nil { + t.Errorf("Payload.GetBool(%q) = %v, %v; want %v, nil", + "enabled", gotBool, err, true) + } + + gotStrSlice, err := out.GetStringSlice("names") + if diff := cmp.Diff(gotStrSlice, names); diff != "" { + t.Errorf("Payload.GetStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "names", gotStrSlice, err, names, diff) + } + + gotIntSlice, err := out.GetIntSlice("primes") + if diff := cmp.Diff(gotIntSlice, primes); diff != "" { + t.Errorf("Payload.GetIntSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "primes", gotIntSlice, err, primes, diff) + } + + gotStrMap, err := out.GetStringMap("user") + if diff := cmp.Diff(gotStrMap, user); diff != "" { + t.Errorf("Payload.GetStringMap(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "user", gotStrMap, err, user, diff) + } + + gotStrMapStr, err := out.GetStringMapString("location") + if diff := cmp.Diff(gotStrMapStr, location); diff != "" { + t.Errorf("Payload.GetStringMapString(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "location", gotStrMapStr, err, location, diff) + } + + gotStrMapStrSlice, err := out.GetStringMapStringSlice("favs") + if diff := cmp.Diff(gotStrMapStrSlice, favs); diff != "" { + t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "favs", gotStrMapStrSlice, err, favs, diff) + } + + gotStrMapInt, err := out.GetStringMapInt("counter") + if diff := cmp.Diff(gotStrMapInt, counter); diff != "" { + t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "counter", gotStrMapInt, err, counter, diff) + } + + gotStrMapBool, err := out.GetStringMapBool("features") + if diff := cmp.Diff(gotStrMapBool, features); diff != "" { + t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", + "features", gotStrMapBool, err, features, diff) + } + + gotTime, err := out.GetTime("timestamp") + if !gotTime.Equal(now) { + t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, nil", + "timestamp", gotTime, err, now) + } + + gotDuration, err := out.GetDuration("duration") + if gotDuration != duration { + t.Errorf("Payload.GetDuration(%q) = %v, %v, want %v, nil", + "duration", gotDuration, err, duration) + } +}