2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-10 11:31:58 +08:00

Make Task type immutable

This change makes it impossible to mutate payload within Handler or
RetryDelayFunc.
This commit is contained in:
Ken Hibino 2020-01-04 13:13:46 -08:00
parent 899566e661
commit f3a23b9b12
14 changed files with 92 additions and 345 deletions

View File

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Added
- NewTask constructor
### Changed
- Task type is now immutable (i.e., Payload is read-only)
## [0.1.0] - 2020-01-04 ## [0.1.0] - 2020-01-04
### Added ### Added

View File

@ -58,28 +58,23 @@ func main() {
} }
client := asynq.NewClient(r) client := asynq.NewClient(r)
t1 := asynq.Task{ // create a task with typename and payload.
Type: "send_welcome_email", t1 := asynq.NewTask(
Payload: map[string]interface{}{ "send_welcome_email",
"recipient_id": 1234, map[string]interface{}{"user_id": 42})
},
}
t2 := asynq.Task{ t2 := asynq.NewTask(
Type: "send_reminder_email", "send_reminder_email",
Payload: map[string]interface{}{ map[string]interface{}{"user_id": 42})
"recipient_id": 1234,
},
}
// process the task immediately. // process the task immediately.
err := client.Schedule(&t1, time.Now()) err := client.Schedule(t1, time.Now())
// process the task 24 hours later. // process the task 24 hours later.
err = client.Schedule(&t2, time.Now().Add(24 * time.Hour)) err = client.Schedule(t2, time.Now().Add(24 * time.Hour))
// specify the max number of retry (default: 25) // specify the max number of retry (default: 25)
err = client.Schedule(&t1, time.Now(), asynq.MaxRetry(1)) err = client.Schedule(t1, time.Now(), asynq.MaxRetry(1))
} }
``` ```
@ -120,7 +115,7 @@ The simplest way to implement a handler is to define a function with the same si
func handler(t *asynq.Task) error { func handler(t *asynq.Task) error {
switch t.Type { switch t.Type {
case "send_welcome_email": case "send_welcome_email":
id, err := t.Payload.GetInt("recipient_id") id, err := t.Payload.GetInt("user_id")
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,9 +11,20 @@ TODOs:
// Task represents a task to be performed. // Task represents a task to be performed.
type Task struct { type Task struct {
// Type indicates the kind of the task to be performed. // Type indicates the type of task to be performed.
Type string Type string
// Payload holds data needed to process the task. // Payload holds data needed to process the task.
Payload Payload Payload Payload
} }
// NewTask returns a new instance of a task given a task type and payload.
//
// Since payload data gets serialized to JSON, the payload values must be
// composed of JSON supported data types.
func NewTask(typename string, payload map[string]interface{}) *Task {
return &Task{
Type: typename,
Payload: Payload{payload},
}
}

View File

@ -52,7 +52,7 @@ type Config struct {
// //
// n is the number of times the task has been retried. // n is the number of times the task has been retried.
// e is the error returned by the task handler. // e is the error returned by the task handler.
// t is the task in question. t is read-only, the function should not mutate t. // t is the task in question.
RetryDelayFunc func(n int, e error, t *Task) time.Duration RetryDelayFunc func(n int, e error, t *Task) time.Duration
} }
@ -91,9 +91,6 @@ func NewBackground(r *redis.Client, cfg *Config) *Background {
// //
// If ProcessTask return a non-nil error or panics, the task // If ProcessTask return a non-nil error or panics, the task
// will be retried after delay. // will be retried after delay.
//
// Note: The argument task is ready only, ProcessTask should
// not mutate the task.
type Handler interface { type Handler interface {
ProcessTask(*Task) error ProcessTask(*Task) error
} }

View File

@ -33,15 +33,9 @@ func TestBackground(t *testing.T) {
bg.start(HandlerFunc(h)) bg.start(HandlerFunc(h))
client.Schedule(&Task{ client.Schedule(NewTask("send_email", map[string]interface{}{"recipient_id": 123}), time.Now())
Type: "send_email",
Payload: map[string]interface{}{"recipient_id": 123},
}, time.Now())
client.Schedule(&Task{ client.Schedule(NewTask("send_email", map[string]interface{}{"recipient_id": 456}), time.Now().Add(time.Hour))
Type: "send_email",
Payload: map[string]interface{}{"recipient_id": 456},
}, time.Now().Add(time.Hour))
bg.stop() bg.stop()
} }

View File

@ -28,8 +28,8 @@ func BenchmarkEndToEndSimple(b *testing.B) {
}) })
// Create a bunch of tasks // Create a bunch of tasks
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
t := Task{Type: fmt.Sprintf("task%d", i), Payload: Payload{"data": i}} t := NewTask(fmt.Sprintf("task%d", i), map[string]interface{}{"data": i})
client.Schedule(&t, time.Now()) client.Schedule(t, time.Now())
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@ -65,12 +65,12 @@ func BenchmarkEndToEnd(b *testing.B) {
}) })
// Create a bunch of tasks // Create a bunch of tasks
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
t := Task{Type: fmt.Sprintf("task%d", i), Payload: Payload{"data": i}} t := NewTask(fmt.Sprintf("task%d", i), map[string]interface{}{"data": i})
client.Schedule(&t, time.Now()) client.Schedule(t, time.Now())
} }
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
t := Task{Type: fmt.Sprintf("scheduled%d", i), Payload: Payload{"data": i}} t := NewTask(fmt.Sprintf("scheduled%d", i), map[string]interface{}{"data": i})
client.Schedule(&t, time.Now().Add(time.Second)) client.Schedule(t, time.Now().Add(time.Second))
} }
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -82,7 +82,7 @@ func (c *Client) Schedule(task *Task, processAt time.Time, opts ...Option) error
msg := &base.TaskMessage{ msg := &base.TaskMessage{
ID: xid.New(), ID: xid.New(),
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Queue: "default", Queue: "default",
Retry: opt.retry, Retry: opt.retry,
} }

