From f0c12cc6e36e8d558e2b1c36d18404fba0aa5138 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Thu, 1 Apr 2021 16:11:56 -0700 Subject: [PATCH] Update RDB.ListPending and RDB.ListActive to return list of TaskInfo --- internal/rdb/inspect.go | 96 ++++++++++++++++++++---------------- internal/rdb/inspect_test.go | 41 +++++++++------ 2 files changed, 80 insertions(+), 57 deletions(-) diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go index a82770b..795fab8 100644 --- a/internal/rdb/inspect.go +++ b/internal/rdb/inspect.go @@ -281,13 +281,49 @@ func parseInfo(infoStr string) (map[string]string, error) { return info, nil } -func reverse(x []string) { +func reverse(x []interface{}) { for i := len(x)/2 - 1; i >= 0; i-- { opp := len(x) - 1 - i x[i], x[opp] = x[opp], x[i] } } +// makeTaskInfo takes values returned from HMGET(TASK_KEY, "msg", "state", "process_at", "last_failed_at") +// command and return a TaskInfo. It assumes that `vals` contains four values for each field. +func makeTaskInfo(vals []interface{}) (*base.TaskInfo, error) { + if len(vals) != 4 { + return nil, fmt.Errorf("asynq internal error: HMGET command returned %d elements", len(vals)) + } + // Note: The "msg", "state" fields are non-nil; + // whereas the "process_at", "last_failed_at" fields can be nil. + encoded := vals[0] + if encoded == nil { + return nil, fmt.Errorf("asynq internal error: HMGET field 'msg' was nil") + } + msg, err := base.DecodeMessage([]byte(encoded.(string))) + if err != nil { + return nil, err + } + state := vals[1] + if state == nil { + return nil, fmt.Errorf("asynq internal error: HMGET field 'state' was nil") + } + processAt, err := parseIntOrDefault(vals[2], 0) + if err != nil { + return nil, err + } + lastFailedAt, err := parseIntOrDefault(vals[3], 0) + if err != nil { + return nil, err + } + return &base.TaskInfo{ + TaskMessage: msg, + State: strings.ToLower(state.(string)), + NextProcessAt: processAt, + LastFailedAt: lastFailedAt, + }, nil +} + // Parses val as base10 64-bit integer if val contains a value. // Uses default value if val is nil. // @@ -310,41 +346,11 @@ func (r *RDB) GetTaskInfo(qname, id string) (*base.TaskInfo, error) { if exists == 0 { return nil, ErrTaskNotFound } - // The "msg", "state" fields are non-nil; - // whereas the "process_at", "last_failed_at" fields can be nil. res, err := r.client.HMGet(key, "msg", "state", "process_at", "last_failed_at").Result() if err != nil { return nil, err } - if len(res) != 4 { - return nil, fmt.Errorf("asynq internal error: HMGET command returned %d elements", len(res)) - } - encoded := res[0] - if encoded == nil { - return nil, fmt.Errorf("asynq internal error: HMGET field 'msg' was nil") - } - msg, err := base.DecodeMessage([]byte(encoded.(string))) - if err != nil { - return nil, err - } - state := res[1] - if state == nil { - return nil, fmt.Errorf("asynq internal error: HMGET field 'state' was nil") - } - processAt, err := parseIntOrDefault(res[2], 0) - if err != nil { - return nil, err - } - lastFailedAt, err := parseIntOrDefault(res[3], 0) - if err != nil { - return nil, err - } - return &base.TaskInfo{ - TaskMessage: msg, - State: strings.ToLower(state.(string)), - NextProcessAt: processAt, - LastFailedAt: lastFailedAt, - }, nil + return makeTaskInfo(res) } // Pagination specifies the page size and page number @@ -366,7 +372,7 @@ func (p Pagination) stop() int64 { } // ListPending returns pending tasks that are ready to be processed. -func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskMessage, error) { +func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskInfo, error) { if !r.client.SIsMember(base.AllQueues, qname).Val() { return nil, fmt.Errorf("queue %q does not exist", qname) } @@ -374,7 +380,7 @@ func (r *RDB) ListPending(qname string, pgn Pagination) ([]*base.TaskMessage, er } // ListActive returns all tasks that are currently being processed for the given queue. -func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskMessage, error) { +func (r *RDB) ListActive(qname string, pgn Pagination) ([]*base.TaskInfo, error) { if !r.client.SIsMember(base.AllQueues, qname).Val() { return nil, fmt.Errorf("queue %q does not exist", qname) } @@ -390,13 +396,13 @@ local ids = redis.call("LRange", KEYS[1], ARGV[1], ARGV[2]) local res = {} for _, id in ipairs(ids) do local key = ARGV[3] .. id - table.insert(res, redis.call("HGET", key, "msg")) + table.insert(res, redis.call("HMGET", key, "msg", "state", "process_at", "last_failed_at")) end return res `) -// listMessages returns a list of TaskMessage in Redis list with the given key. -func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskMessage, error) { +// listMessages returns a list of TaskInfo in Redis list with the given key. +func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskInfo, error) { // Note: Because we use LPUSH to redis list, we need to calculate the // correct range and reverse the list to get the tasks with pagination. stop := -pgn.start() - 1 @@ -406,20 +412,24 @@ func (r *RDB) listMessages(key, qname string, pgn Pagination) ([]*base.TaskMessa if err != nil { return nil, err } - data, err := cast.ToStringSliceE(res) + data, err := cast.ToSliceE(res) if err != nil { return nil, err } reverse(data) - var msgs []*base.TaskMessage + var tasks []*base.TaskInfo for _, s := range data { - m, err := base.DecodeMessage([]byte(s)) + vals, err := cast.ToSliceE(s) if err != nil { - continue // bad data, ignore and continue + return nil, err } - msgs = append(msgs, m) + info, err := makeTaskInfo(vals) + if err != nil { + return nil, err + } + tasks = append(tasks, info) } - return msgs, nil + return tasks, nil } diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index ec22b1d..148b4a0 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -474,24 +474,29 @@ func TestListPending(t *testing.T) { m3 := h.NewTaskMessageWithQueue("important_notification", nil, "critical") m4 := h.NewTaskMessageWithQueue("minor_notification", nil, "low") + now := time.Now() + tests := []struct { pending map[string][]*base.TaskMessage qname string - want []*base.TaskMessage + want []*base.TaskInfo }{ { pending: map[string][]*base.TaskMessage{ base.DefaultQueueName: {m1, m2}, }, qname: base.DefaultQueueName, - want: []*base.TaskMessage{m1, m2}, + want: []*base.TaskInfo{ + {TaskMessage: m1, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0}, + {TaskMessage: m2, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0}, + }, }, { pending: map[string][]*base.TaskMessage{ base.DefaultQueueName: nil, }, qname: base.DefaultQueueName, - want: []*base.TaskMessage(nil), + want: []*base.TaskInfo(nil), }, { pending: map[string][]*base.TaskMessage{ @@ -500,7 +505,10 @@ func TestListPending(t *testing.T) { "low": {m4}, }, qname: base.DefaultQueueName, - want: []*base.TaskMessage{m1, m2}, + want: []*base.TaskInfo{ + {TaskMessage: m1, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0}, + {TaskMessage: m2, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0}, + }, }, { pending: map[string][]*base.TaskMessage{ @@ -509,7 +517,9 @@ func TestListPending(t *testing.T) { "low": {m4}, }, qname: "critical", - want: []*base.TaskMessage{m3}, + want: []*base.TaskInfo{ + {TaskMessage: m3, State: "pending", NextProcessAt: now.Unix(), LastFailedAt: 0}, + }, }, } @@ -607,36 +617,39 @@ func TestListActive(t *testing.T) { m4 := h.NewTaskMessageWithQueue("task2", nil, "low") tests := []struct { - inProgress map[string][]*base.TaskMessage - qname string - want []*base.TaskMessage + active map[string][]*base.TaskMessage + qname string + want []*base.TaskInfo }{ { - inProgress: map[string][]*base.TaskMessage{ + active: map[string][]*base.TaskMessage{ "default": {m1, m2}, "critical": {m3}, "low": {m4}, }, qname: "default", - want: []*base.TaskMessage{m1, m2}, + want: []*base.TaskInfo{ + {TaskMessage: m1, State: "active", NextProcessAt: 0, LastFailedAt: 0}, + {TaskMessage: m2, State: "active", NextProcessAt: 0, LastFailedAt: 0}, + }, }, { - inProgress: map[string][]*base.TaskMessage{ + active: map[string][]*base.TaskMessage{ "default": {}, }, qname: "default", - want: []*base.TaskMessage(nil), + want: []*base.TaskInfo(nil), }, } for _, tc := range tests { h.FlushDB(t, r.client) // clean up db before each test case - h.SeedAllActiveQueues(t, r.client, tc.inProgress) + h.SeedAllActiveQueues(t, r.client, tc.active) got, err := r.ListActive(tc.qname, Pagination{Size: 20, Page: 0}) op := fmt.Sprintf("r.ListActive(%q, Pagination{Size: 20, Page: 0})", tc.qname) if err != nil { - t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.inProgress) + t.Errorf("%s = %v, %v, want %v, nil", op, got, err, tc.active) continue } if diff := cmp.Diff(tc.want, got); diff != "" {