From d841dc2f8d916af59042681165628d3ca553831d Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Tue, 8 Mar 2022 06:38:35 -0800 Subject: [PATCH] Add initial implementation of aggregator --- aggregator.go | 129 +++++++++++++++++++++++++++++++++++++ internal/base/base.go | 31 +++++++-- internal/base/base_test.go | 39 +++++++++-- internal/rdb/rdb.go | 26 ++++++++ 4 files changed, 216 insertions(+), 9 deletions(-) create mode 100644 aggregator.go diff --git a/aggregator.go b/aggregator.go new file mode 100644 index 0000000..27cfb36 --- /dev/null +++ b/aggregator.go @@ -0,0 +1,129 @@ +// Copyright 2022 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "context" + "sync" + "time" + + "github.com/hibiken/asynq/internal/base" + "github.com/hibiken/asynq/internal/log" +) + +// An aggregator is responsible for checking groups and aggregate into one task +// if any of the grouping condition is met. +type aggregator struct { + logger *log.Logger + broker base.Broker + client *Client + + // channel to communicate back to the long running "aggregator" goroutine. + done chan struct{} + + // list of queue names to check and aggregate. + queues []string + + // Group configurations + gracePeriod time.Duration + maxDelay time.Duration + maxSize int + + // Aggregation function + aggregateFunc func(gname string, tasks []*Task) *Task +} + +type aggregatorParams struct { + logger *log.Logger + broker base.Broker + queues []string + gracePeriod time.Duration + maxDelay time.Duration + maxSize int + aggregateFunc func(gname string, msgs []*Task) *Task +} + +func newAggregator(params aggregatorParams) *aggregator { + return &aggregator{ + logger: params.logger, + broker: params.broker, + client: &Client{broker: params.broker}, + done: make(chan struct{}), + queues: params.queues, + gracePeriod: params.gracePeriod, + maxDelay: params.maxDelay, + maxSize: params.maxSize, + aggregateFunc: params.aggregateFunc, + } +} + +func (a *aggregator) shutdown() { + a.logger.Debug("Aggregator shutting down...") + // Signal the aggregator goroutine to stop. + a.done <- struct{}{} +} + +func (a *aggregator) start(wg *sync.WaitGroup) { + wg.Add(1) + go func() { + defer wg.Done() + timer := time.NewTimer(a.gracePeriod) + for { + select { + case <-a.done: + a.logger.Debug("Aggregator done") + return + case <-timer.C: + a.exec() + timer.Reset(a.gracePeriod) + } + } + }() +} + +func (a *aggregator) exec() { + for _, qname := range a.queues { + groups, err := a.broker.ListGroups(qname) + if err != nil { + a.logger.Errorf("Failed to list groups in queue: %q", qname) + continue + } + for _, gname := range groups { + aggregationSetID, err := a.broker.AggregationCheck(qname, gname) + if err != nil { + a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname) + continue + } + if aggregationSetID == "" { + a.logger.Debugf("No aggregation needed at this time: queue=%q group=%q", qname, gname) + continue + } + + // Aggregate and enqueue. + msgs, deadline, err := a.broker.ReadAggregationSet(qname, gname, aggregationSetID) + if err != nil { + a.logger.Errorf("Failed to read aggregation set: queue=%q, group=%q, setID=%q", + qname, gname, aggregationSetID) + continue + } + tasks := make([]*Task, len(msgs)) + for i, m := range msgs { + tasks[i] = NewTask(m.Type, m.Payload) + } + aggregatedTask := a.aggregateFunc(gname, tasks) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil { + a.logger.Errorf("Failed to enqueue aggregated task: %v", err) + cancel() + continue + } + if err := a.broker.DeleteAggregationSet(ctx, qname, gname, aggregationSetID); err != nil { + a.logger.Warnf("Failed to delete aggregation set: queue=%q, group=%q, setID=%q", + qname, gname, aggregationSetID) + } + cancel() + } + } +} diff --git a/internal/base/base.go b/internal/base/base.go index 04f1a00..6b9606d 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -215,14 +215,20 @@ func GroupKey(qname, gkey string) string { return fmt.Sprintf("%s%s", GroupKeyPrefix(qname), gkey) } +// AggregationSetKey returns a redis key used for an aggregation set. +func AggregationSetKey(qname, gname, setID string) string { + return fmt.Sprintf("%s:%s", GroupKey(qname, gname), setID) +} + // AllGroups return a redis key used to store all group keys used in a given queue. func AllGroups(qname string) string { return fmt.Sprintf("%sgroups", QueueKeyPrefix(qname)) } -// AllStagedGroups returns a redis key used to store all groups staged to be aggregated in a given queue. -func AllStagedGroups(qname string) string { - return fmt.Sprintf("%sstaged_groups", QueueKeyPrefix(qname)) +// AllAggregationSets returns a redis key used to store all aggregation sets (set of tasks staged to be aggregated) +// in a given queue. +func AllAggregationSets(qname string) string { + return fmt.Sprintf("%saggregation_sets", QueueKeyPrefix(qname)) } // TaskMessage is the internal representation of a task with additional metadata fields. @@ -708,6 +714,7 @@ func (l *Lease) IsValid() bool { // See rdb.RDB as a reference implementation. type Broker interface { Ping() error + Close() error Enqueue(ctx context.Context, msg *TaskMessage) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error Dequeue(qnames ...string) (*TaskMessage, time.Time, error) @@ -719,13 +726,29 @@ type Broker interface { Retry(ctx context.Context, msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error Archive(ctx context.Context, msg *TaskMessage, errMsg string) error ForwardIfReady(qnames ...string) error + + // Group aggregation related methods + AddToGroup(ctx context.Context, msg *TaskMessage, gname string) error + AddToGroupUnique(ctx context.Context, msg *TaskMessage, groupKey string, ttl time.Duration) error + ListGroups(qname string) ([]string, error) + AggregationCheck(qname, gname string) (aggregationSetID string, err error) + ReadAggregationSet(qname, gname, aggregationSetID string) ([]*TaskMessage, time.Time, error) + DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error + + // Task retention related method DeleteExpiredCompletedTasks(qname string) error + + // Lease related methods ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*TaskMessage, error) ExtendLease(qname string, ids ...string) (time.Time, error) + + // State snapshot related methods WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error ClearServerState(host string, pid int, serverID string) error + + // Cancelation related methods CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers PublishCancelation(id string) error + WriteResult(qname, id string, data []byte) (n int, err error) - Close() error } diff --git a/internal/base/base_test.go b/internal/base/base_test.go index 704b02d..3ac2380 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -421,6 +421,35 @@ func TestGroupKey(t *testing.T) { } } +func TestAggregationSetKey(t *testing.T) { + tests := []struct { + qname string + gname string + setID string + want string + }{ + { + qname: "default", + gname: "mygroup", + setID: "12345", + want: "asynq:{default}:g:mygroup:12345", + }, + { + qname: "custom", + gname: "foo", + setID: "98765", + want: "asynq:{custom}:g:foo:98765", + }, + } + + for _, tc := range tests { + got := AggregationSetKey(tc.qname, tc.gname, tc.setID) + if got != tc.want { + t.Errorf("AggregationSetKey(%q, %q, %q) = %q, want %q", tc.qname, tc.gname, tc.setID, got, tc.want) + } + } +} + func TestAllGroups(t *testing.T) { tests := []struct { qname string @@ -444,25 +473,25 @@ func TestAllGroups(t *testing.T) { } } -func TestAllStagedGroups(t *testing.T) { +func TestAllAggregationSets(t *testing.T) { tests := []struct { qname string want string }{ { qname: "default", - want: "asynq:{default}:staged_groups", + want: "asynq:{default}:aggregation_sets", }, { qname: "custom", - want: "asynq:{custom}:staged_groups", + want: "asynq:{custom}:aggregation_sets", }, } for _, tc := range tests { - got := AllStagedGroups(tc.qname) + got := AllAggregationSets(tc.qname) if got != tc.want { - t.Errorf("AllStagedGroups(%q) = %q, want %q", tc.qname, got, tc.want) + t.Errorf("AllAggregationSets(%q) = %q, want %q", tc.qname, got, tc.want) } } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 7b9fab6..4f71f14 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -982,6 +982,32 @@ func (r *RDB) forwardAll(qname string) (err error) { return nil } +// ListGroups returns a list of all known groups in the given queue. +func (r *RDB) ListGroups(qname string) ([]string, error) { + // TODO: Implement this with TDD + return nil, nil +} + +// AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the +// group are ready to be aggregated. If so, it moves the tasks to be aggregated to a aggregation set and returns +// set ID. If not, it returns an empty string for the set ID. +func (r *RDB) AggregationCheck(qname, gname string) (string, error) { + // TODO: Implement this with TDD + return "", nil +} + +// ReadAggregationSet retrieves memebers of an aggregation set and returns list of tasks and +// the deadline for aggregating those tasks. +func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessage, time.Time, error) { + // TODO: Implement this with TDD + return nil, time.Time{}, nil +} + +// DeleteAggregationSet deletes the aggregation set identified by the parameters. +func (r *RDB) DeleteAggregationSet(ctx context.Context, qname, gname, setID string) error { + return nil +} + // KEYS[1] -> asynq:{}:completed // ARGV[1] -> current time in unix time // ARGV[2] -> task key prefix