View File

@ -17,7 +17,7 @@ func TestClient(t *testing.T) {
r := setup(t) r := setup(t)
client := NewClient(r) client := NewClient(r)
task := &Task{Type: "send_email", Payload: map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"}} task := NewTask("send_email", map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"})
tests := []struct { tests := []struct {
desc string desc string
@ -35,7 +35,7 @@ func TestClient(t *testing.T) {
wantEnqueued: []*base.TaskMessage{ wantEnqueued: []*base.TaskMessage{
&base.TaskMessage{ &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
}, },
@ -52,7 +52,7 @@ func TestClient(t *testing.T) {
{ {
Msg: &base.TaskMessage{ Msg: &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
}, },
@ -70,7 +70,7 @@ func TestClient(t *testing.T) {
wantEnqueued: []*base.TaskMessage{ wantEnqueued: []*base.TaskMessage{
&base.TaskMessage{ &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Retry: 3, Retry: 3,
Queue: "default", Queue: "default",
}, },
@ -87,7 +87,7 @@ func TestClient(t *testing.T) {
wantEnqueued: []*base.TaskMessage{ wantEnqueued: []*base.TaskMessage{
&base.TaskMessage{ &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Retry: 0, // Retry count should be set to zero Retry: 0, // Retry count should be set to zero
Queue: "default", Queue: "default",
}, },
@ -105,7 +105,7 @@ func TestClient(t *testing.T) {
wantEnqueued: []*base.TaskMessage{ wantEnqueued: []*base.TaskMessage{
&base.TaskMessage{ &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload, Payload: task.Payload.data,
Retry: 10, // Last option takes precedence Retry: 10, // Last option takes precedence
Queue: "default", Queue: "default",
}, },

9
doc.go
View File

@ -9,12 +9,11 @@ The Client is used to register a task to be processed at the specified time.
client := asynq.NewClient(redis) client := asynq.NewClient(redis)
t := asynq.Task{ t := asynq.NewTask(
Type: "send_email", "send_email",
Payload: map[string]interface{}{"user_id": 42}, map[string]interface{}{"user_id": 42})
}
err := client.Schedule(&t, time.Now().Add(time.Minute)) err := client.Schedule(t, time.Now().Add(time.Minute))
The Background is used to run the background task processing with a given The Background is used to run the background task processing with a given
handler. handler.

View File

@ -148,64 +148,6 @@ func TestDone(t *testing.T) {
} }
} }
// Note: User should not mutate task payload in Handler
// However, we should handle even if the user mutates the task
// in Handler. This test case is to make sure that we remove task
// from in-progress queue when we call Done for the task.
func TestDoneWithMutatedTask(t *testing.T) {
r := setup(t)
t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello"})
t2 := h.NewTaskMessage("export_csv", map[string]interface{}{"subjct": "hola"})
tests := []struct {
inProgress []*base.TaskMessage // initial state of the in-progress list
target *base.TaskMessage // task to remove
wantInProgress []*base.TaskMessage // final state of the in-progress list
}{
{
inProgress: []*base.TaskMessage{t1, t2},
target: t1,
wantInProgress: []*base.TaskMessage{t2},
},
{
inProgress: []*base.TaskMessage{t1},
target: t1,
wantInProgress: []*base.TaskMessage{},
},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case
h.SeedInProgressQueue(t, r.client, tc.inProgress)
// Mutate payload map!
tc.target.Payload["newkey"] = 123
err := r.Done(tc.target)
if err != nil {
t.Errorf("(*RDB).Done(task) = %v, want nil", err)
continue
}
gotInProgress := h.GetInProgressMessages(t, r.client)
if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" {
t.Errorf("mismatch found in %q: (-want, +got):\n%s", base.InProgressQueue, diff)
continue
}
processedKey := base.ProcessedKey(time.Now())
gotProcessed := r.client.Get(processedKey).Val()
if gotProcessed != "1" {
t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed)
}
gotTTL := r.client.TTL(processedKey).Val()
if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL)
}
}
}
func TestRequeue(t *testing.T) { func TestRequeue(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", nil) t1 := h.NewTaskMessage("send_email", nil)
@ -384,104 +326,6 @@ func TestRetry(t *testing.T) {
} }
} }
func TestRetryWithMutatedTask(t *testing.T) {
r := setup(t)
t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "Hola!"})
t2 := h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/image.jpg"})
t3 := h.NewTaskMessage("reindex", map[string]interface{}{})
t1.Retried = 10
errMsg := "SMTP server is not responding"
t1AfterRetry := &base.TaskMessage{
ID: t1.ID,
Type: t1.Type,
Payload: t1.Payload,
Queue: t1.Queue,
Retry: t1.Retry,
Retried: t1.Retried + 1,
ErrorMsg: errMsg,
}
now := time.Now()
tests := []struct {
inProgress []*base.TaskMessage
retry []h.ZSetEntry
msg *base.TaskMessage
processAt time.Time
errMsg string
wantInProgress []*base.TaskMessage
wantRetry []h.ZSetEntry
}{
{
inProgress: []*base.TaskMessage{t1, t2},
retry: []h.ZSetEntry{
{
Msg: t3,
Score: now.Add(time.Minute).Unix(),
},
},
msg: t1,
processAt: now.Add(5 * time.Minute),
errMsg: errMsg,
wantInProgress: []*base.TaskMessage{t2},
wantRetry: []h.ZSetEntry{
{
Msg: t1AfterRetry,
Score: now.Add(5 * time.Minute).Unix(),
},
{
Msg: t3,
Score: now.Add(time.Minute).Unix(),
},
},
},
}
for _, tc := range tests {
h.FlushDB(t, r.client)
h.SeedInProgressQueue(t, r.client, tc.inProgress)
h.SeedRetryQueue(t, r.client, tc.retry)
// Mutate paylod map!
tc.msg.Payload["newkey"] = "newvalue"
err := r.Retry(tc.msg, tc.processAt, tc.errMsg)
if err != nil {
t.Errorf("(*RDB).Retry = %v, want nil", err)
continue
}
gotInProgress := h.GetInProgressMessages(t, r.client)
if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" {
t.Errorf("mismatch found in %q; (-want, +got)\n%s", base.InProgressQueue, diff)
}
gotRetry := h.GetRetryEntries(t, r.client)
if diff := cmp.Diff(tc.wantRetry, gotRetry, h.SortZSetEntryOpt); diff != "" {
t.Errorf("mismatch found in %q; (-want, +got)\n%s", base.RetryQueue, diff)
}
processedKey := base.ProcessedKey(time.Now())
gotProcessed := r.client.Get(processedKey).Val()
if gotProcessed != "1" {
t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed)
}
gotTTL := r.client.TTL(processedKey).Val()
if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL)
}
failureKey := base.FailureKey(time.Now())
gotFailure := r.client.Get(failureKey).Val()
if gotFailure != "1" {
t.Errorf("GET %q = %q, want 1", failureKey, gotFailure)
}
gotTTL = r.client.TTL(processedKey).Val()
if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", failureKey, gotTTL, statsTTL)
}
}
}
func TestKill(t *testing.T) { func TestKill(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", nil) t1 := h.NewTaskMessage("send_email", nil)
@ -585,112 +429,6 @@ func TestKill(t *testing.T) {
} }
} }
func TestKillWithMutatedTask(t *testing.T) {
r := setup(t)
t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello"})
t2 := h.NewTaskMessage("reindex", map[string]interface{}{})
t3 := h.NewTaskMessage("generate_csv", map[string]interface{}{"path": "some/path/to/img"})
errMsg := "SMTP server not responding"
t1AfterKill := &base.TaskMessage{
ID: t1.ID,
Type: t1.Type,
Payload: t1.Payload,
Queue: t1.Queue,
Retry: t1.Retry,
Retried: t1.Retried,
ErrorMsg: errMsg,
}
now := time.Now()
// TODO(hibiken): add test cases for trimming
tests := []struct {
inProgress []*base.TaskMessage
dead []h.ZSetEntry
target *base.TaskMessage // task to kill
wantInProgress []*base.TaskMessage
wantDead []h.ZSetEntry
}{
{
inProgress: []*base.TaskMessage{t1, t2},
dead: []h.ZSetEntry{
{
Msg: t3,
Score: now.Add(-time.Hour).Unix(),
},
},
target: t1,
wantInProgress: []*base.TaskMessage{t2},
wantDead: []h.ZSetEntry{
{
Msg: t1AfterKill,
Score: now.Unix(),
},
{
Msg: t3,
Score: now.Add(-time.Hour).Unix(),
},
},
},
{
inProgress: []*base.TaskMessage{t1, t2, t3},
dead: []h.ZSetEntry{},
target: t1,
wantInProgress: []*base.TaskMessage{t2, t3},
wantDead: []h.ZSetEntry{
{
Msg: t1AfterKill,
Score: now.Unix(),
},
},
},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case
h.SeedInProgressQueue(t, r.client, tc.inProgress)
h.SeedDeadQueue(t, r.client, tc.dead)
// Mutate payload map!
tc.target.Payload["newkey"] = "newvalue"
err := r.Kill(tc.target, errMsg)
if err != nil {
t.Errorf("(*RDB).Kill(%v, %v) = %v, want nil", tc.target, errMsg, err)
continue
}
gotInProgress := h.GetInProgressMessages(t, r.client)
if diff := cmp.Diff(tc.wantInProgress, gotInProgress, h.SortMsgOpt); diff != "" {
t.Errorf("mismatch found in %q: (-want, +got)\n%s", base.InProgressQueue, diff)
}
gotDead := h.GetDeadEntries(t, r.client)
if diff := cmp.Diff(tc.wantDead, gotDead, h.SortZSetEntryOpt); diff != "" {
t.Errorf("mismatch found in %q after calling (*RDB).Kill: (-want, +got):\n%s", base.DeadQueue, diff)
}
processedKey := base.ProcessedKey(time.Now())
gotProcessed := r.client.Get(processedKey).Val()
if gotProcessed != "1" {
t.Errorf("GET %q = %q, want 1", processedKey, gotProcessed)
}
gotTTL := r.client.TTL(processedKey).Val()
if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL)
}
failureKey := base.FailureKey(time.Now())
gotFailure := r.client.Get(failureKey).Val()
if gotFailure != "1" {
t.Errorf("GET %q = %q, want 1", failureKey, gotFailure)
}
gotTTL = r.client.TTL(processedKey).Val()
if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", failureKey, gotTTL, statsTTL)
}
}
}
func TestRestoreUnfinished(t *testing.T) { func TestRestoreUnfinished(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", nil) t1 := h.NewTaskMessage("send_email", nil)

View File

@ -12,8 +12,9 @@ import (
) )
// Payload is an arbitrary data needed for task execution. // Payload is an arbitrary data needed for task execution.
// The values have to be JSON serializable. type Payload struct {
type Payload map[string]interface{} data map[string]interface{}
}
type errKeyNotFound struct { type errKeyNotFound struct {
key string key string
@ -25,14 +26,14 @@ func (e *errKeyNotFound) Error() string {
// Has reports whether key exists. // Has reports whether key exists.
func (p Payload) Has(key string) bool { func (p Payload) Has(key string) bool {
_, ok := p[key] _, ok := p.data[key]
return ok return ok
} }
// GetString returns a string value if a string type is associated with // GetString returns a string value if a string type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetString(key string) (string, error) { func (p Payload) GetString(key string) (string, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return "", &errKeyNotFound{key} return "", &errKeyNotFound{key}
} }
@ -42,7 +43,7 @@ func (p Payload) GetString(key string) (string, error) {
// GetInt returns an int value if a numeric type is associated with // GetInt returns an int value if a numeric type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetInt(key string) (int, error) { func (p Payload) GetInt(key string) (int, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }
@ -52,7 +53,7 @@ func (p Payload) GetInt(key string) (int, error) {
// GetFloat64 returns a float64 value if a numeric type is associated with // GetFloat64 returns a float64 value if a numeric type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetFloat64(key string) (float64, error) { func (p Payload) GetFloat64(key string) (float64, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }
@ -62,7 +63,7 @@ func (p Payload) GetFloat64(key string) (float64, error) {
// GetBool returns a boolean value if a boolean type is associated with // GetBool returns a boolean value if a boolean type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetBool(key string) (bool, error) { func (p Payload) GetBool(key string) (bool, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return false, &errKeyNotFound{key} return false, &errKeyNotFound{key}
} }
@ -72,7 +73,7 @@ func (p Payload) GetBool(key string) (bool, error) {
// GetStringSlice returns a slice of strings if a string slice type is associated with // GetStringSlice returns a slice of strings if a string slice type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetStringSlice(key string) ([]string, error) { func (p Payload) GetStringSlice(key string) ([]string, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -82,7 +83,7 @@ func (p Payload) GetStringSlice(key string) ([]string, error) {
// GetIntSlice returns a slice of ints if a int slice type is associated with // GetIntSlice returns a slice of ints if a int slice type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetIntSlice(key string) ([]int, error) { func (p Payload) GetIntSlice(key string) ([]int, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -93,7 +94,7 @@ func (p Payload) GetIntSlice(key string) ([]int, error) {
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetStringMap(key string) (map[string]interface{}, error) { func (p Payload) GetStringMap(key string) (map[string]interface{}, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -104,7 +105,7 @@ func (p Payload) GetStringMap(key string) (map[string]interface{}, error) {
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetStringMapString(key string) (map[string]string, error) { func (p Payload) GetStringMapString(key string) (map[string]string, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -115,7 +116,7 @@ func (p Payload) GetStringMapString(key string) (map[string]string, error) {
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetStringMapStringSlice(key string) (map[string][]string, error) { func (p Payload) GetStringMapStringSlice(key string) (map[string][]string, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -126,7 +127,7 @@ func (p Payload) GetStringMapStringSlice(key string) (map[string][]string, error
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetStringMapInt(key string) (map[string]int, error) { func (p Payload) GetStringMapInt(key string) (map[string]int, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -137,7 +138,7 @@ func (p Payload) GetStringMapInt(key string) (map[string]int, error) {
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetStringMapBool(key string) (map[string]bool, error) { func (p Payload) GetStringMapBool(key string) (map[string]bool, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
@ -147,7 +148,7 @@ func (p Payload) GetStringMapBool(key string) (map[string]bool, error) {
// GetTime returns a time value if a correct map type is associated with the key, // GetTime returns a time value if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetTime(key string) (time.Time, error) { func (p Payload) GetTime(key string) (time.Time, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return time.Time{}, &errKeyNotFound{key} return time.Time{}, &errKeyNotFound{key}
} }
@ -157,7 +158,7 @@ func (p Payload) GetTime(key string) (time.Time, error) {
// GetDuration returns a duration value if a correct map type is associated with the key, // GetDuration returns a duration value if a correct map type is associated with the key,
// otherwise reports an error. // otherwise reports an error.
func (p Payload) GetDuration(key string) (time.Duration, error) { func (p Payload) GetDuration(key string) (time.Duration, error) {
v, ok := p[key] v, ok := p.data[key]
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }

View File

@ -10,6 +10,8 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/base"
) )
func TestPayloadGet(t *testing.T) { func TestPayloadGet(t *testing.T) {
@ -34,7 +36,7 @@ func TestPayloadGet(t *testing.T) {
now := time.Now() now := time.Now()
duration := 15 * time.Minute duration := 15 * time.Minute
payload := Payload{ data := map[string]interface{}{
"greeting": "Hello", "greeting": "Hello",
"user_id": 9876, "user_id": 9876,
"pi": 3.1415, "pi": 3.1415,
@ -49,6 +51,7 @@ func TestPayloadGet(t *testing.T) {
"timestamp": now, "timestamp": now,
"duration": duration, "duration": duration,
} }
payload := Payload{data}
gotStr, err := payload.GetString("greeting") gotStr, err := payload.GetString("greeting")
if gotStr != "Hello" || err != nil { if gotStr != "Hello" || err != nil {
@ -151,7 +154,7 @@ func TestPayloadGetWithMarshaling(t *testing.T) {
now := time.Now() now := time.Now()
duration := 15 * time.Minute duration := 15 * time.Minute
in := Payload{ in := Payload{map[string]interface{}{
"subject": "Hello", "subject": "Hello",
"recipient_id": 9876, "recipient_id": 9876,
"pi": 3.14, "pi": 3.14,
@ -165,18 +168,19 @@ func TestPayloadGetWithMarshaling(t *testing.T) {
"features": features, "features": features,
"timestamp": now, "timestamp": now,
"duration": duration, "duration": duration,
} }}
// encode and then decode task messsage
// encode and then decode inMsg := h.NewTaskMessage("testing", in.data)
data, err := json.Marshal(in) data, err := json.Marshal(inMsg)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out Payload var outMsg base.TaskMessage
err = json.Unmarshal(data, &out) err = json.Unmarshal(data, &outMsg)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
out := Payload{outMsg.Payload}
gotStr, err := out.GetString("subject") gotStr, err := out.GetString("subject")
if gotStr != "Hello" || err != nil { if gotStr != "Hello" || err != nil {
@ -258,9 +262,9 @@ func TestPayloadGetWithMarshaling(t *testing.T) {
} }
func TestPayloadHas(t *testing.T) { func TestPayloadHas(t *testing.T) {
payload := Payload{ payload := Payload{map[string]interface{}{
"user_id": 123, "user_id": 123,
} }}
if !payload.Has("user_id") { if !payload.Has("user_id") {
t.Errorf("Payload.Has(%q) = false, want true", "user_id") t.Errorf("Payload.Has(%q) = false, want true", "user_id")

View File

@ -126,7 +126,7 @@ func (p *processor) exec() {
defer func() { <-p.sema /* release token */ }() defer func() { <-p.sema /* release token */ }()
resCh := make(chan error, 1) resCh := make(chan error, 1)
task := &Task{Type: msg.Type, Payload: msg.Payload} task := NewTask(msg.Type, msg.Payload)
go func() { go func() {
resCh <- perform(p.handler, task) resCh <- perform(p.handler, task)
}() }()
@ -182,7 +182,7 @@ func (p *processor) markAsDone(msg *base.TaskMessage) {
} }
func (p *processor) retry(msg *base.TaskMessage, e error) { func (p *processor) retry(msg *base.TaskMessage, e error) {
d := p.retryDelayFunc(msg.Retried, e, &Task{Type: msg.Type, Payload: msg.Payload}) d := p.retryDelayFunc(msg.Retried, e, NewTask(msg.Type, msg.Payload))
retryAt := time.Now().Add(d) retryAt := time.Now().Add(d)
err := p.rdb.Retry(msg, retryAt, e.Error()) err := p.rdb.Retry(msg, retryAt, e.Error())
if err != nil { if err != nil {

View File

@ -25,10 +25,10 @@ func TestProcessorSuccess(t *testing.T) {
m3 := h.NewTaskMessage("reindex", nil) m3 := h.NewTaskMessage("reindex", nil)
m4 := h.NewTaskMessage("sync", nil) m4 := h.NewTaskMessage("sync", nil)
t1 := &Task{Type: m1.Type, Payload: m1.Payload} t1 := NewTask(m1.Type, m1.Payload)
t2 := &Task{Type: m2.Type, Payload: m2.Payload} t2 := NewTask(m2.Type, m2.Payload)
t3 := &Task{Type: m3.Type, Payload: m3.Payload} t3 := NewTask(m3.Type, m3.Payload)
t4 := &Task{Type: m4.Type, Payload: m4.Payload} t4 := NewTask(m4.Type, m4.Payload)
tests := []struct { tests := []struct {
enqueued []*base.TaskMessage // initial default queue state enqueued []*base.TaskMessage // initial default queue state
@ -78,7 +78,7 @@ func TestProcessorSuccess(t *testing.T) {
time.Sleep(tc.wait) time.Sleep(tc.wait)
p.terminate() p.terminate()
if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt); diff != "" { if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt, cmp.AllowUnexported(Payload{})); diff != "" {
t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
} }
@ -190,7 +190,7 @@ func TestPerform(t *testing.T) {
handler: func(t *Task) error { handler: func(t *Task) error {
return nil return nil
}, },
task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, task: NewTask("gen_thumbnail", map[string]interface{}{"src": "some/img/path"}),
wantErr: false, wantErr: false,
}, },
{ {
@ -198,7 +198,7 @@ func TestPerform(t *testing.T) {
handler: func(t *Task) error { handler: func(t *Task) error {
return fmt.Errorf("something went wrong") return fmt.Errorf("something went wrong")
}, },
task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, task: NewTask("gen_thumbnail", map[string]interface{}{"src": "some/img/path"}),
wantErr: true, wantErr: true,
}, },
{ {
@ -206,7 +206,7 @@ func TestPerform(t *testing.T) {
handler: func(t *Task) error { handler: func(t *Task) error {
panic("something went terribly wrong") panic("something went terribly wrong")
}, },
task: &Task{Type: "gen_thumbnail", Payload: map[string]interface{}{"src": "some/img/path"}}, task: NewTask("gen_thumbnail", map[string]interface{}{"src": "some/img/path"}),
wantErr: true, wantErr: true,
}, },
} }