2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-12-25 23:32:17 +08:00

Move test helpers to asynqtest package

This commit is contained in:
Ken Hibino 2022-03-19 07:12:41 -07:00
parent 0149396bae
commit 1acd62c760
4 changed files with 221 additions and 208 deletions

View File

@ -93,6 +93,20 @@ var SortStringSliceOpt = cmp.Transformer("SortStringSlice", func(in []string) []
return out return out
}) })
var SortRedisZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []redis.Z) []redis.Z {
out := append([]redis.Z(nil), in...) // Copy input to avoid mutating it
sort.Slice(out, func(i, j int) bool {
// TODO: If member is a comparable type (int, string, etc) compare by the member
// Use generic comparable type here once update to go1.18
if _, ok := out[i].Member.(string); ok {
// If member is a string, compare the member
return out[i].Member.(string) < out[j].Member.(string)
}
return out[i].Score < out[j].Score
})
return out
})
// IgnoreIDOpt is an cmp.Option to ignore ID field in task messages when comparing. // IgnoreIDOpt is an cmp.Option to ignore ID field in task messages when comparing.
var IgnoreIDOpt = cmpopts.IgnoreFields(base.TaskMessage{}, "ID") var IgnoreIDOpt = cmpopts.IgnoreFields(base.TaskMessage{}, "ID")
@ -522,3 +536,83 @@ func getMessagesFromZSetWithScores(tb testing.TB, r redis.UniversalClient,
} }
return res return res
} }
// TaskSeedData holds the data required to seed tasks under the task key in test.
type TaskSeedData struct {
Msg *base.TaskMessage
State base.TaskState
PendingSince time.Time
}
func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*TaskSeedData) {
for _, data := range taskData {
msg := data.Msg
ctx := context.Background()
key := base.TaskKey(msg.Queue, msg.ID)
v := map[string]interface{}{
"msg": MustMarshal(tb, msg),
"state": data.State.String(),
"unique_key": msg.UniqueKey,
"group": msg.GroupKey,
}
if !data.PendingSince.IsZero() {
v["pending_since"] = data.PendingSince.Unix()
}
if err := r.HSet(ctx, key, v).Err(); err != nil {
tb.Fatalf("Failed to write task data in redis: %v", err)
}
if len(msg.UniqueKey) > 0 {
err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err()
if err != nil {
tb.Fatalf("Failed to set unique lock in redis: %v", err)
}
}
}
}
func SeedRedisZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) {
for key, zs := range zsets {
// FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here?
for _, z := range zs {
if err := r.ZAdd(context.Background(), key, z).Err(); err != nil {
tb.Fatalf("Failed to seed zset (key=%q): %v", key, err)
}
}
}
}
func SeedRedisSets(tb testing.TB, r redis.UniversalClient, sets map[string][]string) {
for key, set := range sets {
SeedRedisSet(tb, r, key, set)
}
}
func SeedRedisSet(tb testing.TB, r redis.UniversalClient, key string, members []string) {
for _, mem := range members {
if err := r.SAdd(context.Background(), key, mem).Err(); err != nil {
tb.Fatalf("Failed to seed set (key=%q): %v", key, err)
}
}
}
func SeedRedisLists(tb testing.TB, r redis.UniversalClient, lists map[string][]string) {
for key, vals := range lists {
for _, v := range vals {
if err := r.LPush(context.Background(), key, v).Err(); err != nil {
tb.Fatalf("Failed to seed list (key=%q): %v", key, err)
}
}
}
}
func AssertRedisZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) {
for key, want := range wantZSets {
got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result()
if err != nil {
t.Fatalf("Failed to read zset (key=%q): %v", key, err)
}
if diff := cmp.Diff(want, got, SortRedisZSetEntryOpt); diff != "" {
t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff)
}
}
}

View File

