2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-12-27 00:02:19 +08:00

Check for aggregation at an interval <= gracePeriod

This commit is contained in:
Ken Hibino 2022-03-09 06:01:37 -08:00
parent d841dc2f8d
commit 4542b52da8

View File

@ -33,6 +33,13 @@ type aggregator struct {
// Aggregation function // Aggregation function
aggregateFunc func(gname string, tasks []*Task) *Task aggregateFunc func(gname string, tasks []*Task) *Task
// interval used to check for aggregation
interval time.Duration
// sema is a counting semaphore to ensure the number of active aggregating function
// does not exceed the limit.
sema chan struct{}
} }
type aggregatorParams struct { type aggregatorParams struct {
@ -45,7 +52,20 @@ type aggregatorParams struct {
aggregateFunc func(gname string, msgs []*Task) *Task aggregateFunc func(gname string, msgs []*Task) *Task
} }
const (
// Maximum number of aggregation checks in flight concurrently.
maxConcurrentAggregationChecks = 3
// Default interval used for aggregation checks. If the provided gracePeriod is less than
// the default, use the gracePeriod.
defaultAggregationCheckInterval = 7 * time.Second
)
func newAggregator(params aggregatorParams) *aggregator { func newAggregator(params aggregatorParams) *aggregator {
interval := defaultAggregationCheckInterval
if params.gracePeriod < interval {
interval = params.gracePeriod
}
return &aggregator{ return &aggregator{
logger: params.logger, logger: params.logger,
broker: params.broker, broker: params.broker,
@ -56,6 +76,8 @@ func newAggregator(params aggregatorParams) *aggregator {
maxDelay: params.maxDelay, maxDelay: params.maxDelay,
maxSize: params.maxSize, maxSize: params.maxSize,
aggregateFunc: params.aggregateFunc, aggregateFunc: params.aggregateFunc,
sema: make(chan struct{}, maxConcurrentAggregationChecks),
interval: interval,
} }
} }
@ -69,21 +91,38 @@ func (a *aggregator) start(wg *sync.WaitGroup) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
timer := time.NewTimer(a.gracePeriod) ticker := time.NewTicker(a.interval)
for { for {
select { select {
case <-a.done: case <-a.done:
a.logger.Debug("Waiting for all aggregation checks to finish...")
// block until all aggregation checks released the token
for i := 0; i < cap(a.sema); i++ {
a.sema <- struct{}{}
}
a.logger.Debug("Aggregator done") a.logger.Debug("Aggregator done")
ticker.Stop()
return return
case <-timer.C: case <-ticker.C:
a.exec() a.exec()
timer.Reset(a.gracePeriod)
} }
} }
}() }()
} }
func (a *aggregator) exec() { func (a *aggregator) exec() {
select {
case a.sema <- struct{}{}: // acquire token
go a.aggregate()
default:
// If the semaphore blocks, then we are currently running max number of
// aggregation checks. Skip this round and log warning.
a.logger.Warnf("Max number of aggregation checks in flight. Skipping")
}
}
func (a *aggregator) aggregate() {
defer func() { <-a.sema /* release token */ }()
for _, qname := range a.queues { for _, qname := range a.queues {
groups, err := a.broker.ListGroups(qname) groups, err := a.broker.ListGroups(qname)
if err != nil { if err != nil {
@ -91,7 +130,7 @@ func (a *aggregator) exec() {
continue continue
} }
for _, gname := range groups { for _, gname := range groups {
aggregationSetID, err := a.broker.AggregationCheck(qname, gname) aggregationSetID, err := a.broker.AggregationCheck(qname, gname, a.gracePeriod, a.maxDelay, a.maxSize)
if err != nil { if err != nil {
a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname) a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname)
continue continue
@ -115,7 +154,8 @@ func (a *aggregator) exec() {
aggregatedTask := a.aggregateFunc(gname, tasks) aggregatedTask := a.aggregateFunc(gname, tasks)
ctx, cancel := context.WithDeadline(context.Background(), deadline) ctx, cancel := context.WithDeadline(context.Background(), deadline)
if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil { if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil {
a.logger.Errorf("Failed to enqueue aggregated task: %v", err) a.logger.Errorf("Failed to enqueue aggregated task (queue=%q, group=%q, setID=%q): %v",
qname, gname, aggregationSetID, err)
cancel() cancel()
continue continue
} }