2
0
mirror of https://github.com/hibiken/asynq.git synced 2025-04-22 16:50:18 +08:00

Update RDB.ListPending and RDB.ListActive to return list of TaskInfo

This commit is contained in:
Ken Hibino 2021-04-01 16:11:56 -07:00
parent 57fc8b86e2
commit f0c12cc6e3
2 changed files with 80 additions and 57 deletions

View File

@ -281,13 +281,49 @@ func parseInfo(infoStr string) (map[string]string, error) {
return info, nil
}
func reverse(x []string) {
func reverse(x []interface{}) {
for i := len(x)/2 - 1; i >= 0; i-- {
opp := len(x) - 1 - i
x[i], x[opp] = x[opp], x[i]
}
}
// makeTaskInfo takes values returned from HMGET(TASK_KEY, "msg", "state", "process_at", "last_failed_at")
// command and return a TaskInfo. It assumes that `vals` contains four values for each field.
func makeTaskInfo(vals []interface{}) (*base.TaskInfo, error) {
if len(vals) != 4 {
return nil, fmt.Errorf("asynq internal error: HMGET command returned %d elements", len(vals))
}
// Note: The "msg", "state" fields are non-nil;
// whereas the "process_at", "last_failed_at" fields can be nil.
encoded := vals[0]
if encoded == nil {
return nil, fmt.Errorf("asynq internal error: HMGET field 'msg' was nil")
}
msg, err := base.DecodeMessage([]byte(encoded.(string)))
if err != nil {
return nil, err
}
state := vals[1]
if state == nil {
return nil, fmt.Errorf("asynq internal error: HMGET field 'state' was nil")
}
processAt, err := parseIntOrDefault(vals[2], 0)
if err != nil {
return nil, err
}
lastFailedAt, err := parseIntOrDefault(vals[3], 0)
if err != nil {
return nil, err
}
return &base.TaskInfo{
TaskMessage: msg,
State: strings.ToLower(state.(string)),
NextProcessAt: processAt,
LastFailedAt: lastFailedAt,
}, nil
}
// Parses val as base10 64-bit integer if val contains a value.
// Uses default value if val is nil.
//
@ -310,41 +346,11 @@ func (r *RDB) GetTaskInfo(qname, id string) (*base.TaskInfo, error) {
if exists == 0 {
return nil, ErrTaskNotFound
}
// The "msg", "state" fields are non-nil;
// whereas the "process_at", "last_failed_at" fields can be nil.
res, err := r.client.HMGet(key, "msg", "state", "process_at", "last_failed_at").Result()
if err != nil {
return nil, err
}
if len(res) != 4 {
return nil, fmt.Errorf("asynq internal error: HMGET command returned %d elements", len(res))
}
encoded := res[0]
if encoded == nil {
return nil, fmt.Errorf("asynq internal error: HMGET field 'msg' was nil")
}
msg, err := base.DecodeMessage([]byte(encoded.(string)))
if err != nil {
return nil, err
}
state := res[1]
if state == nil {
return nil, fmt.Errorf("asynq internal error: HMGET field 'state' was nil")
}
processAt, err := parseIntOrDefault(res[2], 0)
if err != nil {
return nil, err
}
lastFailedAt, err := parseIntOrDefault(res[3], 0)
if err != nil {
return nil, err
}
return &base.TaskInfo{
TaskMessage: msg,
State: strings.ToLower(state.(string)),
NextProcessAt: processAt,
LastFailedAt: lastFailedAt,
}, nil
return makeTaskInfo(res)
}
// Pagination specifies the page size and page number
@ -366,7 +372,7 @@ func (p Pagination) stop() int64 {
}
// ListPending returns pending tasks that are ready to be processed.
func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskMessage, error) {
func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error) {
if !r.client.SIsMember(base.AllQueues, qname).Val() {
return nil, fmt.Errorf("queue %q does not exist", qname)
}
@ -374,7 +380,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskMessage, er
}
// ListActive returns all tasks that are currently being processed for the given queue.
func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskMessage, error) {
func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) {
if !r.client.SIsMember(base.AllQueues, qname).Val() {
return nil, fmt.Errorf("queue %q does not exist", qname)
}
@ -390,13 +396,13 @@ local ids = redis.call("LRange", KEYS[1], ARGV[1], ARGV[2])
local res = {}
for _, id in ipairs(ids) do
local key = ARGV[3] .. id
table.insert(res, redis.call("HGET", key, "msg"))
table.insert(res, redis.call("HMGET", key, "msg", "state", "process_at", "last_failed_at"))
end
return res
`)
// listMessages returns a list of TaskMessage in Redis list with the given key.
func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskMessage, error) {
// listMessages returns a list of TaskInfo in Redis list with the given key.
func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskInfo, error) {
// Note: Because we use LPUSH to redis list, we need to calculate the
// correct range and reverse the list to get the tasks with pagination.
stop := -pgn.start() - 1
@ -406,20 +412,24 @@ func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskMessa
if err != nil {
return nil, err
}
data, err := cast.ToStringSliceE(res)
data, err := cast.ToSliceE(res)
if err != nil {
return nil, err
}
reverse(data)
var msgs []*base.TaskMessage
var tasks []*base.TaskInfo
for _, s := range data {
m, err := base.DecodeMessage([]byte(s))
vals, err := cast.ToSliceE(s)
if err != nil {
continue // bad data, ignore and continue
return nil, err
}
msgs = append(msgs, m)
info, err := makeTaskInfo(vals)
if err != nil {
return nil, err
}
tasks = append(tasks, info)
}
return msgs, nil
return tasks, nil
}

View File

@ -474,24 +474,29 @@ func TestListPending(t *testing.T) {
m3 := h.NewTaskMessageWithQueue("important_notification", nil, "critical")
m4 := h.NewTaskMessageWithQueue("minor_notification", nil, "low")
now := time.Now()
tests := []struct {
pending map[string][]*base.TaskMessage
qname string
want []*base.TaskMessage
want []*base.TaskInfo
}{
{
pending: map[string][]*base.TaskMessage{
base.DefaultQueueName: {m1, m2},
},
qname: base.DefaultQueueName,
want: []*base.TaskMessage{m1, m2},
want: []*base.TaskInfo{
{TaskMessage: m1, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0},
{TaskMessage: m2, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0},
},
},
{
pending: map[string][]*base.TaskMessage{
base.DefaultQueueName: nil,
},
qname: base.DefaultQueueName,
want: []*base.TaskMessage(nil),
want: []*base.TaskInfo(nil),
},
{
pending: map[string][]*base.TaskMessage{
@ -500,7 +505,10 @@ func TestListPending(t *testing.T) {
"low": {m4},
},
qname: base.DefaultQueueName,
want: []*base.TaskMessage{m1, m2},
want: []*base.TaskInfo{
{TaskMessage: m1, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0},
{TaskMessage: m2, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0},
},
},
{
pending: map[string][]*base.TaskMessage{
@ -509,7 +517,9 @@ func TestListPending(t *testing.T) {
"low": {m4},
},
qname: "critical",
want: []*base.TaskMessage{m3},
want: []*base.TaskInfo{
{TaskMessage: m3, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0},
},
},
}
@ -607,36 +617,39 @@ func TestListActive(t *testing.T) {
m4 := h.NewTaskMessageWithQueue("task2", nil, "low")
tests := []struct {
inProgress map[string][]*base.TaskMessage
qname string
want []*base.TaskMessage
active map[string][]*base.TaskMessage
qname string
want []*base.TaskInfo
}{
{
inProgress: map[string][]*base.TaskMessage{
active: map[string][]*base.TaskMessage{
"default": {m1, m2},
"critical": {m3},
"low": {m4},
},
qname: "default",
want: []*base.TaskMessage{m1, m2},
want: []*base.TaskInfo{
{TaskMessage: m1, State: "active", NextProcessAt: 0, LastFailedAt: 0},
{TaskMessage: m2, State: "active", NextProcessAt: 0, LastFailedAt: 0},
},
},
{
inProgress: map[string][]*base.TaskMessage{
active: map[string][]*base.TaskMessage{
"default": {},
},
qname: "default",
want: []*base.TaskMessage(nil),
want: []*base.TaskInfo(nil),
},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case
h.SeedAllActiveQueues(t, r.client, tc.inProgress)
h.SeedAllActiveQueues(t, r.client, tc.active)
got, err := r.ListActive(tc.qname, Pagination{Size: 20, Page: 0})
op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: 20, Page: 0})", tc.qname)
if err != nil {
t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.inProgress)
t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.active)
continue
}
if diff := cmp.Diff(tc.want, got); diff != "" {