Define GroupAggregator interface

This commit is contained in:
Ken Hibino
2022-04-07 06:13:49 -07:00
parent a369443955
commit 829f64fd38
3 changed files with 55 additions and 39 deletions

View File

@@ -31,8 +31,8 @@ type aggregator struct {
maxDelay time.Duration
maxSize int
// Aggregation function
aggregateFunc func(gname string, tasks []*Task) *Task
// User provided group aggregator.
ga GroupAggregator
// interval used to check for aggregation
interval time.Duration
@@ -43,13 +43,13 @@ type aggregator struct {
}
type aggregatorParams struct {
logger *log.Logger
broker base.Broker
queues []string
gracePeriod time.Duration
maxDelay time.Duration
maxSize int
aggregateFunc func(gname string, msgs []*Task) *Task
logger *log.Logger
broker base.Broker
queues []string
gracePeriod time.Duration
maxDelay time.Duration
maxSize int
groupAggregator GroupAggregator
}
const (
@@ -67,22 +67,22 @@ func newAggregator(params aggregatorParams) *aggregator {
interval = params.gracePeriod
}
return &aggregator{
logger: params.logger,
broker: params.broker,
client: &Client{broker: params.broker},
done: make(chan struct{}),
queues: params.queues,
gracePeriod: params.gracePeriod,
maxDelay: params.maxDelay,
maxSize: params.maxSize,
aggregateFunc: params.aggregateFunc,
sema: make(chan struct{}, maxConcurrentAggregationChecks),
interval: interval,
logger: params.logger,
broker: params.broker,
client: &Client{broker: params.broker},
done: make(chan struct{}),
queues: params.queues,
gracePeriod: params.gracePeriod,
maxDelay: params.maxDelay,
maxSize: params.maxSize,
ga: params.groupAggregator,
sema: make(chan struct{}, maxConcurrentAggregationChecks),
interval: interval,
}
}
func (a *aggregator) shutdown() {
if a.aggregateFunc == nil {
if a.ga == nil {
return
}
a.logger.Debug("Aggregator shutting down...")
@@ -91,7 +91,7 @@ func (a *aggregator) shutdown() {
}
func (a *aggregator) start(wg *sync.WaitGroup) {
if a.aggregateFunc == nil {
if a.ga == nil {
return
}
wg.Add(1)
@@ -158,7 +158,7 @@ func (a *aggregator) aggregate(t time.Time) {
for i, m := range msgs {
tasks[i] = NewTask(m.Type, m.Payload)
}
aggregatedTask := a.aggregateFunc(gname, tasks)
aggregatedTask := a.ga.Aggregate(gname, tasks)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil {
a.logger.Errorf("Failed to enqueue aggregated task (queue=%q, group=%q, setID=%q): %v",