mirror of
https://github.com/hibiken/asynq.git
synced 2025-04-20 07:40:19 +08:00
feat: dynamic queue concurrency
This commit is contained in:
parent
be15ef61d6
commit
2bace4cce4
@ -81,6 +81,11 @@ func newAggregator(params aggregatorParams) *aggregator {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aggregator) resetState() {
|
||||
a.done = make(chan struct{})
|
||||
a.sema = make(chan struct{}, maxConcurrentAggregationChecks)
|
||||
}
|
||||
|
||||
func (a *aggregator) shutdown() {
|
||||
if a.ga == nil {
|
||||
return
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/hibiken/asynq/internal/log"
|
||||
"github.com/hibiken/asynq/internal/timeutil"
|
||||
|
@ -14,12 +14,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq/internal/errors"
|
||||
pb "github.com/hibiken/asynq/internal/proto"
|
||||
"github.com/hibiken/asynq/internal/timeutil"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/hibiken/asynq/internal/errors"
|
||||
pb "github.com/hibiken/asynq/internal/proto"
|
||||
"github.com/hibiken/asynq/internal/timeutil"
|
||||
)
|
||||
|
||||
// Version of asynq library and CLI.
|
||||
@ -722,4 +723,5 @@ type Broker interface {
|
||||
PublishCancelation(id string) error
|
||||
|
||||
WriteResult(qname, id string, data []byte) (n int, err error)
|
||||
SetQueueConcurrency(qname string, concurrency int)
|
||||
}
|
||||
|
@ -30,7 +30,9 @@ type Option func(r *RDB)
|
||||
|
||||
func WithQueueConcurrency(queueConcurrency map[string]int) Option {
|
||||
return func(r *RDB) {
|
||||
r.queueConcurrency = queueConcurrency
|
||||
for qname, concurrency := range queueConcurrency {
|
||||
r.queueConcurrency.Store(qname, concurrency)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -39,7 +41,7 @@ type RDB struct {
|
||||
client redis.UniversalClient
|
||||
clock timeutil.Clock
|
||||
queuesPublished sync.Map
|
||||
queueConcurrency map[string]int
|
||||
queueConcurrency sync.Map
|
||||
}
|
||||
|
||||
// NewRDB returns a new instance of RDB.
|
||||
@ -271,8 +273,8 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT
|
||||
base.LeaseKey(qname),
|
||||
}
|
||||
leaseExpirationTime = r.clock.Now().Add(LeaseDuration)
|
||||
queueConcurrency, ok := r.queueConcurrency[qname]
|
||||
if !ok || queueConcurrency <= 0 {
|
||||
queueConcurrency, ok := r.queueConcurrency.Load(qname)
|
||||
if !ok || queueConcurrency.(int) <= 0 {
|
||||
queueConcurrency = math.MaxInt
|
||||
}
|
||||
argv := []interface{}{
|
||||
@ -1581,3 +1583,7 @@ func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) {
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (r *RDB) SetQueueConcurrency(qname string, concurrency int) {
|
||||
r.queueConcurrency.Store(qname, concurrency)
|
||||
}
|
||||
|
@ -11,8 +11,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
)
|
||||
|
||||
var errRedisDown = errors.New("testutil: redis is down")
|
||||
@ -297,3 +298,7 @@ func (tb *TestBroker) ReclaimStaleAggregationSets(qname string) error {
|
||||
}
|
||||
return tb.real.ReclaimStaleAggregationSets(qname)
|
||||
}
|
||||
|
||||
func (tb *TestBroker) SetQueueConcurrency(qname string, concurrency int) {
|
||||
tb.real.SetQueueConcurrency(qname, concurrency)
|
||||
}
|
||||
|
18
processor.go
18
processor.go
@ -16,12 +16,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
"github.com/hibiken/asynq/internal/errors"
|
||||
"github.com/hibiken/asynq/internal/log"
|
||||
"github.com/hibiken/asynq/internal/timeutil"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type processor struct {
|
||||
@ -57,7 +58,7 @@ type processor struct {
|
||||
// channel to communicate back to the long running "processor" goroutine.
|
||||
// once is used to send value to the channel only once.
|
||||
done chan struct{}
|
||||
once sync.Once
|
||||
once *sync.Once
|
||||
|
||||
// quit channel is closed when the shutdown of the "processor" goroutine starts.
|
||||
quit chan struct{}
|
||||
@ -112,6 +113,7 @@ func newProcessor(params processorParams) *processor {
|
||||
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
|
||||
sema: make(chan struct{}, params.concurrency),
|
||||
done: make(chan struct{}),
|
||||
once: &sync.Once{},
|
||||
quit: make(chan struct{}),
|
||||
abort: make(chan struct{}),
|
||||
errHandler: params.errHandler,
|
||||
@ -139,7 +141,9 @@ func (p *processor) stop() {
|
||||
func (p *processor) shutdown() {
|
||||
p.stop()
|
||||
|
||||
time.AfterFunc(p.shutdownTimeout, func() { close(p.abort) })
|
||||
go func(abort chan struct{}) {
|
||||
time.AfterFunc(p.shutdownTimeout, func() { close(abort) })
|
||||
}(p.abort)
|
||||
|
||||
p.logger.Info("Waiting for all workers to finish...")
|
||||
// block until all workers have released the token
|
||||
@ -149,6 +153,14 @@ func (p *processor) shutdown() {
|
||||
p.logger.Info("All workers have finished")
|
||||
}
|
||||
|
||||
func (p *processor) resetState() {
|
||||
p.sema = make(chan struct{}, cap(p.sema))
|
||||
p.done = make(chan struct{})
|
||||
p.quit = make(chan struct{})
|
||||
p.abort = make(chan struct{})
|
||||
p.once = &sync.Once{}
|
||||
}
|
||||
|
||||
func (p *processor) start(wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
|
85
server.go
85
server.go
@ -44,6 +44,10 @@ type Server struct {
|
||||
|
||||
state *serverState
|
||||
|
||||
mu sync.RWMutex
|
||||
queues map[string]int
|
||||
strictPriority bool
|
||||
|
||||
// wait group to wait for all goroutines to finish.
|
||||
wg sync.WaitGroup
|
||||
forwarder *forwarder
|
||||
@ -481,7 +485,9 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server {
|
||||
}
|
||||
}
|
||||
if len(queues) == 0 {
|
||||
queues = defaultQueueConfig
|
||||
for qname, p := range defaultQueueConfig {
|
||||
queues[qname] = p
|
||||
}
|
||||
}
|
||||
var qnames []string
|
||||
for q := range queues {
|
||||
@ -610,6 +616,8 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server {
|
||||
groupAggregator: cfg.GroupAggregator,
|
||||
})
|
||||
return &Server{
|
||||
queues: queues,
|
||||
strictPriority: cfg.StrictPriority,
|
||||
logger: logger,
|
||||
broker: rdb,
|
||||
sharedConnection: true,
|
||||
@ -792,3 +800,78 @@ func (srv *Server) Ping() error {
|
||||
|
||||
return srv.broker.Ping()
|
||||
}
|
||||
|
||||
func (srv *Server) AddQueue(qname string, priority, concurrency int) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if _, ok := srv.queues[qname]; ok {
|
||||
srv.logger.Warnf("queue %s already exists, skipping", qname)
|
||||
return
|
||||
}
|
||||
|
||||
srv.queues[qname] = priority
|
||||
|
||||
srv.state.mu.Lock()
|
||||
state := srv.state.value
|
||||
srv.state.mu.Unlock()
|
||||
if state == srvStateNew || state == srvStateClosed {
|
||||
srv.queues[qname] = priority
|
||||
return
|
||||
}
|
||||
|
||||
srv.logger.Info("restart server...")
|
||||
srv.forwarder.shutdown()
|
||||
srv.processor.shutdown()
|
||||
srv.recoverer.shutdown()
|
||||
srv.syncer.shutdown()
|
||||
srv.subscriber.shutdown()
|
||||
srv.janitor.shutdown()
|
||||
srv.aggregator.shutdown()
|
||||
srv.healthchecker.shutdown()
|
||||
srv.heartbeater.shutdown()
|
||||
srv.wg.Wait()
|
||||
|
||||
qnames := make([]string, 0, len(srv.queues))
|
||||
for q := range srv.queues {
|
||||
qnames = append(qnames, q)
|
||||
}
|
||||
srv.broker.SetQueueConcurrency(qname, concurrency)
|
||||
srv.heartbeater.queues = srv.queues
|
||||
srv.recoverer.queues = qnames
|
||||
srv.forwarder.queues = qnames
|
||||
srv.processor.resetState()
|
||||
queues := normalizeQueues(srv.queues)
|
||||
orderedQueues := []string(nil)
|
||||
if srv.strictPriority {
|
||||
orderedQueues = sortByPriority(queues)
|
||||
}
|
||||
srv.processor.queueConfig = srv.queues
|
||||
srv.processor.orderedQueues = orderedQueues
|
||||
srv.janitor.queues = qnames
|
||||
srv.aggregator.resetState()
|
||||
srv.aggregator.queues = qnames
|
||||
|
||||
srv.heartbeater.start(&srv.wg)
|
||||
srv.healthchecker.start(&srv.wg)
|
||||
srv.subscriber.start(&srv.wg)
|
||||
srv.syncer.start(&srv.wg)
|
||||
srv.recoverer.start(&srv.wg)
|
||||
srv.forwarder.start(&srv.wg)
|
||||
srv.processor.start(&srv.wg)
|
||||
srv.janitor.start(&srv.wg)
|
||||
srv.aggregator.start(&srv.wg)
|
||||
|
||||
srv.logger.Info("server restarted")
|
||||
}
|
||||
|
||||
func (srv *Server) HasQueue(qname string) bool {
|
||||
srv.mu.RLock()
|
||||
defer srv.mu.RUnlock()
|
||||
_, ok := srv.queues[qname]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (srv *Server) SetQueueConcurrency(queue string, concurrency int) {
|
||||
srv.broker.SetQueueConcurrency(queue, concurrency)
|
||||
}
|
||||
|
105
server_test.go
105
server_test.go
@ -92,9 +92,6 @@ func TestServerWithQueueConcurrency(t *testing.T) {
|
||||
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
|
||||
}
|
||||
|
||||
c := NewClient(redisConnOpt)
|
||||
defer c.Close()
|
||||
|
||||
const taskNum = 8
|
||||
const serverNum = 2
|
||||
tests := []struct {
|
||||
@ -134,6 +131,8 @@ func TestServerWithQueueConcurrency(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var err error
|
||||
testutil.FlushDB(t, r)
|
||||
c := NewClient(redisConnOpt)
|
||||
defer c.Close()
|
||||
for i := 0; i < taskNum; i++ {
|
||||
_, err = c.Enqueue(NewTask("send_email",
|
||||
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})))
|
||||
@ -173,6 +172,106 @@ func TestServerWithQueueConcurrency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerWithDynamicQueue(t *testing.T) {
|
||||
// https://github.com/go-redis/redis/issues/1029
|
||||
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")
|
||||
defer goleak.VerifyNone(t, ignoreOpt)
|
||||
|
||||
redisConnOpt := getRedisConnOpt(t)
|
||||
r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient)
|
||||
if !ok {
|
||||
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
|
||||
}
|
||||
|
||||
const taskNum = 8
|
||||
const serverNum = 2
|
||||
tests := []struct {
|
||||
name string
|
||||
concurrency int
|
||||
queueConcurrency int
|
||||
wantActiveNum int
|
||||
}{
|
||||
{
|
||||
name: "based on client concurrency control",
|
||||
concurrency: 2,
|
||||
queueConcurrency: 6,
|
||||
wantActiveNum: 2 * serverNum,
|
||||
},
|
||||
{
|
||||
name: "no queue concurrency control",
|
||||
concurrency: 2,
|
||||
queueConcurrency: 0,
|
||||
wantActiveNum: 2 * serverNum,
|
||||
},
|
||||
{
|
||||
name: "based on queue concurrency control",
|
||||
concurrency: 6,
|
||||
queueConcurrency: 2,
|
||||
wantActiveNum: 2 * serverNum,
|
||||
},
|
||||
}
|
||||
|
||||
// no-op handler
|
||||
handle := func(ctx context.Context, task *Task) error {
|
||||
time.Sleep(time.Second * 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
var DynamicQueueNameFmt = "dynamic:%d:%d"
|
||||
var servers [serverNum]*Server
|
||||
for tcn, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var err error
|
||||
testutil.FlushDB(t, r)
|
||||
c := NewClient(redisConnOpt)
|
||||
defer c.Close()
|
||||
for i := 0; i < taskNum; i++ {
|
||||
_, err = c.Enqueue(NewTask("send_email",
|
||||
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})),
|
||||
Queue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i%2)))
|
||||
if err != nil {
|
||||
t.Fatalf("could not enqueue a task: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < serverNum; i++ {
|
||||
srv := NewServer(redisConnOpt, Config{
|
||||
Concurrency: tc.concurrency,
|
||||
LogLevel: testLogLevel,
|
||||
QueueConcurrency: map[string]int{base.DefaultQueueName: tc.queueConcurrency},
|
||||
})
|
||||
err = srv.Start(HandlerFunc(handle))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
srv.AddQueue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i), 1, tc.queueConcurrency)
|
||||
servers[i] = srv
|
||||
}
|
||||
defer func() {
|
||||
for _, srv := range servers {
|
||||
srv.Shutdown()
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Second)
|
||||
inspector := NewInspector(redisConnOpt)
|
||||
|
||||
var tasks []*TaskInfo
|
||||
|
||||
for i := range servers {
|
||||
qtasks, err := inspector.ListActiveTasks(fmt.Sprintf(DynamicQueueNameFmt, tcn, i))
|
||||
if err != nil {
|
||||
t.Fatalf("could not list active tasks: %v", err)
|
||||
}
|
||||
tasks = append(tasks, qtasks...)
|
||||
}
|
||||
|
||||
if len(tasks) != tc.wantActiveNum {
|
||||
t.Errorf("dynamic queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerRun(t *testing.T) {
|
||||
// https://github.com/go-redis/redis/issues/1029
|
||||
|
@ -80,6 +80,9 @@ func (s *subscriber) start(wg *sync.WaitGroup) {
|
||||
s.logger.Debug("Subscriber done")
|
||||
return
|
||||
case msg := <-cancelCh:
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
cancel, ok := s.cancelations.Get(msg.Payload)
|
||||
if ok {
|
||||
cancel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user