@ -68,7 +68,7 @@ func TestCurrentStats(t *testing.T) {
r.SetClock(timeutil.NewSimulatedClock(now)) r.SetClock(timeutil.NewSimulatedClock(now))
tests := []struct { tests := []struct {
tasks []*taskData tasks []*h.TaskSeedData
allQueues []string allQueues []string
allGroups map[string][]string allGroups map[string][]string
pending map[string][]string pending map[string][]string
@ -88,14 +88,14 @@ func TestCurrentStats(t *testing.T) {
want *Stats want *Stats
}{ }{
{ {
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStatePending}, {Msg: m1, State: base.TaskStatePending},
{msg: m2, state: base.TaskStateActive}, {Msg: m2, State: base.TaskStateActive},
{msg: m3, state: base.TaskStateScheduled}, {Msg: m3, State: base.TaskStateScheduled},
{msg: m4, state: base.TaskStateScheduled}, {Msg: m4, State: base.TaskStateScheduled},
{msg: m5, state: base.TaskStatePending}, {Msg: m5, State: base.TaskStatePending},
{msg: m6, state: base.TaskStatePending}, {Msg: m6, State: base.TaskStatePending},
{msg: m7, state: base.TaskStateAggregating}, {Msg: m7, State: base.TaskStateAggregating},
}, },
allQueues: []string{"default", "critical", "low"}, allQueues: []string{"default", "critical", "low"},
allGroups: map[string][]string{ allGroups: map[string][]string{
@ -187,12 +187,12 @@ func TestCurrentStats(t *testing.T) {
}, },
}, },
{ {
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStatePending}, {Msg: m1, State: base.TaskStatePending},
{msg: m2, state: base.TaskStateActive}, {Msg: m2, State: base.TaskStateActive},
{msg: m3, state: base.TaskStateScheduled}, {Msg: m3, State: base.TaskStateScheduled},
{msg: m4, state: base.TaskStateScheduled}, {Msg: m4, State: base.TaskStateScheduled},
{msg: m6, state: base.TaskStatePending}, {Msg: m6, State: base.TaskStatePending},
}, },
allQueues: []string{"default", "critical", "low"}, allQueues: []string{"default", "critical", "low"},
pending: map[string][]string{ pending: map[string][]string{
@ -284,16 +284,16 @@ func TestCurrentStats(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
SeedSet(t, r.client, base.AllQueues, tc.allQueues) h.SeedRedisSet(t, r.client, base.AllQueues, tc.allQueues)
SeedSets(t, r.client, tc.allGroups) h.SeedRedisSets(t, r.client, tc.allGroups)
SeedTasks(t, r.client, tc.tasks) h.SeedTasks(t, r.client, tc.tasks)
SeedLists(t, r.client, tc.pending) h.SeedRedisLists(t, r.client, tc.pending)
SeedLists(t, r.client, tc.active) h.SeedRedisLists(t, r.client, tc.active)
SeedZSets(t, r.client, tc.scheduled) h.SeedRedisZSets(t, r.client, tc.scheduled)
SeedZSets(t, r.client, tc.retry) h.SeedRedisZSets(t, r.client, tc.retry)
SeedZSets(t, r.client, tc.archived) h.SeedRedisZSets(t, r.client, tc.archived)
SeedZSets(t, r.client, tc.completed) h.SeedRedisZSets(t, r.client, tc.completed)
SeedZSets(t, r.client, tc.groups) h.SeedRedisZSets(t, r.client, tc.groups)
ctx := context.Background() ctx := context.Background()
for qname, n := range tc.processed { for qname, n := range tc.processed {
r.client.Set(ctx, base.ProcessedKey(qname, now), n, 0) r.client.Set(ctx, base.ProcessedKey(qname, now), n, 0)
@ -434,16 +434,16 @@ func TestGroupStats(t *testing.T) {
now := time.Now() now := time.Now()
fixtures := struct { fixtures := struct {
tasks []*taskData tasks []*h.TaskSeedData
allGroups map[string][]string allGroups map[string][]string
groups map[string][]*redis.Z groups map[string][]*redis.Z
}{ }{
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStateAggregating}, {Msg: m1, State: base.TaskStateAggregating},
{msg: m2, state: base.TaskStateAggregating}, {Msg: m2, State: base.TaskStateAggregating},
{msg: m3, state: base.TaskStateAggregating}, {Msg: m3, State: base.TaskStateAggregating},
{msg: m4, state: base.TaskStateAggregating}, {Msg: m4, State: base.TaskStateAggregating},
{msg: m5, state: base.TaskStateAggregating}, {Msg: m5, State: base.TaskStateAggregating},
}, },
allGroups: map[string][]string{ allGroups: map[string][]string{
base.AllGroups("default"): {"group1", "group2"}, base.AllGroups("default"): {"group1", "group2"},
@ -499,9 +499,9 @@ func TestGroupStats(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
SeedTasks(t, r.client, fixtures.tasks) h.SeedTasks(t, r.client, fixtures.tasks)
SeedSets(t, r.client, fixtures.allGroups) h.SeedRedisSets(t, r.client, fixtures.allGroups)
SeedZSets(t, r.client, fixtures.groups) h.SeedRedisZSets(t, r.client, fixtures.groups)
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
got, err := r.GroupStats(tc.qname) got, err := r.GroupStats(tc.qname)

View File

@ -1017,6 +1017,10 @@ func (r *RDB) ListGroups(qname string) ([]string, error) {
// Output: // Output:
// Returns 0 if no aggregation set was created // Returns 0 if no aggregation set was created
// Returns 1 if an aggregation set was created // Returns 1 if an aggregation set was created
//
// Time Complexity:
// O(log(N) + M) with N being the number tasks in the group zset
// and M being the max size.
var aggregationCheckCmd = redis.NewScript(` var aggregationCheckCmd = redis.NewScript(`
local size = redis.call("ZCARD", KEYS[1]) local size = redis.call("ZCARD", KEYS[1])
if size == 0 then if size == 0 then
@ -1118,6 +1122,12 @@ func (r *RDB) AggregationCheck(qname, gname string, t time.Time, gracePeriod, ma
// KEYS[1] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id> // KEYS[1] -> asynq:{<qname>}:g:<gname>:<aggregation_set_id>
// ------ // ------
// ARGV[1] -> task key prefix // ARGV[1] -> task key prefix
//
// Output:
// Array of encoded task messages
//
// Time Complexity:
// O(N) with N being the number of tasks in the aggregation set.
var readAggregationSetCmd = redis.NewScript(` var readAggregationSetCmd = redis.NewScript(`
local msgs = {} local msgs = {}
local ids = redis.call("ZRANGE", KEYS[1], 0, -1) local ids = redis.call("ZRANGE", KEYS[1], 0, -1)
@ -1162,6 +1172,13 @@ func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessag
// KEYS[2] -> asynq:{<qname>}:aggregation_sets // KEYS[2] -> asynq:{<qname>}:aggregation_sets
// ------- // -------
// ARGV[1] -> task key prefix // ARGV[1] -> task key prefix
//
// Output:
// Redis status reply
//
// Time Complexity:
// max(O(N), O(log(M))) with N being the number of tasks in the aggregation set
// and M being the number of elements in the all-aggregation-sets list.
var deleteAggregationSetCmd = redis.NewScript(` var deleteAggregationSetCmd = redis.NewScript(`
local ids = redis.call("ZRANGE", KEYS[1], 0, -1) local ids = redis.call("ZRANGE", KEYS[1], 0, -1)
for _, id in ipairs(ids) do for _, id in ipairs(ids) do

View File

@ -9,7 +9,6 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"math" "math"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -3122,7 +3121,7 @@ func TestAggregationCheck(t *testing.T) {
tests := []struct { tests := []struct {
desc string desc string
// initial data // initial data
tasks []*taskData tasks []*h.TaskSeedData
groups map[string][]*redis.Z groups map[string][]*redis.Z
allGroups map[string][]string allGroups map[string][]string
@ -3141,7 +3140,7 @@ func TestAggregationCheck(t *testing.T) {
}{ }{
{ {
desc: "with an empty group", desc: "with an empty group",
tasks: []*taskData{}, tasks: []*h.TaskSeedData{},
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): {}, base.GroupKey("default", "mygroup"): {},
}, },
@ -3162,12 +3161,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with a group size reaching the max size", desc: "with a group size reaching the max size",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3195,12 +3194,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with group size greater than max size", desc: "with group size greater than max size",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3231,10 +3230,10 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with the most recent task older than grace period", desc: "with the most recent task older than grace period",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3260,12 +3259,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with the oldest task older than max delay", desc: "with the oldest task older than max delay",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3293,12 +3292,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with unlimited size", desc: "with unlimited size",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3332,12 +3331,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with unlimited size and passed grace period", desc: "with unlimited size and passed grace period",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3365,12 +3364,12 @@ func TestAggregationCheck(t *testing.T) {
}, },
{ {
desc: "with unlimited delay", desc: "with unlimited delay",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: msg1, state: base.TaskStateAggregating}, {Msg: msg1, State: base.TaskStateAggregating},
{msg: msg2, state: base.TaskStateAggregating}, {Msg: msg2, State: base.TaskStateAggregating},
{msg: msg3, state: base.TaskStateAggregating}, {Msg: msg3, State: base.TaskStateAggregating},
{msg: msg4, state: base.TaskStateAggregating}, {Msg: msg4, State: base.TaskStateAggregating},
{msg: msg5, state: base.TaskStateAggregating}, {Msg: msg5, State: base.TaskStateAggregating},
}, },
groups: map[string][]*redis.Z{ groups: map[string][]*redis.Z{
base.GroupKey("default", "mygroup"): { base.GroupKey("default", "mygroup"): {
@ -3408,9 +3407,9 @@ func TestAggregationCheck(t *testing.T) {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
SeedTasks(t, r.client, tc.tasks) h.SeedTasks(t, r.client, tc.tasks)
SeedZSets(t, r.client, tc.groups) h.SeedRedisZSets(t, r.client, tc.groups)
SeedSets(t, r.client, tc.allGroups) h.SeedRedisSets(t, r.client, tc.allGroups)
aggregationSetID, err := r.AggregationCheck(tc.qname, tc.gname, now, tc.gracePeriod, tc.maxDelay, tc.maxSize) aggregationSetID, err := r.AggregationCheck(tc.qname, tc.gname, now, tc.gracePeriod, tc.maxDelay, tc.maxSize)
if err != nil { if err != nil {
@ -3438,7 +3437,7 @@ func TestAggregationCheck(t *testing.T) {
} }
} }
AssertZSets(t, r.client, tc.wantGroups) h.AssertRedisZSets(t, r.client, tc.wantGroups)
if tc.shouldClearGroup { if tc.shouldClearGroup {
if key := base.GroupKey(tc.qname, tc.gname); r.client.Exists(ctx, key).Val() != 0 { if key := base.GroupKey(tc.qname, tc.gname); r.client.Exists(ctx, key).Val() != 0 {
@ -3473,7 +3472,7 @@ func TestDeleteAggregationSet(t *testing.T) {
tests := []struct { tests := []struct {
desc string desc string
// initial data // initial data
tasks []*taskData tasks []*h.TaskSeedData
aggregationSets map[string][]*redis.Z aggregationSets map[string][]*redis.Z
allAggregationSets map[string][]*redis.Z allAggregationSets map[string][]*redis.Z
@ -3490,10 +3489,10 @@ func TestDeleteAggregationSet(t *testing.T) {
}{ }{
{ {
desc: "with a sigle active aggregation set", desc: "with a sigle active aggregation set",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStateAggregating}, {Msg: m1, State: base.TaskStateAggregating},
{msg: m2, state: base.TaskStateAggregating}, {Msg: m2, State: base.TaskStateAggregating},
{msg: m3, state: base.TaskStateAggregating}, {Msg: m3, State: base.TaskStateAggregating},
}, },
aggregationSets: map[string][]*redis.Z{ aggregationSets: map[string][]*redis.Z{
base.AggregationSetKey("default", "mygroup", setID): { base.AggregationSetKey("default", "mygroup", setID): {
@ -3524,10 +3523,10 @@ func TestDeleteAggregationSet(t *testing.T) {
}, },
{ {
desc: "with multiple active aggregation sets", desc: "with multiple active aggregation sets",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStateAggregating}, {Msg: m1, State: base.TaskStateAggregating},
{msg: m2, state: base.TaskStateAggregating}, {Msg: m2, State: base.TaskStateAggregating},
{msg: m3, state: base.TaskStateAggregating}, {Msg: m3, State: base.TaskStateAggregating},
}, },
aggregationSets: map[string][]*redis.Z{ aggregationSets: map[string][]*redis.Z{
base.AggregationSetKey("default", "mygroup", setID): { base.AggregationSetKey("default", "mygroup", setID): {
@ -3569,9 +3568,9 @@ func TestDeleteAggregationSet(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
SeedTasks(t, r.client, tc.tasks) h.SeedTasks(t, r.client, tc.tasks)
SeedZSets(t, r.client, tc.aggregationSets) h.SeedRedisZSets(t, r.client, tc.aggregationSets)
SeedZSets(t, r.client, tc.allAggregationSets) h.SeedRedisZSets(t, r.client, tc.allAggregationSets)
if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err != nil { if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err != nil {
t.Fatalf("DeleteAggregationSet returned error: %v", err) t.Fatalf("DeleteAggregationSet returned error: %v", err)
@ -3582,7 +3581,7 @@ func TestDeleteAggregationSet(t *testing.T) {
t.Errorf("key=%q still exists, want deleted", key) t.Errorf("key=%q still exists, want deleted", key)
} }
} }
AssertZSets(t, r.client, tc.wantAllAggregationSets) h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets)
}) })
} }
} }
@ -3602,7 +3601,7 @@ func TestDeleteAggregationSetError(t *testing.T) {
tests := []struct { tests := []struct {
desc string desc string
// initial data // initial data
tasks []*taskData tasks []*h.TaskSeedData
aggregationSets map[string][]*redis.Z aggregationSets map[string][]*redis.Z
allAggregationSets map[string][]*redis.Z allAggregationSets map[string][]*redis.Z
@ -3618,10 +3617,10 @@ func TestDeleteAggregationSetError(t *testing.T) {
}{ }{
{ {
desc: "with deadline exceeded context", desc: "with deadline exceeded context",
tasks: []*taskData{ tasks: []*h.TaskSeedData{
{msg: m1, state: base.TaskStateAggregating}, {Msg: m1, State: base.TaskStateAggregating},
{msg: m2, state: base.TaskStateAggregating}, {Msg: m2, State: base.TaskStateAggregating},
{msg: m3, state: base.TaskStateAggregating}, {Msg: m3, State: base.TaskStateAggregating},
}, },
aggregationSets: map[string][]*redis.Z{ aggregationSets: map[string][]*redis.Z{
base.AggregationSetKey("default", "mygroup", setID): { base.AggregationSetKey("default", "mygroup", setID): {
@ -3659,17 +3658,17 @@ func TestDeleteAggregationSetError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
SeedTasks(t, r.client, tc.tasks) h.SeedTasks(t, r.client, tc.tasks)
SeedZSets(t, r.client, tc.aggregationSets) h.SeedRedisZSets(t, r.client, tc.aggregationSets)
SeedZSets(t, r.client, tc.allAggregationSets) h.SeedRedisZSets(t, r.client, tc.allAggregationSets)
if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err == nil { if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err == nil {
t.Fatal("DeleteAggregationSet returned nil, want non-nil error") t.Fatal("DeleteAggregationSet returned nil, want non-nil error")
} }
// Make sure zsets are unchanged. // Make sure zsets are unchanged.
AssertZSets(t, r.client, tc.wantAggregationSets) h.AssertRedisZSets(t, r.client, tc.wantAggregationSets)
AssertZSets(t, r.client, tc.wantAllAggregationSets) h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets)
}) })
} }
} }
@ -3746,118 +3745,21 @@ func TestReclaimStaleAggregationSets(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) h.FlushDB(t, r.client)
SeedZSets(t, r.client, tc.groups) h.SeedRedisZSets(t, r.client, tc.groups)
SeedZSets(t, r.client, tc.aggregationSets) h.SeedRedisZSets(t, r.client, tc.aggregationSets)
SeedZSets(t, r.client, tc.allAggregationSets) h.SeedRedisZSets(t, r.client, tc.allAggregationSets)
if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil { if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil {
t.Errorf("ReclaimStaleAggregationSets returned error: %v", err) t.Errorf("ReclaimStaleAggregationSets returned error: %v", err)
continue continue
} }
AssertZSets(t, r.client, tc.wantGroups) h.AssertRedisZSets(t, r.client, tc.wantGroups)
AssertZSets(t, r.client, tc.wantAggregationSets) h.AssertRedisZSets(t, r.client, tc.wantAggregationSets)
AssertZSets(t, r.client, tc.wantAllAggregationSets) h.AssertRedisZSets(t, r.client, tc.wantAllAggregationSets)
} }
} }
// taskData holds the data required to seed tasks under the task key in test.
type taskData struct {
msg *base.TaskMessage
state base.TaskState
pendingSince time.Time
}
// TODO: move this helper somewhere more canonical
func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*taskData) {
for _, data := range taskData {
msg := data.msg
ctx := context.Background()
key := base.TaskKey(msg.Queue, msg.ID)
v := map[string]interface{}{
"msg": h.MustMarshal(tb, msg),
"state": data.state.String(),
"unique_key": msg.UniqueKey,
"group": msg.GroupKey,
}
if !data.pendingSince.IsZero() {
v["pending_since"] = data.pendingSince.Unix()
}
if err := r.HSet(ctx, key, v).Err(); err != nil {
tb.Fatalf("Failed to write task data in redis: %v", err)
}
if len(msg.UniqueKey) > 0 {
err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err()
if err != nil {
tb.Fatalf("Failed to set unique lock in redis: %v", err)
}
}
}
}
// TODO: move this helper somewhere more canonical
func SeedZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) {
for key, zs := range zsets {
// FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here?
for _, z := range zs {
if err := r.ZAdd(context.Background(), key, z).Err(); err != nil {
tb.Fatalf("Failed to seed zset (key=%q): %v", key, err)
}
}
}
}
func SeedSets(tb testing.TB, r redis.UniversalClient, sets map[string][]string) {
for key, set := range sets {
SeedSet(tb, r, key, set)
}
}
func SeedSet(tb testing.TB, r redis.UniversalClient, key string, members []string) {
for _, mem := range members {
if err := r.SAdd(context.Background(), key, mem).Err(); err != nil {
tb.Fatalf("Failed to seed set (key=%q): %v", key, err)
}
}
}
func SeedLists(tb testing.TB, r redis.UniversalClient, lists map[string][]string) {
for key, vals := range lists {
for _, v := range vals {
if err := r.LPush(context.Background(), key, v).Err(); err != nil {
tb.Fatalf("Failed to seed list (key=%q): %v", key, err)
}
}
}
}
// TODO: move this helper somewhere more canonical
func AssertZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) {
for key, want := range wantZSets {
got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result()
if err != nil {
t.Fatalf("Failed to read zset (key=%q): %v", key, err)
}
if diff := cmp.Diff(want, got, SortZSetEntryOpt); diff != "" {
t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff)
}
}
}
var SortZSetEntryOpt = cmp.Transformer("SortZSetEntries", func(in []redis.Z) []redis.Z {
out := append([]redis.Z(nil), in...) // Copy input to avoid mutating it
sort.Slice(out, func(i, j int) bool {
// TODO: If member is a comparable type (int, string, etc) compare by the member
// Use generic comparable type here once update to go1.18
if _, ok := out[i].Member.(string); ok {
// If member is a string, compare the member
return out[i].Member.(string) < out[j].Member.(string)
}
return out[i].Score < out[j].Score
})
return out
})
func TestListGroups(t *testing.T) { func TestListGroups(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()