From 1acd62c76038196d1e58f2ee06f40467d062a560 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Sat, 19 Mar 2022 07:12:41 -0700 Subject: [PATCH] Move test helpers to asynqtest package --- internal/asynqtest/asynqtest.go | 94 ++++++++++++ internal/rdb/inspect_test.go | 70 ++++----- internal/rdb/rdb.go | 17 +++ internal/rdb/rdb_test.go | 248 ++++++++++---------------------- 4 files changed, 221 insertions(+), 208 deletions(-) diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 6511fee..cb00a4a 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -93,6 +93,20 @@ var SortStringSliceOpt = cmp.Transformer("SortStringSlice", func(in []string) [] return out }) +var SortRedisZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []redis.Z) []redis.Z { + out := append([]redis.Z(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + // TODO: If member is a comparable type (int, string, etc) compare by the member + // Use generic comparable type here once update to go1.18 + if _, ok := out[i].Member.(string); ok { + // If member is a string, compare the member + return out[i].Member.(string) < out[j].Member.(string) + } + return out[i].Score < out[j].Score + }) + return out +}) + // IgnoreIDOpt is an cmp.Option to ignore ID field in task messages when comparing. var IgnoreIDOpt = cmpopts.IgnoreFields(base.TaskMessage{}, "ID") @@ -522,3 +536,83 @@ func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient, } return res } + +// TaskSeedData holds the data required to seed tasks under the task key in test. +type TaskSeedData struct { + Msg *base.TaskMessage + State base.TaskState + PendingSince time.Time +} + +func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*TaskSeedData) { + for _, data := range taskData { + msg := data.Msg + ctx := context.Background() + key := base.TaskKey(msg.Queue, msg.ID) + v := map[string]interface{}{ + "msg": MustMarshal(tb, msg), + "state": data.State.String(), + "unique_key": msg.UniqueKey, + "group": msg.GroupKey, + } + if !data.PendingSince.IsZero() { + v["pending_since"] = data.PendingSince.Unix() + } + if err := r.HSet(ctx, key, v).Err(); err != nil { + tb.Fatalf("Failed to write task data in redis: %v", err) + } + if len(msg.UniqueKey) > 0 { + err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err() + if err != nil { + tb.Fatalf("Failed to set unique lock in redis: %v", err) + } + } + } +} + +func SeedRedisZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) { + for key, zs := range zsets { + // FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here? + for _, z := range zs { + if err := r.ZAdd(context.Background(), key, z).Err(); err != nil { + tb.Fatalf("Failed to seed zset (key=%q): %v", key, err) + } + } + } +} + +func SeedRedisSets(tb testing.TB, r redis.UniversalClient, sets map[string][]string) { + for key, set := range sets { + SeedRedisSet(tb, r, key, set) + } +} + +func SeedRedisSet(tb testing.TB, r redis.UniversalClient, key string, members []string) { + for _, mem := range members { + if err := r.SAdd(context.Background(), key, mem).Err(); err != nil { + tb.Fatalf("Failed to seed set (key=%q): %v", key, err) + } + } +} + +func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]string) { + for key, vals := range lists { + for _, v := range vals { + if err := r.LPush(context.Background(), key, v).Err(); err != nil { + tb.Fatalf("Failed to seed list (key=%q): %v", key, err) + } + } + } +} + +func AssertRedisZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) { + for key, want := range wantZSets { + got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result() + if err != nil { + t.Fatalf("Failed to read zset (key=%q): %v", key, err) + } + if diff := cmp.Diff(want, got, SortRedisZSetEntryOpt); diff != "" { + t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff) + } + } +} diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 3df09a5..5d5aea2 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -68,7 +68,7 @@ func TestCurrentStats(t *testing.T) { r.SetClock(timeutil.NewSimulatedClock(now)) tests := []struct { - tasks []*taskData + tasks []*h.TaskSeedData allQueues []string allGroups map[string][]string pending map[string][]string @@ -88,14 +88,14 @@ func TestCurrentStats(t *testing.T) { want *Stats }{ { - tasks: []*taskData{ - {msg: m1, state: base.TaskStatePending}, - {msg: m2, state: base.TaskStateActive}, - {msg: m3, state: base.TaskStateScheduled}, - {msg: m4, state: base.TaskStateScheduled}, - {msg: m5, state: base.TaskStatePending}, - {msg: m6, state: base.TaskStatePending}, - {msg: m7, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStatePending}, + {Msg: m2, State: base.TaskStateActive}, + {Msg: m3, State: base.TaskStateScheduled}, + {Msg: m4, State: base.TaskStateScheduled}, + {Msg: m5, State: base.TaskStatePending}, + {Msg: m6, State: base.TaskStatePending}, + {Msg: m7, State: base.TaskStateAggregating}, }, allQueues: []string{"default", "critical", "low"}, allGroups: map[string][]string{ @@ -187,12 +187,12 @@ func TestCurrentStats(t *testing.T) { }, }, { - tasks: []*taskData{ - {msg: m1, state: base.TaskStatePending}, - {msg: m2, state: base.TaskStateActive}, - {msg: m3, state: base.TaskStateScheduled}, - {msg: m4, state: base.TaskStateScheduled}, - {msg: m6, state: base.TaskStatePending}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStatePending}, + {Msg: m2, State: base.TaskStateActive}, + {Msg: m3, State: base.TaskStateScheduled}, + {Msg: m4, State: base.TaskStateScheduled}, + {Msg: m6, State: base.TaskStatePending}, }, allQueues: []string{"default", "critical", "low"}, pending: map[string][]string{ @@ -284,16 +284,16 @@ func TestCurrentStats(t *testing.T) { t.Fatal(err) } } - SeedSet(t, r.client, base.AllQueues, tc.allQueues) - SeedSets(t, r.client, tc.allGroups) - SeedTasks(t, r.client, tc.tasks) - SeedLists(t, r.client, tc.pending) - SeedLists(t, r.client, tc.active) - SeedZSets(t, r.client, tc.scheduled) - SeedZSets(t, r.client, tc.retry) - SeedZSets(t, r.client, tc.archived) - SeedZSets(t, r.client, tc.completed) - SeedZSets(t, r.client, tc.groups) + h.SeedRedisSet(t, r.client, base.AllQueues, tc.allQueues) + h.SeedRedisSets(t, r.client, tc.allGroups) + h.SeedTasks(t, r.client, tc.tasks) + h.SeedRedisLists(t, r.client, tc.pending) + h.SeedRedisLists(t, r.client, tc.active) + h.SeedRedisZSets(t, r.client, tc.scheduled) + h.SeedRedisZSets(t, r.client, tc.retry) + h.SeedRedisZSets(t, r.client, tc.archived) + h.SeedRedisZSets(t, r.client, tc.completed) + h.SeedRedisZSets(t, r.client, tc.groups) ctx := context.Background() for qname, n := range tc.processed { r.client.Set(ctx, base.ProcessedKey(qname, now), n, 0) @@ -434,16 +434,16 @@ func TestGroupStats(t *testing.T) { now := time.Now() fixtures := struct { - tasks []*taskData + tasks []*h.TaskSeedData allGroups map[string][]string groups map[string][]*redis.Z }{ - tasks: []*taskData{ - {msg: m1, state: base.TaskStateAggregating}, - {msg: m2, state: base.TaskStateAggregating}, - {msg: m3, state: base.TaskStateAggregating}, - {msg: m4, state: base.TaskStateAggregating}, - {msg: m5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, + {Msg: m4, State: base.TaskStateAggregating}, + {Msg: m5, State: base.TaskStateAggregating}, }, allGroups: map[string][]string{ base.AllGroups("default"): {"group1", "group2"}, @@ -499,9 +499,9 @@ func TestGroupStats(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) - SeedTasks(t, r.client, fixtures.tasks) - SeedSets(t, r.client, fixtures.allGroups) - SeedZSets(t, r.client, fixtures.groups) + h.SeedTasks(t, r.client, fixtures.tasks) + h.SeedRedisSets(t, r.client, fixtures.allGroups) + h.SeedRedisZSets(t, r.client, fixtures.groups) t.Run(tc.desc, func(t *testing.T) { got, err := r.GroupStats(tc.qname) diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index e986fb6..7b25f15 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -1017,6 +1017,10 @@ func (r *RDB) ListGroups(qname string) ([]string, error) { // Output: // Returns 0 if no aggregation set was created // Returns 1 if an aggregation set was created +// +// Time Complexity: +// O(log(N) + M) with N being the number tasks in the group zset +// and M being the max size. var aggregationCheckCmd = redis.NewScript(` local size = redis.call("ZCARD", KEYS[1]) if size == 0 then @@ -1118,6 +1122,12 @@ func (r *RDB) AggregationCheck(qname, gname string, t time.Time, gracePeriod, ma // KEYS[1] -> asynq:{}:g:: // ------ // ARGV[1] -> task key prefix +// +// Output: +// Array of encoded task messages +// +// Time Complexity: +// O(N) with N being the number of tasks in the aggregation set. var readAggregationSetCmd = redis.NewScript(` local msgs = {} local ids = redis.call("ZRANGE", KEYS[1], 0, -1) @@ -1162,6 +1172,13 @@ func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessag // KEYS[2] -> asynq:{}:aggregation_sets // ------- // ARGV[1] -> task key prefix +// +// Output: +// Redis status reply +// +// Time Complexity: +// max(O(N), O(log(M))) with N being the number of tasks in the aggregation set +// and M being the number of elements in the all-aggregation-sets list. var deleteAggregationSetCmd = redis.NewScript(` local ids = redis.call("ZRANGE", KEYS[1], 0, -1) for _, id in ipairs(ids) do diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 528d917..cf4be25 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -9,7 +9,6 @@ import ( "encoding/json" "flag" "math" - "sort" "strconv" "strings" "sync" @@ -3122,7 +3121,7 @@ func TestAggregationCheck(t *testing.T) { tests := []struct { desc string // initial data - tasks []*taskData + tasks []*h.TaskSeedData groups map[string][]*redis.Z allGroups map[string][]string @@ -3141,7 +3140,7 @@ func TestAggregationCheck(t *testing.T) { }{ { desc: "with an empty group", - tasks: []*taskData{}, + tasks: []*h.TaskSeedData{}, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): {}, }, @@ -3162,12 +3161,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with a group size reaching the max size", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3195,12 +3194,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with group size greater than max size", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3231,10 +3230,10 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with the most recent task older than grace period", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3260,12 +3259,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with the oldest task older than max delay", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3293,12 +3292,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with unlimited size", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3332,12 +3331,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with unlimited size and passed grace period", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3365,12 +3364,12 @@ func TestAggregationCheck(t *testing.T) { }, { desc: "with unlimited delay", - tasks: []*taskData{ - {msg: msg1, state: base.TaskStateAggregating}, - {msg: msg2, state: base.TaskStateAggregating}, - {msg: msg3, state: base.TaskStateAggregating}, - {msg: msg4, state: base.TaskStateAggregating}, - {msg: msg5, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: msg1, State: base.TaskStateAggregating}, + {Msg: msg2, State: base.TaskStateAggregating}, + {Msg: msg3, State: base.TaskStateAggregating}, + {Msg: msg4, State: base.TaskStateAggregating}, + {Msg: msg5, State: base.TaskStateAggregating}, }, groups: map[string][]*redis.Z{ base.GroupKey("default", "mygroup"): { @@ -3408,9 +3407,9 @@ func TestAggregationCheck(t *testing.T) { h.FlushDB(t, r.client) t.Run(tc.desc, func(t *testing.T) { - SeedTasks(t, r.client, tc.tasks) - SeedZSets(t, r.client, tc.groups) - SeedSets(t, r.client, tc.allGroups) + h.SeedTasks(t, r.client, tc.tasks) + h.SeedRedisZSets(t, r.client, tc.groups) + h.SeedRedisSets(t, r.client, tc.allGroups) aggregationSetID, err := r.AggregationCheck(tc.qname, tc.gname, now, tc.gracePeriod, tc.maxDelay, tc.maxSize) if err != nil { @@ -3438,7 +3437,7 @@ func TestAggregationCheck(t *testing.T) { } } - AssertZSets(t, r.client, tc.wantGroups) + h.AssertRedisZSets(t, r.client, tc.wantGroups) if tc.shouldClearGroup { if key := base.GroupKey(tc.qname, tc.gname); r.client.Exists(ctx, key).Val() != 0 { @@ -3473,7 +3472,7 @@ func TestDeleteAggregationSet(t *testing.T) { tests := []struct { desc string // initial data - tasks []*taskData + tasks []*h.TaskSeedData aggregationSets map[string][]*redis.Z allAggregationSets map[string][]*redis.Z @@ -3490,10 +3489,10 @@ func TestDeleteAggregationSet(t *testing.T) { }{ { desc: "with a sigle active aggregation set", - tasks: []*taskData{ - {msg: m1, state: base.TaskStateAggregating}, - {msg: m2, state: base.TaskStateAggregating}, - {msg: m3, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, }, aggregationSets: map[string][]*redis.Z{ base.AggregationSetKey("default", "mygroup", setID): { @@ -3524,10 +3523,10 @@ func TestDeleteAggregationSet(t *testing.T) { }, { desc: "with multiple active aggregation sets", - tasks: []*taskData{ - {msg: m1, state: base.TaskStateAggregating}, - {msg: m2, state: base.TaskStateAggregating}, - {msg: m3, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, }, aggregationSets: map[string][]*redis.Z{ base.AggregationSetKey("default", "mygroup", setID): { @@ -3569,9 +3568,9 @@ func TestDeleteAggregationSet(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) t.Run(tc.desc, func(t *testing.T) { - SeedTasks(t, r.client, tc.tasks) - SeedZSets(t, r.client, tc.aggregationSets) - SeedZSets(t, r.client, tc.allAggregationSets) + h.SeedTasks(t, r.client, tc.tasks) + h.SeedRedisZSets(t, r.client, tc.aggregationSets) + h.SeedRedisZSets(t, r.client, tc.allAggregationSets) if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err != nil { t.Fatalf("DeleteAggregationSet returned error: %v", err) @@ -3582,7 +3581,7 @@ func TestDeleteAggregationSet(t *testing.T) { t.Errorf("key=%q still exists, want deleted", key) } } - AssertZSets(t, r.client, tc.wantAllAggregationSets) + h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets) }) } } @@ -3602,7 +3601,7 @@ func TestDeleteAggregationSetError(t *testing.T) { tests := []struct { desc string // initial data - tasks []*taskData + tasks []*h.TaskSeedData aggregationSets map[string][]*redis.Z allAggregationSets map[string][]*redis.Z @@ -3618,10 +3617,10 @@ func TestDeleteAggregationSetError(t *testing.T) { }{ { desc: "with deadline exceeded context", - tasks: []*taskData{ - {msg: m1, state: base.TaskStateAggregating}, - {msg: m2, state: base.TaskStateAggregating}, - {msg: m3, state: base.TaskStateAggregating}, + tasks: []*h.TaskSeedData{ + {Msg: m1, State: base.TaskStateAggregating}, + {Msg: m2, State: base.TaskStateAggregating}, + {Msg: m3, State: base.TaskStateAggregating}, }, aggregationSets: map[string][]*redis.Z{ base.AggregationSetKey("default", "mygroup", setID): { @@ -3659,17 +3658,17 @@ func TestDeleteAggregationSetError(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) t.Run(tc.desc, func(t *testing.T) { - SeedTasks(t, r.client, tc.tasks) - SeedZSets(t, r.client, tc.aggregationSets) - SeedZSets(t, r.client, tc.allAggregationSets) + h.SeedTasks(t, r.client, tc.tasks) + h.SeedRedisZSets(t, r.client, tc.aggregationSets) + h.SeedRedisZSets(t, r.client, tc.allAggregationSets) if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err == nil { t.Fatal("DeleteAggregationSet returned nil, want non-nil error") } // Make sure zsets are unchanged. - AssertZSets(t, r.client, tc.wantAggregationSets) - AssertZSets(t, r.client, tc.wantAllAggregationSets) + h.AssertRedisZSets(t, r.client, tc.wantAggregationSets) + h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets) }) } } @@ -3746,118 +3745,21 @@ func TestReclaimStaleAggregationSets(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) - SeedZSets(t, r.client, tc.groups) - SeedZSets(t, r.client, tc.aggregationSets) - SeedZSets(t, r.client, tc.allAggregationSets) + h.SeedRedisZSets(t, r.client, tc.groups) + h.SeedRedisZSets(t, r.client, tc.aggregationSets) + h.SeedRedisZSets(t, r.client, tc.allAggregationSets) if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil { t.Errorf("ReclaimStaleAggregationSets returned error: %v", err) continue } - AssertZSets(t, r.client, tc.wantGroups) - AssertZSets(t, r.client, tc.wantAggregationSets) - AssertZSets(t, r.client, tc.wantAllAggregationSets) + h.AssertRedisZSets(t, r.client, tc.wantGroups) + h.AssertRedisZSets(t, r.client, tc.wantAggregationSets) + h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets) } } -// taskData holds the data required to seed tasks under the task key in test. -type taskData struct { - msg *base.TaskMessage - state base.TaskState - pendingSince time.Time -} - -// TODO: move this helper somewhere more canonical -func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*taskData) { - for _, data := range taskData { - msg := data.msg - ctx := context.Background() - key := base.TaskKey(msg.Queue, msg.ID) - v := map[string]interface{}{ - "msg": h.MustMarshal(tb, msg), - "state": data.state.String(), - "unique_key": msg.UniqueKey, - "group": msg.GroupKey, - } - if !data.pendingSince.IsZero() { - v["pending_since"] = data.pendingSince.Unix() - } - if err := r.HSet(ctx, key, v).Err(); err != nil { - tb.Fatalf("Failed to write task data in redis: %v", err) - } - if len(msg.UniqueKey) > 0 { - err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err() - if err != nil { - tb.Fatalf("Failed to set unique lock in redis: %v", err) - } - } - } -} - -// TODO: move this helper somewhere more canonical -func SeedZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) { - for key, zs := range zsets { - // FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here? - for _, z := range zs { - if err := r.ZAdd(context.Background(), key, z).Err(); err != nil { - tb.Fatalf("Failed to seed zset (key=%q): %v", key, err) - } - } - } -} - -func SeedSets(tb testing.TB, r redis.UniversalClient, sets map[string][]string) { - for key, set := range sets { - SeedSet(tb, r, key, set) - } -} - -func SeedSet(tb testing.TB, r redis.UniversalClient, key string, members []string) { - for _, mem := range members { - if err := r.SAdd(context.Background(), key, mem).Err(); err != nil { - tb.Fatalf("Failed to seed set (key=%q): %v", key, err) - } - } -} - -func SeedLists(tb testing.TB, r redis.UniversalClient, lists map[string][]string) { - for key, vals := range lists { - for _, v := range vals { - if err := r.LPush(context.Background(), key, v).Err(); err != nil { - tb.Fatalf("Failed to seed list (key=%q): %v", key, err) - } - } - } -} - -// TODO: move this helper somewhere more canonical -func AssertZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) { - for key, want := range wantZSets { - got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result() - if err != nil { - t.Fatalf("Failed to read zset (key=%q): %v", key, err) - } - if diff := cmp.Diff(want, got, SortZSetEntryOpt); diff != "" { - t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff) - } - } -} - -var SortZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []redis.Z) []redis.Z { - out := append([]redis.Z(nil), in...) // Copy input to avoid mutating it - sort.Slice(out, func(i, j int) bool { - // TODO: If member is a comparable type (int, string, etc) compare by the member - // Use generic comparable type here once update to go1.18 - if _, ok := out[i].Member.(string); ok { - // If member is a string, compare the member - return out[i].Member.(string) < out[j].Member.(string) - } - return out[i].Score < out[j].Score - }) - return out -}) - func TestListGroups(t *testing.T) { r := setup(t) defer r.Close()