2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-10 11:31:58 +08:00

Add initial implementation of aggregator

This commit is contained in:
Ken Hibino 2022-03-08 06:38:35 -08:00
parent ab28234767
commit d841dc2f8d
4 changed files with 216 additions and 9 deletions

129
aggregator.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright 2022 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package asynq
import (
"context"
"sync"
"time"
"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/log"
)
// An aggregator is responsible for checking groups and aggregate into one task
// if any of the grouping condition is met.
type aggregator struct {
logger *log.Logger
broker base.Broker
client *Client
// channel to communicate back to the long running "aggregator" goroutine.
done chan struct{}
// list of queue names to check and aggregate.
queues []string
// Group configurations
gracePeriod time.Duration
maxDelay time.Duration
maxSize int
// Aggregation function
aggregateFunc func(gname string, tasks []*Task) *Task
}
type aggregatorParams struct {
logger *log.Logger
broker base.Broker
queues []string
gracePeriod time.Duration
maxDelay time.Duration
maxSize int
aggregateFunc func(gname string, msgs []*Task) *Task
}
func newAggregator(params aggregatorParams) *aggregator {
return &aggregator{
logger: params.logger,
broker: params.broker,
client: &Client{broker: params.broker},
done: make(chan struct{}),
queues: params.queues,
gracePeriod: params.gracePeriod,
maxDelay: params.maxDelay,
maxSize: params.maxSize,
aggregateFunc: params.aggregateFunc,
}
}
func (a *aggregator) shutdown() {
a.logger.Debug("Aggregator shutting down...")
// Signal the aggregator goroutine to stop.
a.done <- struct{}{}
}
func (a *aggregator) start(wg *sync.WaitGroup) {
wg.Add(1)
go func() {
defer wg.Done()
timer := time.NewTimer(a.gracePeriod)
for {
select {
case <-a.done:
a.logger.Debug("Aggregator done")
return
case <-timer.C:
a.exec()
timer.Reset(a.gracePeriod)
}
}
}()
}
func (a *aggregator) exec() {
for _, qname := range a.queues {
groups, err := a.broker.ListGroups(qname)
if err != nil {
a.logger.Errorf("Failed to list groups in queue: %q", qname)
continue
}
for _, gname := range groups {
aggregationSetID, err := a.broker.AggregationCheck(qname, gname)
if err != nil {
a.logger.Errorf("Failed to run aggregation check: queue=%q group=%q", qname, gname)
continue
}
if aggregationSetID == "" {
a.logger.Debugf("No aggregation needed at this time: queue=%q group=%q", qname, gname)
continue
}
// Aggregate and enqueue.
msgs, deadline, err := a.broker.ReadAggregationSet(qname, gname, aggregationSetID)
if err != nil {
a.logger.Errorf("Failed to read aggregation set: queue=%q, group=%q, setID=%q",
qname, gname, aggregationSetID)
continue
}
tasks := make([]*Task, len(msgs))
for i, m := range msgs {
tasks[i] = NewTask(m.Type, m.Payload)
}
aggregatedTask := a.aggregateFunc(gname, tasks)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
if _, err := a.client.EnqueueContext(ctx, aggregatedTask); err != nil {
a.logger.Errorf("Failed to enqueue aggregated task: %v", err)
cancel()
continue
}
if err := a.broker.DeleteAggregationSet(ctx, qname, gname, aggregationSetID); err != nil {
a.logger.Warnf("Failed to delete aggregation set: queue=%q, group=%q, setID=%q",
qname, gname, aggregationSetID)
}
cancel()
}
}
}

View File

