diff --git a/aggregator.go b/aggregator.go index d4a5680..7340f3c 100644 --- a/aggregator.go +++ b/aggregator.go @@ -31,8 +31,8 @@ type aggregator struct { maxDelay time.Duration maxSize int - // Aggregation function - aggregateFunc func(gname string, tasks []*Task) *Task + // User provided group aggregator. + ga GroupAggregator // interval used to check for aggregation interval time.Duration @@ -43,13 +43,13 @@ type aggregator struct { } type aggregatorParams struct { - logger *log.Logger - broker base.Broker - queues []string - gracePeriod time.Duration - maxDelay time.Duration - maxSize int - aggregateFunc func(gname string, msgs []*Task) *Task + logger *log.Logger + broker base.Broker + queues []string + gracePeriod time.Duration + maxDelay time.Duration + maxSize int + groupAggregator GroupAggregator } const ( @@ -67,22 +67,22 @@ func newAggregator(params aggregatorParams) *aggregator { interval = params.gracePeriod } return &aggregator{ - logger: params.logger, - broker: params.broker, - client: &Client{broker: params.broker}, - done: make(chan struct{}), - queues: params.queues, - gracePeriod: params.gracePeriod, - maxDelay: params.maxDelay, - maxSize: params.maxSize, - aggregateFunc: params.aggregateFunc, - sema: make(chan struct{}, maxConcurrentAggregationChecks), - interval: interval, + logger: params.logger, + broker: params.broker, + client: &Client{broker: params.broker}, + done: make(chan struct{}), + queues: params.queues, + gracePeriod: params.gracePeriod, + maxDelay: params.maxDelay, + maxSize: params.maxSize, + ga: params.groupAggregator, + sema: make(chan struct{}, maxConcurrentAggregationChecks), + interval: interval, } } func (a *aggregator) shutdown() { - if a.aggregateFunc == nil { + if a.ga == nil { return } a.logger.Debug("Aggregator shutting down...") @@ -91,7 +91,7 @@ func (a *aggregator) shutdown() { } func (a *aggregator) start(wg *sync.WaitGroup) { - if a.aggregateFunc == nil { + if a.ga == nil { return } wg.Add(1) @@ -158,7 +158,7 @@ func (a *aggregator) aggregate(t time.Time) { for i, m := range msgs { tasks[i] = NewTask(m.Type, m.Payload) } - aggregatedTask := a.aggregateFunc(gname, tasks) + aggregatedTask := a.ga.Aggregate(gname, tasks) ctx, cancel := context.WithDeadline(context.Background(), deadline) if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil { a.logger.Errorf("Failed to enqueue aggregated task (queue=%q, group=%q, setID=%q): %v", diff --git a/aggregator_test.go b/aggregator_test.go index d46d639..ccce306 100644 --- a/aggregator_test.go +++ b/aggregator_test.go @@ -120,13 +120,13 @@ func TestAggregator(t *testing.T) { h.FlushDB(t, r) aggregator := newAggregator(aggregatorParams{ - logger: testLogger, - broker: rdbClient, - queues: []string{"default"}, - gracePeriod: tc.gracePeriod, - maxDelay: tc.maxDelay, - maxSize: tc.maxSize, - aggregateFunc: tc.aggregateFunc, + logger: testLogger, + broker: rdbClient, + queues: []string{"default"}, + gracePeriod: tc.gracePeriod, + maxDelay: tc.maxDelay, + maxSize: tc.maxSize, + groupAggregator: GroupAggregatorFunc(tc.aggregateFunc), }) var wg sync.WaitGroup diff --git a/server.go b/server.go index ad6d7e3..945a7c3 100644 --- a/server.go +++ b/server.go @@ -216,10 +216,26 @@ type Config struct { // If unset or zero, no size limit is used. GroupMaxSize int - // GroupAggregateFunc specifies the aggregation function used to aggregate multiple tasks in a group into one task. + // GroupAggregator specifies the aggregation function used to aggregate multiple tasks in a group into one task. // // If unset or nil, the group aggregation feature will be disabled on the server. - GroupAggregateFunc func(groupKey string, tasks []*Task) *Task + GroupAggregator GroupAggregator +} + +// GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler. +type GroupAggregator interface { + // Aggregate aggregates the given tasks which belong to a same group + // and returns a new task which is the aggregation of those tasks. + Aggregate(groupKey string, tasks []*Task) *Task +} + +// The GroupAggregatorFunc type is an adapter to allow the use of ordinary functions as a GroupAggregator. +// If f is a function with the appropriate signature, GroupAggregatorFunc(f) is a GroupAggregator that calls f. +type GroupAggregatorFunc func(groupKey string, tasks []*Task) *Task + +// Aggregate calls fn(groupKey, tasks) +func (fn GroupAggregatorFunc) Aggregate(groupKey string, tasks []*Task) *Task { + return fn(groupKey, tasks) } // An ErrorHandler handles an error occured during task processing. @@ -506,13 +522,13 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { interval: 8 * time.Second, }) aggregator := newAggregator(aggregatorParams{ - logger: logger, - broker: rdb, - queues: qnames, - gracePeriod: groupGracePeriod, - maxDelay: cfg.GroupMaxDelay, - maxSize: cfg.GroupMaxSize, - aggregateFunc: cfg.GroupAggregateFunc, + logger: logger, + broker: rdb, + queues: qnames, + gracePeriod: groupGracePeriod, + maxDelay: cfg.GroupMaxDelay, + maxSize: cfg.GroupMaxSize, + groupAggregator: cfg.GroupAggregator, }) return &Server{ logger: logger,