diff --git a/internal/base/base.go b/internal/base/base.go index 210ff37..0abb9e9 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -734,6 +734,7 @@ type Broker interface { AggregationCheck(qname, gname string, t time.Time, gracePeriod, maxDelay time.Duration, maxSize int) (aggregationSetID string, err error) ReadAggregationSet(qname, gname, aggregationSetID string) ([]*TaskMessage, time.Time, error) DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error + ReclaimStaleAggregationSets(qname string) error // Task retention related method DeleteExpiredCompletedTasks(qname string) error diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 2b7efa4..9d8a15d 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -1156,6 +1156,12 @@ func (r *RDB) DeleteAggregationSet(ctx context.Context, qname, gname, setID stri return r.runScript(ctx, op, deleteAggregationSetCmd, []string{base.AggregationSetKey(qname, gname, setID)}, base.TaskKeyPrefix(qname)) } +// ReclaimStateAggregationSets checks for any stale aggregation sets in the given queue, and +// reclaim tasks in the stale aggregation set by putting them back in the group. +func (r *RDB) ReclaimStaleAggregationSets(qname string) error { + return nil +} + // KEYS[1] -> asynq:{}:completed // ARGV[1] -> current time in unix time // ARGV[2] -> task key prefix diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 0d61c45..aa58a99 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -3402,6 +3402,118 @@ func TestDeleteAggregationSet(t *testing.T) { } } +func TestReclaimStaleAggregationSets(t *testing.T) { + r := setup(t) + defer r.Close() + + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + + m1 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("foo").Build() + m2 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("foo").Build() + m3 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("bar").Build() + m4 := h.NewTaskMessageBuilder().SetQueue("default").SetGroup("qux").Build() + + // Note: In this test, we're trying out a new way to test RDB by exactly describing how + // keys and values are represented in Redis. + tests := []struct { + groups map[string][]*redis.Z // map redis-key to redis-zset + aggregationSets map[string][]*redis.Z + allAggregationSets map[string][]*redis.Z + qname string + wantGroups map[string][]redis.Z + wantAggregationSets map[string][]redis.Z + wantAllAggregationSets map[string][]redis.Z + }{ + { + groups: map[string][]*redis.Z{ + base.GroupKey("default", "foo"): {}, + base.GroupKey("default", "bar"): {}, + base.GroupKey("default", "qux"): { + {Member: m4.ID, Score: float64(now.Add(-10 * time.Second).Unix())}, + }, + }, + aggregationSets: map[string][]*redis.Z{ + base.AggregationSetKey("default", "foo", "set1"): { + {Member: m1.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + }, + base.AggregationSetKey("default", "bar", "set2"): { + {Member: m3.ID, Score: float64(now.Add(-1 * time.Minute).Unix())}, + }, + }, + allAggregationSets: map[string][]*redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "foo", "set1"), Score: float64(now.Add(-10 * time.Second).Unix())}, // set1 is expired + {Member: base.AggregationSetKey("default", "bar", "set2"), Score: float64(now.Add(40 * time.Second).Unix())}, // set2 is not expired + }, + }, + qname: "default", + wantGroups: map[string][]redis.Z{ + base.GroupKey("default", "foo"): { + {Member: m1.ID, Score: float64(now.Add(-3 * time.Minute).Unix())}, + {Member: m2.ID, Score: float64(now.Add(-4 * time.Minute).Unix())}, + }, + base.GroupKey("default", "bar"): {}, + base.GroupKey("default", "qux"): { + {Member: m4.ID, Score: float64(now.Add(-10 * time.Second).Unix())}, + }, + }, + wantAggregationSets: map[string][]redis.Z{ + base.AggregationSetKey("default", "bar", "set2"): { + {Member: m3.ID, Score: float64(now.Add(-1 * time.Minute).Unix())}, + }, + }, + wantAllAggregationSets: map[string][]redis.Z{ + base.AllAggregationSets("default"): { + {Member: base.AggregationSetKey("default", "bar", "set2"), Score: float64(now.Add(40 * time.Second).Unix())}, + }, + }, + }, + } + + for _, tc := range tests { + h.FlushDB(t, r.client) + SeedZSets(t, r.client, tc.groups) + SeedZSets(t, r.client, tc.aggregationSets) + SeedZSets(t, r.client, tc.allAggregationSets) + + if err := r.ReclaimStaleAggregationSets(tc.qname); err != nil { + t.Errorf("ReclaimStaleAggregationSets returned error: %v", err) + continue + } + + AssertZSets(t, r.client, tc.wantGroups) + AssertZSets(t, r.client, tc.wantAggregationSets) + AssertZSets(t, r.client, tc.wantAllAggregationSets) + } +} + +// TODO: move this helper somewhere more canonical +func SeedZSets(tb testing.TB, r redis.UniversalClient, zsets map[string][]*redis.Z) { + for key, zs := range zsets { + // FIXME: How come we can't simply do ZAdd(ctx, key, zs...) here? + for _, z := range zs { + if err := r.ZAdd(context.Background(), key, z).Err(); err != nil { + tb.Fatalf("Failed to seed zset (key=%q): %v", key, err) + } + } + } +} + +// TODO: move this helper somewhere more canonical +func AssertZSets(t *testing.T, r redis.UniversalClient, wantZSets map[string][]redis.Z) { + for key, want := range wantZSets { + got, err := r.ZRangeWithScores(context.Background(), key, 0, -1).Result() + if err != nil { + t.Fatalf("Failed to read zset (key=%q): %v", key, err) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("mismatch found in zset (key=%q): (-want,+got)\n%s", key, diff) + } + } +} + func TestListGroups(t *testing.T) { r := setup(t) defer r.Close() diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 5d228d5..ef2717c 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -288,3 +288,12 @@ func (tb *TestBroker) DeleteAggregationSet(ctx context.Context, qname, gname, ag } return tb.real.DeleteAggregationSet(ctx, qname, gname, aggregationSetID) } + +func (tb *TestBroker) ReclaimStaleAggregationSets(qname string) error { + tb.mu.Lock() + defer tb.mu.Unlock() + if tb.sleeping { + return errRedisDown + } + return tb.real.ReclaimStaleAggregationSets(qname) +} diff --git a/recoverer.go b/recoverer.go index a0107b4..c7a41dd 100644 --- a/recoverer.go +++ b/recoverer.go @@ -82,11 +82,16 @@ func (r *recoverer) start(wg *sync.WaitGroup) { var ErrLeaseExpired = errors.New("asynq: task lease expired") func (r *recoverer) recover() { + r.recoverLeaseExpiredTasks() + r.recoverStaleAggregationSets() +} + +func (r *recoverer) recoverLeaseExpiredTasks() { // Get all tasks which have expired 30 seconds ago or earlier to accomodate certain amount of clock skew. cutoff := time.Now().Add(-30 * time.Second) msgs, err := r.broker.ListLeaseExpired(cutoff, r.queues...) if err != nil { - r.logger.Warn("recoverer: could not list lease expired tasks") + r.logger.Warnf("recoverer: could not list lease expired tasks: %v", err) return } for _, msg := range msgs { @@ -98,6 +103,14 @@ func (r *recoverer) recover() { } } +func (r *recoverer) recoverStaleAggregationSets() { + for _, qname := range r.queues { + if err := r.broker.ReclaimStaleAggregationSets(qname); err != nil { + r.logger.Warnf("recoverer: could not reclaim stale aggregation sets in queue %q: %v", qname, err) + } + } +} + func (r *recoverer) retry(msg *base.TaskMessage, err error) { delay := r.retryDelayFunc(msg.Retried, err, NewTask(msg.Type, msg.Payload)) retryAt := time.Now().Add(delay)