@ -215,14 +215,20 @@ func GroupKey(qname, gkey string) string {
return fmt.Sprintf("%s%s", GroupKeyPrefix(qname), gkey) return fmt.Sprintf("%s%s", GroupKeyPrefix(qname), gkey)
} }
// AggregationSetKey returns a redis key used for an aggregation set.
func AggregationSetKey(qname, gname, setID string) string {
return fmt.Sprintf("%s:%s", GroupKey(qname, gname), setID)
}
// AllGroups return a redis key used to store all group keys used in a given queue. // AllGroups return a redis key used to store all group keys used in a given queue.
func AllGroups(qname string) string { func AllGroups(qname string) string {
return fmt.Sprintf("%sgroups", QueueKeyPrefix(qname)) return fmt.Sprintf("%sgroups", QueueKeyPrefix(qname))
} }
// AllStagedGroups returns a redis key used to store all groups staged to be aggregated in a given queue. // AllAggregationSets returns a redis key used to store all aggregation sets (set of tasks staged to be aggregated)
func AllStagedGroups(qname string) string { // in a given queue.
return fmt.Sprintf("%sstaged_groups", QueueKeyPrefix(qname)) func AllAggregationSets(qname string) string {
return fmt.Sprintf("%saggregation_sets", QueueKeyPrefix(qname))
} }
// TaskMessage is the internal representation of a task with additional metadata fields. // TaskMessage is the internal representation of a task with additional metadata fields.
@ -708,6 +714,7 @@ func (l *Lease) IsValid() bool {
// See rdb.RDB as a reference implementation. // See rdb.RDB as a reference implementation.
type Broker interface { type Broker interface {
Ping() error Ping() error
Close() error
Enqueue(ctx context.Context, msg *TaskMessage) error Enqueue(ctx context.Context, msg *TaskMessage) error
EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error EnqueueUnique(ctx context.Context, msg *TaskMessage, ttl time.Duration) error
Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Dequeue(qnames ...string) (*TaskMessage, time.Time, error)
@ -719,13 +726,29 @@ type Broker interface {
Retry(ctx context.Context, msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error Retry(ctx context.Context, msg *TaskMessage, processAt time.Time, errMsg string, isFailure bool) error
Archive(ctx context.Context, msg *TaskMessage, errMsg string) error Archive(ctx context.Context, msg *TaskMessage, errMsg string) error
ForwardIfReady(qnames ...string) error ForwardIfReady(qnames ...string) error
// Group aggregation related methods
AddToGroup(ctx context.Context, msg *TaskMessage, gname string) error
AddToGroupUnique(ctx context.Context, msg *TaskMessage, groupKey string, ttl time.Duration) error
ListGroups(qname string) ([]string, error)
AggregationCheck(qname, gname string) (aggregationSetID string, err error)
ReadAggregationSet(qname, gname, aggregationSetID string) ([]*TaskMessage, time.Time, error)
DeleteAggregationSet(ctx context.Context, qname, gname, aggregationSetID string) error
// Task retention related method
DeleteExpiredCompletedTasks(qname string) error DeleteExpiredCompletedTasks(qname string) error
// Lease related methods
ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*TaskMessage, error) ListLeaseExpired(cutoff time.Time, qnames ...string) ([]*TaskMessage, error)
ExtendLease(qname string, ids ...string) (time.Time, error) ExtendLease(qname string, ids ...string) (time.Time, error)
// State snapshot related methods
WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error
ClearServerState(host string, pid int, serverID string) error ClearServerState(host string, pid int, serverID string) error
// Cancelation related methods
CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers
PublishCancelation(id string) error PublishCancelation(id string) error
WriteResult(qname, id string, data []byte) (n int, err error) WriteResult(qname, id string, data []byte) (n int, err error)
Close() error
} }

View File

@ -421,6 +421,35 @@ func TestGroupKey(t *testing.T) {
} }
} }
func TestAggregationSetKey(t *testing.T) {
tests := []struct {
qname string
gname string
setID string
want string
}{
{
qname: "default",
gname: "mygroup",
setID: "12345",
want: "asynq:{default}:g:mygroup:12345",
},
{
qname: "custom",
gname: "foo",
setID: "98765",
want: "asynq:{custom}:g:foo:98765",
},
}
for _, tc := range tests {
got := AggregationSetKey(tc.qname, tc.gname, tc.setID)
if got != tc.want {
t.Errorf("AggregationSetKey(%q, %q, %q) = %q, want %q", tc.qname, tc.gname, tc.setID, got, tc.want)
}
}
}
func TestAllGroups(t *testing.T) { func TestAllGroups(t *testing.T) {
tests := []struct { tests := []struct {
qname string qname string
@ -444,25 +473,25 @@ func TestAllGroups(t *testing.T) {
} }
} }
func TestAllStagedGroups(t *testing.T) { func TestAllAggregationSets(t *testing.T) {
tests := []struct { tests := []struct {
qname string qname string
want string want string
}{ }{
{ {
qname: "default", qname: "default",
want: "asynq:{default}:staged_groups", want: "asynq:{default}:aggregation_sets",
}, },
{ {
qname: "custom", qname: "custom",
want: "asynq:{custom}:staged_groups", want: "asynq:{custom}:aggregation_sets",
}, },
} }
for _, tc := range tests { for _, tc := range tests {
got := AllStagedGroups(tc.qname) got := AllAggregationSets(tc.qname)
if got != tc.want { if got != tc.want {
t.Errorf("AllStagedGroups(%q) = %q, want %q", tc.qname, got, tc.want) t.Errorf("AllAggregationSets(%q) = %q, want %q", tc.qname, got, tc.want)
} }
} }
} }

View File

@ -982,6 +982,32 @@ func (r *RDB) forwardAll(qname string) (err error) {
return nil return nil
} }
// ListGroups returns a list of all known groups in the given queue.
func (r *RDB) ListGroups(qname string) ([]string, error) {
// TODO: Implement this with TDD
return nil, nil
}
// AggregationCheck checks the group identified by the given queue and group name to see if the tasks in the
// group are ready to be aggregated. If so, it moves the tasks to be aggregated to a aggregation set and returns
// set ID. If not, it returns an empty string for the set ID.
func (r *RDB) AggregationCheck(qname, gname string) (string, error) {
// TODO: Implement this with TDD
return "", nil
}
// ReadAggregationSet retrieves memebers of an aggregation set and returns list of tasks and
// the deadline for aggregating those tasks.
func (r *RDB) ReadAggregationSet(qname, gname, setID string) ([]*base.TaskMessage, time.Time, error) {
// TODO: Implement this with TDD
return nil, time.Time{}, nil
}
// DeleteAggregationSet deletes the aggregation set identified by the parameters.
func (r *RDB) DeleteAggregationSet(ctx context.Context, qname, gname, setID string) error {
return nil
}
// KEYS[1] -> asynq:{<qname>}:completed // KEYS[1] -> asynq:{<qname>}:completed
// ARGV[1] -> current time in unix time // ARGV[1] -> current time in unix time
// ARGV[2] -> task key prefix // ARGV[2] -> task key prefix