2
0
mirror of https://github.com/hibiken/asynq.git synced 2025-10-25 23:06:12 +08:00

Fix JSON number ovewflow issue

This commit is contained in:
Ken Hibino
2020-06-11 20:58:27 -07:00
parent 81bb52b08c
commit a2abeedaa0
6 changed files with 273 additions and 72 deletions

View File

@@ -7,6 +7,7 @@ package base
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@@ -106,6 +107,26 @@ type TaskMessage struct {
UniqueKey string UniqueKey string
} }
// EncodeMessage marshals the given task message in JSON and returns an encoded string.
func EncodeMessage(msg *TaskMessage) (string, error) {
b, err := json.Marshal(msg)
if err != nil {
return "", err
}
return string(b), nil
}
// DecodeMessage unmarshals the given encoded string and returns a decoded task message.
func DecodeMessage(s string) (*TaskMessage, error) {
d := json.NewDecoder(strings.NewReader(s))
d.UseNumber()
var msg TaskMessage
if err := d.Decode(&msg); err != nil {
return nil, err
}
return &msg, nil
}
// ServerStatus represents status of a server. // ServerStatus represents status of a server.
// ServerStatus methods are concurrency safe. // ServerStatus methods are concurrency safe.
type ServerStatus struct { type ServerStatus struct {

View File

@@ -6,9 +6,13 @@ package base
import ( import (
"context" "context"
"encoding/json"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/rs/xid"
) )
func TestQueueKey(t *testing.T) { func TestQueueKey(t *testing.T) {
@@ -103,6 +107,52 @@ func TestWorkersKey(t *testing.T) {
} }
} }
func TestMessageEncoding(t *testing.T) {
id := xid.New()
tests := []struct {
in *TaskMessage
out *TaskMessage
}{
{
in: &TaskMessage{
Type: "task1",
Payload: map[string]interface{}{"a": 1, "b": "hello!", "c": true},
ID: id,
Queue: "default",
Retry: 10,
Retried: 0,
Timeout: "0",
},
out: &TaskMessage{
Type: "task1",
Payload: map[string]interface{}{"a": json.Number("1"), "b": "hello!", "c": true},
ID: id,
Queue: "default",
Retry: 10,
Retried: 0,
Timeout: "0",
},
},
}
for _, tc := range tests {
encoded, err := EncodeMessage(tc.in)
if err != nil {
t.Errorf("EncodeMessage(msg) returned error: %v", err)
continue
}
decoded, err := DecodeMessage(encoded)
if err != nil {
t.Errorf("DecodeMessage(encoded) returned error: %v", err)
continue
}
if diff := cmp.Diff(tc.out, decoded); diff != "" {
t.Errorf("Decoded message == %+v, want %+v;(-want,+got)\n%s",
decoded, tc.out, diff)
}
}
}
// Test for status being accessed by multiple goroutines. // Test for status being accessed by multiple goroutines.
// Run with -race flag to check for data race. // Run with -race flag to check for data race.
func TestStatusConcurrentAccess(t *testing.T) { func TestStatusConcurrentAccess(t *testing.T) {

View File

@@ -54,12 +54,12 @@ return 1`)
// Enqueue inserts the given task to the tail of the queue. // Enqueue inserts the given task to the tail of the queue.
func (r *RDB) Enqueue(msg *base.TaskMessage) error { func (r *RDB) Enqueue(msg *base.TaskMessage) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
key := base.QueueKey(msg.Queue) key := base.QueueKey(msg.Queue)
return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, bytes).Err() return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, encoded).Err()
} }
// KEYS[1] -> unique key in the form <type>:<payload>:<qname> // KEYS[1] -> unique key in the form <type>:<payload>:<qname>
@@ -81,14 +81,14 @@ return 1
// EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired. // EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error { func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
key := base.QueueKey(msg.Queue) key := base.QueueKey(msg.Queue)
res, err := enqueueUniqueCmd.Run(r.client, res, err := enqueueUniqueCmd.Run(r.client,
[]string{msg.UniqueKey, key, base.AllQueues}, []string{msg.UniqueKey, key, base.AllQueues},
msg.ID.String(), int(ttl.Seconds()), bytes).Result() msg.ID.String(), int(ttl.Seconds()), encoded).Result()
if err != nil { if err != nil {
return err return err
} }
@@ -117,12 +117,7 @@ func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var msg base.TaskMessage return base.DecodeMessage(data)
err = json.Unmarshal([]byte(data), &msg)
if err != nil {
return nil, err
}
return &msg, nil
} }
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress
@@ -176,7 +171,7 @@ return redis.status_reply("OK")
// Done removes the task from in-progress queue to mark the task as done. // Done removes the task from in-progress queue to mark the task as done.
// It removes a uniqueness lock acquired by the task, if any. // It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) Done(msg *base.TaskMessage) error { func (r *RDB) Done(msg *base.TaskMessage) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
@@ -185,7 +180,7 @@ func (r *RDB) Done(msg *base.TaskMessage) error {
expireAt := now.Add(statsTTL) expireAt := now.Add(statsTTL)
return doneCmd.Run(r.client, return doneCmd.Run(r.client,
[]string{base.InProgressQueue, processedKey, msg.UniqueKey}, []string{base.InProgressQueue, processedKey, msg.UniqueKey},
bytes, expireAt.Unix(), msg.ID.String()).Err() encoded, expireAt.Unix(), msg.ID.String()).Err()
} }
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress
@@ -199,13 +194,13 @@ return redis.status_reply("OK")`)
// Requeue moves the task from in-progress queue to the specified queue. // Requeue moves the task from in-progress queue to the specified queue.
func (r *RDB) Requeue(msg *base.TaskMessage) error { func (r *RDB) Requeue(msg *base.TaskMessage) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
return requeueCmd.Run(r.client, return requeueCmd.Run(r.client,
[]string{base.InProgressQueue, base.QueueKey(msg.Queue)}, []string{base.InProgressQueue, base.QueueKey(msg.Queue)},
string(bytes)).Err() encoded).Err()
} }
// KEYS[1] -> asynq:scheduled // KEYS[1] -> asynq:scheduled
@@ -221,7 +216,7 @@ return 1
// Schedule adds the task to the backlog queue to be processed in the future. // Schedule adds the task to the backlog queue to be processed in the future.
func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
@@ -229,7 +224,7 @@ func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
score := float64(processAt.Unix()) score := float64(processAt.Unix())
return scheduleCmd.Run(r.client, return scheduleCmd.Run(r.client,
[]string{base.ScheduledQueue, base.AllQueues}, []string{base.ScheduledQueue, base.AllQueues},
score, bytes, qkey).Err() score, encoded, qkey).Err()
} }
// KEYS[1] -> unique key in the format <type>:<payload>:<qname> // KEYS[1] -> unique key in the format <type>:<payload>:<qname>
@@ -253,7 +248,7 @@ return 1
// ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired. // ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired. // It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error { func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error {
bytes, err := json.Marshal(msg) encoded, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
@@ -261,7 +256,7 @@ func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl tim
score := float64(processAt.Unix()) score := float64(processAt.Unix())
res, err := scheduleUniqueCmd.Run(r.client, res, err := scheduleUniqueCmd.Run(r.client,
[]string{msg.UniqueKey, base.ScheduledQueue, base.AllQueues}, []string{msg.UniqueKey, base.ScheduledQueue, base.AllQueues},
msg.ID.String(), int(ttl.Seconds()), score, bytes, qkey).Result() msg.ID.String(), int(ttl.Seconds()), score, encoded, qkey).Result()
if err != nil { if err != nil {
return err return err
} }
@@ -302,14 +297,14 @@ return redis.status_reply("OK")`)
// Retry moves the task from in-progress to retry queue, incrementing retry count // Retry moves the task from in-progress to retry queue, incrementing retry count
// and assigning error message to the task message. // and assigning error message to the task message.
func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) error { func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) error {
bytesToRemove, err := json.Marshal(msg) msgToRemove, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
modified := *msg modified := *msg
modified.Retried++ modified.Retried++
modified.ErrorMsg = errMsg modified.ErrorMsg = errMsg
bytesToAdd, err := json.Marshal(&modified) msgToAdd, err := base.EncodeMessage(&modified)
if err != nil { if err != nil {
return err return err
} }
@@ -319,7 +314,7 @@ func (r *RDB) Retry(msg *base.TaskMessage, processAt time.Time, errMsg string) e
expireAt := now.Add(statsTTL) expireAt := now.Add(statsTTL)
return retryCmd.Run(r.client, return retryCmd.Run(r.client,
[]string{base.InProgressQueue, base.RetryQueue, processedKey, failureKey}, []string{base.InProgressQueue, base.RetryQueue, processedKey, failureKey},
string(bytesToRemove), string(bytesToAdd), processAt.Unix(), expireAt.Unix()).Err() msgToRemove, msgToAdd, processAt.Unix(), expireAt.Unix()).Err()
} }
const ( const (
@@ -359,13 +354,13 @@ return redis.status_reply("OK")`)
// the error message to the task. // the error message to the task.
// It also trims the set by timestamp and set size. // It also trims the set by timestamp and set size.
func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error { func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error {
bytesToRemove, err := json.Marshal(msg) msgToRemove, err := base.EncodeMessage(msg)
if err != nil { if err != nil {
return err return err
} }
modified := *msg modified := *msg
modified.ErrorMsg = errMsg modified.ErrorMsg = errMsg
bytesToAdd, err := json.Marshal(&modified) msgToAdd, err := base.EncodeMessage(&modified)
if err != nil { if err != nil {
return err return err
} }
@@ -376,7 +371,7 @@ func (r *RDB) Kill(msg *base.TaskMessage, errMsg string) error {
expireAt := now.Add(statsTTL) expireAt := now.Add(statsTTL)
return killCmd.Run(r.client, return killCmd.Run(r.client,
[]string{base.InProgressQueue, base.DeadQueue, processedKey, failureKey}, []string{base.InProgressQueue, base.DeadQueue, processedKey, failureKey},
string(bytesToRemove), string(bytesToAdd), now.Unix(), limit, maxDeadTasks, expireAt.Unix()).Err() msgToRemove, msgToAdd, now.Unix(), limit, maxDeadTasks, expireAt.Unix()).Err()
} }
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress

View File

@@ -5,6 +5,7 @@
package asynq package asynq
import ( import (
"encoding/json"
"fmt" "fmt"
"time" "time"
@@ -30,6 +31,19 @@ func (p Payload) Has(key string) bool {
return ok return ok
} }
func toInt(v interface{}) (int, error) {
switch v := v.(type) {
case json.Number:
val, err := v.Int64()
if err != nil {
return 0, err
}
return int(val), nil
default:
return cast.ToIntE(v)
}
}
// GetString returns a string value if a string type is associated with // GetString returns a string value if a string type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
func (p Payload) GetString(key string) (string, error) { func (p Payload) GetString(key string) (string, error) {
@@ -47,7 +61,7 @@ func (p Payload) GetInt(key string) (int, error) {
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }
return cast.ToIntE(v) return toInt(v)
} }
// GetFloat64 returns a float64 value if a numeric type is associated with // GetFloat64 returns a float64 value if a numeric type is associated with
@@ -57,8 +71,13 @@ func (p Payload) GetFloat64(key string) (float64, error) {
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }
switch v := v.(type) {
case json.Number:
return v.Float64()
default:
return cast.ToFloat64E(v) return cast.ToFloat64E(v)
} }
}
// GetBool returns a boolean value if a boolean type is associated with // GetBool returns a boolean value if a boolean type is associated with
// the key, otherwise reports an error. // the key, otherwise reports an error.
@@ -87,8 +106,21 @@ func (p Payload) GetIntSlice(key string) ([]int, error) {
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
switch v := v.(type) {
case []interface{}:
var res []int
for _, elem := range v {
val, err := toInt(elem)
if err != nil {
return nil, err
}
res = append(res, int(val))
}
return res, nil
default:
return cast.ToIntSliceE(v) return cast.ToIntSliceE(v)
} }
}
// GetStringMap returns a map of string to empty interface // GetStringMap returns a map of string to empty interface
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
@@ -131,8 +163,21 @@ func (p Payload) GetStringMapInt(key string) (map[string]int, error) {
if !ok { if !ok {
return nil, &errKeyNotFound{key} return nil, &errKeyNotFound{key}
} }
switch v := v.(type) {
case map[string]interface{}:
res := make(map[string]int)
for key, val := range v {
ival, err := toInt(val)
if err != nil {
return nil, err
}
res[key] = ival
}
return res, nil
default:
return cast.ToStringMapIntE(v) return cast.ToStringMapIntE(v)
} }
}
// GetStringMapBool returns a map of string to boolean // GetStringMapBool returns a map of string to boolean
// if a correct map type is associated with the key, // if a correct map type is associated with the key,
@@ -162,5 +207,14 @@ func (p Payload) GetDuration(key string) (time.Duration, error) {
if !ok { if !ok {
return 0, &errKeyNotFound{key} return 0, &errKeyNotFound{key}
} }
switch v := v.(type) {
case json.Number:
val, err := v.Int64()
if err != nil {
return 0, err
}
return time.Duration(val), nil
default:
return cast.ToDurationE(v) return cast.ToDurationE(v)
} }
}

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
) )
@@ -40,12 +41,11 @@ func TestPayloadString(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -85,12 +85,11 @@ func TestPayloadInt(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -130,12 +129,11 @@ func TestPayloadFloat64(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -175,12 +173,11 @@ func TestPayloadBool(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -221,12 +218,11 @@ func TestPayloadStringSlice(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -268,12 +264,11 @@ func TestPayloadIntSlice(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -315,21 +310,28 @@ func TestPayloadStringMap(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
payload = Payload{out.Payload} payload = Payload{out.Payload}
got, err = payload.GetStringMap(tc.key) got, err = payload.GetStringMap(tc.key)
diff = cmp.Diff(got, tc.data[tc.key]) ignoreOpt := cmpopts.IgnoreMapEntries(func(key string, val interface{}) bool {
switch val.(type) {
case json.Number:
return true
default:
return false
}
})
diff = cmp.Diff(got, tc.data[tc.key], ignoreOpt)
if err != nil || diff != "" { if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMap(%q) = %v, %v, want %v, nil", t.Errorf("With Marshaling: Payload.GetStringMap(%q) = %v, %v, want %v, nil;(-want,+got)\n%s",
tc.key, got, err, tc.data[tc.key]) tc.key, got, err, tc.data[tc.key], diff)
} }
// access non-existent key. // access non-existent key.
@@ -362,12 +364,11 @@ func TestPayloadStringMapString(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -413,12 +414,11 @@ func TestPayloadStringMapStringSlice(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -465,12 +465,11 @@ func TestPayloadStringMapInt(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -517,12 +516,11 @@ func TestPayloadStringMapBool(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -564,12 +562,11 @@ func TestPayloadTime(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -611,12 +608,11 @@ func TestPayloadDuration(t *testing.T) {
// encode and then decode task messsage. // encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data) in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in) encoded, err := base.EncodeMessage(in)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var out base.TaskMessage out, err := base.DecodeMessage(encoded)
err = json.Unmarshal(b, &out)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -31,6 +31,17 @@ func fakeHeartbeater(starting, finished <-chan *base.TaskMessage, done <-chan st
} }
} }
// fakeSyncer receives from sync channel and do nothing.
func fakeSyncer(syncCh <-chan *syncRequest, done <-chan struct{}) {
for {
select {
case <-syncCh:
case <-done:
return
}
}
}
func TestProcessorSuccess(t *testing.T) { func TestProcessorSuccess(t *testing.T) {
r := setup(t) r := setup(t)
rdbClient := rdb.NewRDB(r) rdbClient := rdb.NewRDB(r)
@@ -77,14 +88,16 @@ func TestProcessorSuccess(t *testing.T) {
} }
starting := make(chan *base.TaskMessage) starting := make(chan *base.TaskMessage)
finished := make(chan *base.TaskMessage) finished := make(chan *base.TaskMessage)
syncCh := make(chan *syncRequest)
done := make(chan struct{}) done := make(chan struct{})
defer func() { close(done) }() defer func() { close(done) }()
go fakeHeartbeater(starting, finished, done) go fakeHeartbeater(starting, finished, done)
go fakeSyncer(syncCh, done)
p := newProcessor(processorParams{ p := newProcessor(processorParams{
logger: testLogger, logger: testLogger,
broker: rdbClient, broker: rdbClient,
retryDelayFunc: defaultDelayFunc, retryDelayFunc: defaultDelayFunc,
syncCh: nil, syncCh: syncCh,
cancelations: base.NewCancelations(), cancelations: base.NewCancelations(),
concurrency: 10, concurrency: 10,
queues: defaultQueueConfig, queues: defaultQueueConfig,
@@ -105,6 +118,9 @@ func TestProcessorSuccess(t *testing.T) {
} }
} }
time.Sleep(2 * time.Second) // wait for two second to allow all enqueued tasks to be processed. time.Sleep(2 * time.Second) // wait for two second to allow all enqueued tasks to be processed.
if l := r.LLen(base.InProgressQueue).Val(); l != 0 {
t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l)
}
p.terminate() p.terminate()
mu.Lock() mu.Lock()
@@ -112,10 +128,79 @@ func TestProcessorSuccess(t *testing.T) {
t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
} }
mu.Unlock() mu.Unlock()
}
}
// https://github.com/hibiken/asynq/issues/166
func TestProcessTasksWithLargeNumberInPayload(t *testing.T) {
r := setup(t)
rdbClient := rdb.NewRDB(r)
m1 := h.NewTaskMessage("large_number", map[string]interface{}{"data": 111111111111111111})
t1 := NewTask(m1.Type, m1.Payload)
tests := []struct {
enqueued []*base.TaskMessage // initial default queue state
wantProcessed []*Task // tasks to be processed at the end
}{
{
enqueued: []*base.TaskMessage{m1},
wantProcessed: []*Task{t1},
},
}
for _, tc := range tests {
h.FlushDB(t, r) // clean up db before each test case.
h.SeedEnqueuedQueue(t, r, tc.enqueued) // initialize default queue.
var mu sync.Mutex
var processed []*Task
handler := func(ctx context.Context, task *Task) error {
mu.Lock()
defer mu.Unlock()
if data, err := task.Payload.GetInt("data"); err != nil {
t.Errorf("coult not get data from payload: %v", err)
} else {
t.Logf("data == %d", data)
}
processed = append(processed, task)
return nil
}
starting := make(chan *base.TaskMessage)
finished := make(chan *base.TaskMessage)
syncCh := make(chan *syncRequest)
done := make(chan struct{})
defer func() { close(done) }()
go fakeHeartbeater(starting, finished, done)
go fakeSyncer(syncCh, done)
p := newProcessor(processorParams{
logger: testLogger,
broker: rdbClient,
retryDelayFunc: defaultDelayFunc,
syncCh: syncCh,
cancelations: base.NewCancelations(),
concurrency: 10,
queues: defaultQueueConfig,
strictPriority: false,
errHandler: nil,
shutdownTimeout: defaultShutdownTimeout,
starting: starting,
finished: finished,
})
p.handler = HandlerFunc(handler)
p.start(&sync.WaitGroup{})
time.Sleep(2 * time.Second) // wait for two second to allow all enqueued tasks to be processed.
if l := r.LLen(base.InProgressQueue).Val(); l != 0 { if l := r.LLen(base.InProgressQueue).Val(); l != 0 {
t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l) t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l)
} }
p.terminate()
mu.Lock()
if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt, cmpopts.IgnoreUnexported(Payload{})); diff != "" {
t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
}
mu.Unlock()
} }
} }