diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 235de12..1bfba53 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3484,6 +3484,93 @@ func TestDeleteAggregationSet(t *testing.T) { } } +func TestDeleteAggregationSetError(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + setID := uuid.NewString() + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("mygroup").Build() + deadlineExceededCtx, cancel := context.WithDeadline(context.Background(), now.Add(-10*time.Second)) + defer cancel() + + tests := []struct { + desc string + // initial data + tasks []*taskData + aggregationSets map[string][]*redis.Z + allAggregationSets map[string][]*redis.Z + + // args + ctx context.Context + qname string + gname string + setID string + + // expectations + wantAggregationSets map[string][]redis.Z + wantAllAggregationSets map[string][]redis.Z + }{ + { + desc: "with deadline exceeded context", + tasks: []*taskData{ + {msg: m1, state: base.TaskStateAggregating}, + {msg: m2, state: base.TaskStateAggregating}, + {msg: m3, state: base.TaskStateAggregating}, + }, + aggregationSets: map[string][]*redis.Z{ + base.AggregationSetKey("default", "mygroup", setID): { + {Member: m1.ID, Score: float64(now.Add(-5 * time.Minute).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + {Member: m3.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + }, + }, + allAggregationSets: map[string][]*redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "mygroup", setID), Score: float64(now.Add(aggregationTimeout).Unix())}, + }, + }, + ctx: deadlineExceededCtx, + qname: "default", + gname: "mygroup", + setID: setID, + // want data unchanged. + wantAggregationSets: map[string][]redis.Z{ + base.AggregationSetKey("default", "mygroup", setID): { + {Member: m1.ID, Score: float64(now.Add(-5 * time.Minute).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + {Member: m3.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + }, + }, + // want data unchanged. + wantAllAggregationSets: map[string][]redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "mygroup", setID), Score: float64(now.Add(aggregationTimeout).Unix())}, + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + t.Run(tc.desc, func(t *testing.T) { + SeedTasks(t, r.client, tc.tasks) + SeedZSets(t, r.client, tc.aggregationSets) + SeedZSets(t, r.client, tc.allAggregationSets) + + if err := r.DeleteAggregationSet(tc.ctx, tc.qname, tc.gname, tc.setID); err == nil { + t.Fatal("DeleteAggregationSet returned nil, want non-nil error") + } + + // Make sure zsets are unchanged. + AssertZSets(t, r.client, tc.wantAggregationSets) + AssertZSets(t, r.client, tc.wantAllAggregationSets) + }) + } +} + func TestReclaimStaleAggregationSets(t *testing.T) { r := setup(t) defer r.Close()