diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index e4f03c6..235de12 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3356,65 +3356,131 @@ func TestAggregationCheck(t *testing.T) { } } -// TODO: Rewrite this test with the new pattern of using redis key-value as data. func TestDeleteAggregationSet(t *testing.T) { r := setup(t) defer r.Close() now := time.Now() - ctx := context.Background() setID := uuid.NewString() - msg1 := h.NewTaskMessageBuilder().SetType("foo").SetQueue("default").SetGroup("mygroup").Build() - msg2 := h.NewTaskMessageBuilder().SetType("bar").SetQueue("default").SetGroup("mygroup").Build() - msg3 := h.NewTaskMessageBuilder().SetType("baz").SetQueue("default").SetGroup("mygroup").Build() + otherSetID := uuid.NewString() + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() tests := []struct { - aggregationSet []base.Z - qname string - gname string - setID string + desc string + // initial data + tasks []*taskData + aggregationSets map[string][]*redis.Z + allAggregationSets map[string][]*redis.Z + + // args + ctx context.Context + qname string + gname string + setID string + + // expectations + wantDeletedKeys []string // redis key to check for non existence + wantAggregationSets map[string][]redis.Z + wantAllAggregationSets map[string][]redis.Z }{ { - aggregationSet: []base.Z{ - {msg1, now.Add(-3 * time.Minute).Unix()}, - {msg2, now.Add(-2 * time.Minute).Unix()}, - {msg3, now.Add(-1 * time.Minute).Unix()}, + desc: "with a sigle active aggregation set", + tasks: []*taskData{ + {msg: m1, state: base.TaskStateAggregating}, + {msg: m2, state: base.TaskStateAggregating}, + {msg: m3, state: base.TaskStateAggregating}, }, + aggregationSets: map[string][]*redis.Z{ + base.AggregationSetKey("default", "mygroup", setID): { + {Member: m1.ID, Score: float64(now.Add(-5 * time.Minute).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + {Member: m3.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + }, + }, + allAggregationSets: map[string][]*redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "mygroup", setID), Score: float64(now.Add(aggregationTimeout).Unix())}, + }, + }, + ctx: context.Background(), qname: "default", gname: "mygroup", setID: setID, + wantDeletedKeys: []string{ + base.AggregationSetKey("default", "mygroup", setID), + base.TaskKey(m1.Queue, m1.ID), + base.TaskKey(m2.Queue, m2.ID), + base.TaskKey(m3.Queue, m3.ID), + }, + wantAggregationSets: map[string][]redis.Z{}, + wantAllAggregationSets: map[string][]redis.Z{ + base.AllAggregationSets("default"): {}, + }, + }, + { + desc: "with multiple active aggregation sets", + tasks: []*taskData{ + {msg: m1, state: base.TaskStateAggregating}, + {msg: m2, state: base.TaskStateAggregating}, + {msg: m3, state: base.TaskStateAggregating}, + }, + aggregationSets: map[string][]*redis.Z{ + base.AggregationSetKey("default", "mygroup", setID): { + {Member: m1.ID, Score: float64(now.Add(-5 * time.Minute).Unix())}, + }, + base.AggregationSetKey("default", "mygroup", otherSetID): { + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + {Member: m3.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + }, + }, + allAggregationSets: map[string][]*redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "mygroup", setID), Score: float64(now.Add(aggregationTimeout).Unix())}, + {Member: base.AggregationSetKey("default", "mygroup", otherSetID), Score: float64(now.Add(aggregationTimeout).Unix())}, + }, + }, + ctx: context.Background(), + qname: "default", + gname: "mygroup", + setID: setID, + wantDeletedKeys: []string{ + base.AggregationSetKey("default", "mygroup", setID), + base.TaskKey(m1.Queue, m1.ID), + }, + wantAggregationSets: map[string][]redis.Z{ + base.AggregationSetKey("default", "mygroup", otherSetID): { + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + {Member: m3.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + }, + }, + wantAllAggregationSets: map[string][]redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "mygroup", otherSetID), Score: float64(now.Add(aggregationTimeout).Unix())}, + }, + }, }, } for _, tc := range tests { h.FlushDB(t, r.client) - h.SeedAggregationSet(t, r.client, tc.aggregationSet, tc.qname, tc.gname, tc.setID) - key := base.AggregationSetKey(tc.qname, tc.gname, tc.setID) - if err := r.client.ZAdd(context.Background(), - base.AllAggregationSets(tc.qname), - &redis.Z{Member: key, Score: float64(now.Add(aggregationTimeout).Unix())}).Err(); err != nil { - t.Fatal(err) - } + t.Run(tc.desc, func(t *testing.T) { + SeedTasks(t, r.client, tc.tasks) + SeedZSets(t, r.client, tc.aggregationSets) + SeedZSets(t, r.client, tc.allAggregationSets) - if err := r.DeleteAggregationSet(ctx, tc.qname, tc.gname, tc.setID); err != nil { - t.Fatalf("DeleteAggregationSet returned error: %v", err) - } - // Check if the set is deleted. - if r.client.Exists(ctx, key).Val() != 0 { - t.Errorf("aggregation set key %q still exists", key) - } - - // Check all tasks in the set are deleted. - for _, z := range tc.aggregationSet { - taskKey := base.TaskKey(z.Message.Queue, z.Message.ID) - if r.client.Exists(ctx, taskKey).Val() != 0 { - t.Errorf("task key %q still exists", taskKey) + if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err != nil { + t.Fatalf("DeleteAggregationSet returned error: %v", err) } - } - if _, err := r.client.ZScore(ctx, base.AllAggregationSets(tc.qname), key).Result(); err != redis.Nil { - t.Errorf("aggregation_set key %q is still in key %q", key, base.AllAggregationSets(tc.qname)) - } + for _, key := range tc.wantDeletedKeys { + if r.client.Exists(context.Background(), key).Val() != 0 { + t.Errorf("key=%q still exists, want deleted", key) + } + } + AssertZSets(t, r.client, tc.wantAllAggregationSets) + }) } } @@ -3505,6 +3571,40 @@ func TestReclaimStaleAggregationSets(t *testing.T) { } } +// taskData holds the data required to seed tasks under the task key in test. +type taskData struct { + msg *base.TaskMessage + state base.TaskState + pendingSince time.Time +} + +// TODO: move this helper somewhere more canonical +func SeedTasks(tb testing.TB, r redis.UniversalClient, taskData []*taskData) { + for _, data := range taskData { + msg := data.msg + ctx := context.Background() + key := base.TaskKey(msg.Queue, msg.ID) + v := map[string]interface{}{ + "msg": h.MustMarshal(tb, msg), + "state": data.state.String(), + "unique_key": msg.UniqueKey, + "group": msg.GroupKey, + } + if !data.pendingSince.IsZero() { + v["pending_since"] = data.pendingSince.Unix() + } + if err := r.HSet(ctx, key, v).Err(); err != nil { + tb.Fatalf("Failed to write task data in redis: %v", err) + } + if len(msg.UniqueKey) > 0 { + err := r.SetNX(ctx, msg.UniqueKey, msg.ID, 1*time.Minute).Err() + if err != nil { + tb.Fatalf("Failed to set unique lock in redis: %v", err) + } + } + } +} + // TODO: move this helper somewhere more canonical func SeedZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) { for key, zs := range zsets {