diff --git a/internal/base/base.go b/internal/base/base.go index 2210410..f97a964 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -7,6 +7,7 @@ package base import ( "context" + "encoding/json" "fmt" "strings" "sync" @@ -106,6 +107,26 @@ type TaskMessage struct { UniqueKey string } +// EncodeMessage marshals the given task message in JSON and returns an encoded string. +func EncodeMessage(msg *TaskMessage) (string, error) { + b, err := json.Marshal(msg) + if err != nil { + return "", err + } + return string(b), nil +} + +// DecodeMessage unmarshals the given encoded string and returns a decoded task message. +func DecodeMessage(s string) (*TaskMessage, error) { + d := json.NewDecoder(strings.NewReader(s)) + d.UseNumber() + var msg TaskMessage + if err := d.Decode(&msg); err != nil { + return nil, err + } + return &msg, nil +} + // ServerStatus represents status of a server. // ServerStatus methods are concurrency safe. type ServerStatus struct { diff --git a/internal/base/base_test.go b/internal/base/base_test.go index cfe8414..700ae53 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -6,9 +6,13 @@ package base import ( "context" + "encoding/json" "sync" "testing" "time" + + "github.com/google/go-cmp/cmp" + "github.com/rs/xid" ) func TestQueueKey(t *testing.T) { @@ -103,6 +107,52 @@ func TestWorkersKey(t *testing.T) { } } +func TestMessageEncoding(t *testing.T) { + id := xid.New() + tests := []struct { + in *TaskMessage + out *TaskMessage + }{ + { + in: &TaskMessage{ + Type: "task1", + Payload: map[string]interface{}{"a": 1, "b": "hello!", "c": true}, + ID: id, + Queue: "default", + Retry: 10, + Retried: 0, + Timeout: "0", + }, + out: &TaskMessage{ + Type: "task1", + Payload: map[string]interface{}{"a": json.Number("1"), "b": "hello!", "c": true}, + ID: id, + Queue: "default", + Retry: 10, + Retried: 0, + Timeout: "0", + }, + }, + } + + for _, tc := range tests { + encoded, err := EncodeMessage(tc.in) + if err != nil { + t.Errorf("EncodeMessage(msg) returned error: %v", err) + continue + } + decoded, err := DecodeMessage(encoded) + if err != nil { + t.Errorf("DecodeMessage(encoded) returned error: %v", err) + continue + } + if diff := cmp.Diff(tc.out, decoded); diff != "" { + t.Errorf("Decoded message == %+v, want %+v;(-want,+got)\n%s", + decoded, tc.out, diff) + } + } +} + // Test for status being accessed by multiple goroutines. // Run with -race flag to check for data race. func TestStatusConcurrentAccess(t *testing.T) { diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index e3a493e..5112242 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -54,12 +54,12 @@ return 1`) // Enqueue inserts the given task to the tail of the queue. func (r *RDB) Enqueue(msg *base.TaskMessage) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } key := base.QueueKey(msg.Queue) - return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, bytes).Err() + return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, encoded).Err() } // KEYS[1] -> unique key in the form :: @@ -81,14 +81,14 @@ return 1 // EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired. func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } key := base.QueueKey(msg.Queue) res, err := enqueueUniqueCmd.Run(r.client, []string{msg.UniqueKey, key, base.AllQueues}, - msg.ID.String(), int(ttl.Seconds()), bytes).Result() + msg.ID.String(), int(ttl.Seconds()), encoded).Result() if err != nil { return err } @@ -117,12 +117,7 @@ func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) { if err != nil { return nil, err } - var msg base.TaskMessage - err = json.Unmarshal([]byte(data), &msg) - if err != nil { - return nil, err - } - return &msg, nil + return base.DecodeMessage(data) } // KEYS[1] -> asynq:in_progress @@ -176,7 +171,7 @@ return redis.status_reply("OK") // Done removes the task from in-progress queue to mark the task as done. // It removes a uniqueness lock acquired by the task, if any. func (r *RDB) Done(msg *base.TaskMessage) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } @@ -185,7 +180,7 @@ func (r *RDB) Done(msg *base.TaskMessage) error { expireAt := now.Add(statsTTL) return doneCmd.Run(r.client, []string{base.InProgressQueue, processedKey, msg.UniqueKey}, - bytes, expireAt.Unix(), msg.ID.String()).Err() + encoded, expireAt.Unix(), msg.ID.String()).Err() } // KEYS[1] -> asynq:in_progress @@ -199,13 +194,13 @@ return redis.status_reply("OK")`) // Requeue moves the task from in-progress queue to the specified queue. func (r *RDB) Requeue(msg *base.TaskMessage) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } return requeueCmd.Run(r.client, []string{base.InProgressQueue, base.QueueKey(msg.Queue)}, - string(bytes)).Err() + encoded).Err() } // KEYS[1] -> asynq:scheduled @@ -221,7 +216,7 @@ return 1 // Schedule adds the task to the backlog queue to be processed in the future. func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } @@ -229,7 +224,7 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { score := float64(processAt.Unix()) return scheduleCmd.Run(r.client, []string{base.ScheduledQueue, base.AllQueues}, - score, bytes, qkey).Err() + score, encoded, qkey).Err() } // KEYS[1] -> unique key in the format :: @@ -253,7 +248,7 @@ return 1 // ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired. func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { - bytes, err := json.Marshal(msg) + encoded, err := base.EncodeMessage(msg) if err != nil { return err } @@ -261,7 +256,7 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim score := float64(processAt.Unix()) res, err := scheduleUniqueCmd.Run(r.client, []string{msg.UniqueKey, base.ScheduledQueue, base.AllQueues}, - msg.ID.String(), int(ttl.Seconds()), score, bytes, qkey).Result() + msg.ID.String(), int(ttl.Seconds()), score, encoded, qkey).Result() if err != nil { return err } @@ -302,14 +297,14 @@ return redis.status_reply("OK")`) // Retry moves the task from in-progress to retry queue, incrementing retry count // and assigning error message to the task message. func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) error { - bytesToRemove, err := json.Marshal(msg) + msgToRemove, err := base.EncodeMessage(msg) if err != nil { return err } modified := *msg modified.Retried++ modified.ErrorMsg = errMsg - bytesToAdd, err := json.Marshal(&modified) + msgToAdd, err := base.EncodeMessage(&modified) if err != nil { return err } @@ -319,7 +314,7 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) e expireAt := now.Add(statsTTL) return retryCmd.Run(r.client, []string{base.InProgressQueue, base.RetryQueue, processedKey, failureKey}, - string(bytesToRemove), string(bytesToAdd), processAt.Unix(), expireAt.Unix()).Err() + msgToRemove, msgToAdd, processAt.Unix(), expireAt.Unix()).Err() } const ( @@ -359,13 +354,13 @@ return redis.status_reply("OK")`) // the error message to the task. // It also trims the set by timestamp and set size. func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error { - bytesToRemove, err := json.Marshal(msg) + msgToRemove, err := base.EncodeMessage(msg) if err != nil { return err } modified := *msg modified.ErrorMsg = errMsg - bytesToAdd, err := json.Marshal(&modified) + msgToAdd, err := base.EncodeMessage(&modified) if err != nil { return err } @@ -376,7 +371,7 @@ func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error { expireAt := now.Add(statsTTL) return killCmd.Run(r.client, []string{base.InProgressQueue, base.DeadQueue, processedKey, failureKey}, - string(bytesToRemove), string(bytesToAdd), now.Unix(), limit, maxDeadTasks, expireAt.Unix()).Err() + msgToRemove, msgToAdd, now.Unix(), limit, maxDeadTasks, expireAt.Unix()).Err() } // KEYS[1] -> asynq:in_progress diff --git a/payload.go b/payload.go index 340e673..85d63f1 100644 --- a/payload.go +++ b/payload.go @@ -5,6 +5,7 @@ package asynq import ( + "encoding/json" "fmt" "time" @@ -30,6 +31,19 @@ func (p Payload) Has(key string) bool { return ok } +func toInt(v interface{}) (int, error) { + switch v := v.(type) { + case json.Number: + val, err := v.Int64() + if err != nil { + return 0, err + } + return int(val), nil + default: + return cast.ToIntE(v) + } +} + // GetString returns a string value if a string type is associated with // the key, otherwise reports an error. func (p Payload) GetString(key string) (string, error) { @@ -47,7 +61,7 @@ func (p Payload) GetInt(key string) (int, error) { if !ok { return 0, &errKeyNotFound{key} } - return cast.ToIntE(v) + return toInt(v) } // GetFloat64 returns a float64 value if a numeric type is associated with @@ -57,7 +71,12 @@ func (p Payload) GetFloat64(key string) (float64, error) { if !ok { return 0, &errKeyNotFound{key} } - return cast.ToFloat64E(v) + switch v := v.(type) { + case json.Number: + return v.Float64() + default: + return cast.ToFloat64E(v) + } } // GetBool returns a boolean value if a boolean type is associated with @@ -87,7 +106,20 @@ func (p Payload) GetIntSlice(key string) ([]int, error) { if !ok { return nil, &errKeyNotFound{key} } - return cast.ToIntSliceE(v) + switch v := v.(type) { + case []interface{}: + var res []int + for _, elem := range v { + val, err := toInt(elem) + if err != nil { + return nil, err + } + res = append(res, int(val)) + } + return res, nil + default: + return cast.ToIntSliceE(v) + } } // GetStringMap returns a map of string to empty interface @@ -131,7 +163,20 @@ func (p Payload) GetStringMapInt(key string) (map[string]int, error) { if !ok { return nil, &errKeyNotFound{key} } - return cast.ToStringMapIntE(v) + switch v := v.(type) { + case map[string]interface{}: + res := make(map[string]int) + for key, val := range v { + ival, err := toInt(val) + if err != nil { + return nil, err + } + res[key] = ival + } + return res, nil + default: + return cast.ToStringMapIntE(v) + } } // GetStringMapBool returns a map of string to boolean @@ -162,5 +207,14 @@ func (p Payload) GetDuration(key string) (time.Duration, error) { if !ok { return 0, &errKeyNotFound{key} } - return cast.ToDurationE(v) + switch v := v.(type) { + case json.Number: + val, err := v.Int64() + if err != nil { + return 0, err + } + return time.Duration(val), nil + default: + return cast.ToDurationE(v) + } } diff --git a/payload_test.go b/payload_test.go index c8245b3..9808374 100644 --- a/payload_test.go +++ b/payload_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" h "github.com/hibiken/asynq/internal/asynqtest" "github.com/hibiken/asynq/internal/base" ) @@ -40,12 +41,11 @@ func TestPayloadString(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -85,12 +85,11 @@ func TestPayloadInt(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -130,12 +129,11 @@ func TestPayloadFloat64(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -175,12 +173,11 @@ func TestPayloadBool(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -221,12 +218,11 @@ func TestPayloadStringSlice(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -268,12 +264,11 @@ func TestPayloadIntSlice(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -315,21 +310,28 @@ func TestPayloadStringMap(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } payload = Payload{out.Payload} got, err = payload.GetStringMap(tc.key) - diff = cmp.Diff(got, tc.data[tc.key]) + ignoreOpt := cmpopts.IgnoreMapEntries(func(key string, val interface{}) bool { + switch val.(type) { + case json.Number: + return true + default: + return false + } + }) + diff = cmp.Diff(got, tc.data[tc.key], ignoreOpt) if err != nil || diff != "" { - t.Errorf("With Marshaling: Payload.GetStringMap(%q) = %v, %v, want %v, nil", - tc.key, got, err, tc.data[tc.key]) + t.Errorf("With Marshaling: Payload.GetStringMap(%q) = %v, %v, want %v, nil;(-want,+got)\n%s", + tc.key, got, err, tc.data[tc.key], diff) } // access non-existent key. @@ -362,12 +364,11 @@ func TestPayloadStringMapString(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -413,12 +414,11 @@ func TestPayloadStringMapStringSlice(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -465,12 +465,11 @@ func TestPayloadStringMapInt(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -517,12 +516,11 @@ func TestPayloadStringMapBool(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -564,12 +562,11 @@ func TestPayloadTime(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } @@ -611,12 +608,11 @@ func TestPayloadDuration(t *testing.T) { // encode and then decode task messsage. in := h.NewTaskMessage("testing", tc.data) - b, err := json.Marshal(in) + encoded, err := base.EncodeMessage(in) if err != nil { t.Fatal(err) } - var out base.TaskMessage - err = json.Unmarshal(b, &out) + out, err := base.DecodeMessage(encoded) if err != nil { t.Fatal(err) } diff --git a/processor_test.go b/processor_test.go index da1b577..7a7c8a0 100644 --- a/processor_test.go +++ b/processor_test.go @@ -31,6 +31,17 @@ func fakeHeartbeater(starting, finished <-chan *base.TaskMessage, done <-chan st } } +// fakeSyncer receives from sync channel and do nothing. +func fakeSyncer(syncCh <-chan *syncRequest, done <-chan struct{}) { + for { + select { + case <-syncCh: + case <-done: + return + } + } +} + func TestProcessorSuccess(t *testing.T) { r := setup(t) rdbClient := rdb.NewRDB(r) @@ -77,14 +88,16 @@ func TestProcessorSuccess(t *testing.T) { } starting := make(chan *base.TaskMessage) finished := make(chan *base.TaskMessage) + syncCh := make(chan *syncRequest) done := make(chan struct{}) defer func() { close(done) }() go fakeHeartbeater(starting, finished, done) + go fakeSyncer(syncCh, done) p := newProcessor(processorParams{ logger: testLogger, broker: rdbClient, retryDelayFunc: defaultDelayFunc, - syncCh: nil, + syncCh: syncCh, cancelations: base.NewCancelations(), concurrency: 10, queues: defaultQueueConfig, @@ -105,6 +118,9 @@ func TestProcessorSuccess(t *testing.T) { } } time.Sleep(2 * time.Second) // wait for two second to allow all enqueued tasks to be processed. + if l := r.LLen(base.InProgressQueue).Val(); l != 0 { + t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l) + } p.terminate() mu.Lock() @@ -112,10 +128,79 @@ func TestProcessorSuccess(t *testing.T) { t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) } mu.Unlock() + } +} +// https://github.com/hibiken/asynq/issues/166 +func TestProcessTasksWithLargeNumberInPayload(t *testing.T) { + r := setup(t) + rdbClient := rdb.NewRDB(r) + + m1 := h.NewTaskMessage("large_number", map[string]interface{}{"data": 111111111111111111}) + t1 := NewTask(m1.Type, m1.Payload) + + tests := []struct { + enqueued []*base.TaskMessage // initial default queue state + wantProcessed []*Task // tasks to be processed at the end + }{ + { + enqueued: []*base.TaskMessage{m1}, + wantProcessed: []*Task{t1}, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r) // clean up db before each test case. + h.SeedEnqueuedQueue(t, r, tc.enqueued) // initialize default queue. + + var mu sync.Mutex + var processed []*Task + handler := func(ctx context.Context, task *Task) error { + mu.Lock() + defer mu.Unlock() + if data, err := task.Payload.GetInt("data"); err != nil { + t.Errorf("coult not get data from payload: %v", err) + } else { + t.Logf("data == %d", data) + } + processed = append(processed, task) + return nil + } + starting := make(chan *base.TaskMessage) + finished := make(chan *base.TaskMessage) + syncCh := make(chan *syncRequest) + done := make(chan struct{}) + defer func() { close(done) }() + go fakeHeartbeater(starting, finished, done) + go fakeSyncer(syncCh, done) + p := newProcessor(processorParams{ + logger: testLogger, + broker: rdbClient, + retryDelayFunc: defaultDelayFunc, + syncCh: syncCh, + cancelations: base.NewCancelations(), + concurrency: 10, + queues: defaultQueueConfig, + strictPriority: false, + errHandler: nil, + shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, + }) + p.handler = HandlerFunc(handler) + + p.start(&sync.WaitGroup{}) + time.Sleep(2 * time.Second) // wait for two second to allow all enqueued tasks to be processed. if l := r.LLen(base.InProgressQueue).Val(); l != 0 { t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l) } + p.terminate() + + mu.Lock() + if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt, cmpopts.IgnoreUnexported(Payload{})); diff != "" { + t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) + } + mu.Unlock() } }