diff --git a/aggregator.go b/aggregator.go index 9f8da70..318d7b2 100644 --- a/aggregator.go +++ b/aggregator.go @@ -81,6 +81,11 @@ func newAggregator(params aggregatorParams) *aggregator { } } +func (a *aggregator) resetState() { + a.done = make(chan struct{}) + a.sema = make(chan struct{}, maxConcurrentAggregationChecks) +} + func (a *aggregator) shutdown() { if a.ga == nil { return diff --git a/heartbeat.go b/heartbeat.go index f426445..c1ac62b 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/uuid" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/log" "github.com/hibiken/asynq/internal/timeutil" diff --git a/internal/base/base.go b/internal/base/base.go index 505e1ba..fbca74f 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -14,12 +14,13 @@ import ( "sync" "time" - "github.com/hibiken/asynq/internal/errors" - pb "github.com/hibiken/asynq/internal/proto" - "github.com/hibiken/asynq/internal/timeutil" "github.com/redis/go-redis/v9" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/hibiken/asynq/internal/errors" + pb "github.com/hibiken/asynq/internal/proto" + "github.com/hibiken/asynq/internal/timeutil" ) // Version of asynq library and CLI. @@ -722,4 +723,5 @@ type Broker interface { PublishCancelation(id string) error WriteResult(qname, id string, data []byte) (n int, err error) + SetQueueConcurrency(qname string, concurrency int) } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index a29a262..9ee3d03 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -30,7 +30,9 @@ type Option func(r *RDB) func WithQueueConcurrency(queueConcurrency map[string]int) Option { return func(r *RDB) { - r.queueConcurrency = queueConcurrency + for qname, concurrency := range queueConcurrency { + r.queueConcurrency.Store(qname, concurrency) + } } } @@ -39,7 +41,7 @@ type RDB struct { client redis.UniversalClient clock timeutil.Clock queuesPublished sync.Map - queueConcurrency map[string]int + queueConcurrency sync.Map } // NewRDB returns a new instance of RDB. @@ -271,8 +273,8 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT base.LeaseKey(qname), } leaseExpirationTime = r.clock.Now().Add(LeaseDuration) - queueConcurrency, ok := r.queueConcurrency[qname] - if !ok || queueConcurrency <= 0 { + queueConcurrency, ok := r.queueConcurrency.Load(qname) + if !ok || queueConcurrency.(int) <= 0 { queueConcurrency = math.MaxInt } argv := []interface{}{ @@ -1581,3 +1583,7 @@ func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) { } return len(data), nil } + +func (r *RDB) SetQueueConcurrency(qname string, concurrency int) { + r.queueConcurrency.Store(qname, concurrency) +} diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index ffab6fe..510bf46 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -11,8 +11,9 @@ import ( "sync" "time" - "github.com/hibiken/asynq/internal/base" "github.com/redis/go-redis/v9" + + "github.com/hibiken/asynq/internal/base" ) var errRedisDown = errors.New("testutil: redis is down") @@ -297,3 +298,7 @@ func (tb *TestBroker) ReclaimStaleAggregationSets(qname string) error { } return tb.real.ReclaimStaleAggregationSets(qname) } + +func (tb *TestBroker) SetQueueConcurrency(qname string, concurrency int) { + tb.real.SetQueueConcurrency(qname, concurrency) +} diff --git a/processor.go b/processor.go index fa810d6..fbd7b05 100644 --- a/processor.go +++ b/processor.go @@ -16,12 +16,13 @@ import ( "sync" "time" + "golang.org/x/time/rate" + "github.com/hibiken/asynq/internal/base" asynqcontext "github.com/hibiken/asynq/internal/context" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/log" "github.com/hibiken/asynq/internal/timeutil" - "golang.org/x/time/rate" ) type processor struct { @@ -57,7 +58,7 @@ type processor struct { // channel to communicate back to the long running "processor" goroutine. // once is used to send value to the channel only once. done chan struct{} - once sync.Once + once *sync.Once // quit channel is closed when the shutdown of the "processor" goroutine starts. quit chan struct{} @@ -112,6 +113,7 @@ func newProcessor(params processorParams) *processor { errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1), sema: make(chan struct{}, params.concurrency), done: make(chan struct{}), + once: &sync.Once{}, quit: make(chan struct{}), abort: make(chan struct{}), errHandler: params.errHandler, @@ -139,7 +141,9 @@ func (p *processor) stop() { func (p *processor) shutdown() { p.stop() - time.AfterFunc(p.shutdownTimeout, func() { close(p.abort) }) + go func(abort chan struct{}) { + time.AfterFunc(p.shutdownTimeout, func() { close(abort) }) + }(p.abort) p.logger.Info("Waiting for all workers to finish...") // block until all workers have released the token @@ -149,6 +153,14 @@ func (p *processor) shutdown() { p.logger.Info("All workers have finished") } +func (p *processor) resetState() { + p.sema = make(chan struct{}, cap(p.sema)) + p.done = make(chan struct{}) + p.quit = make(chan struct{}) + p.abort = make(chan struct{}) + p.once = &sync.Once{} +} + func (p *processor) start(wg *sync.WaitGroup) { wg.Add(1) go func() { diff --git a/server.go b/server.go index 111be65..298ad57 100644 --- a/server.go +++ b/server.go @@ -44,6 +44,10 @@ type Server struct { state *serverState + mu sync.RWMutex + queues map[string]int + strictPriority bool + // wait group to wait for all goroutines to finish. wg sync.WaitGroup forwarder *forwarder @@ -481,7 +485,9 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server { } } if len(queues) == 0 { - queues = defaultQueueConfig + for qname, p := range defaultQueueConfig { + queues[qname] = p + } } var qnames []string for q := range queues { @@ -610,6 +616,8 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server { groupAggregator: cfg.GroupAggregator, }) return &Server{ + queues: queues, + strictPriority: cfg.StrictPriority, logger: logger, broker: rdb, sharedConnection: true, @@ -792,3 +800,78 @@ func (srv *Server) Ping() error { return srv.broker.Ping() } + +func (srv *Server) AddQueue(qname string, priority, concurrency int) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if _, ok := srv.queues[qname]; ok { + srv.logger.Warnf("queue %s already exists, skipping", qname) + return + } + + srv.queues[qname] = priority + + srv.state.mu.Lock() + state := srv.state.value + srv.state.mu.Unlock() + if state == srvStateNew || state == srvStateClosed { + srv.queues[qname] = priority + return + } + + srv.logger.Info("restart server...") + srv.forwarder.shutdown() + srv.processor.shutdown() + srv.recoverer.shutdown() + srv.syncer.shutdown() + srv.subscriber.shutdown() + srv.janitor.shutdown() + srv.aggregator.shutdown() + srv.healthchecker.shutdown() + srv.heartbeater.shutdown() + srv.wg.Wait() + + qnames := make([]string, 0, len(srv.queues)) + for q := range srv.queues { + qnames = append(qnames, q) + } + srv.broker.SetQueueConcurrency(qname, concurrency) + srv.heartbeater.queues = srv.queues + srv.recoverer.queues = qnames + srv.forwarder.queues = qnames + srv.processor.resetState() + queues := normalizeQueues(srv.queues) + orderedQueues := []string(nil) + if srv.strictPriority { + orderedQueues = sortByPriority(queues) + } + srv.processor.queueConfig = srv.queues + srv.processor.orderedQueues = orderedQueues + srv.janitor.queues = qnames + srv.aggregator.resetState() + srv.aggregator.queues = qnames + + srv.heartbeater.start(&srv.wg) + srv.healthchecker.start(&srv.wg) + srv.subscriber.start(&srv.wg) + srv.syncer.start(&srv.wg) + srv.recoverer.start(&srv.wg) + srv.forwarder.start(&srv.wg) + srv.processor.start(&srv.wg) + srv.janitor.start(&srv.wg) + srv.aggregator.start(&srv.wg) + + srv.logger.Info("server restarted") +} + +func (srv *Server) HasQueue(qname string) bool { + srv.mu.RLock() + defer srv.mu.RUnlock() + _, ok := srv.queues[qname] + return ok +} + +func (srv *Server) SetQueueConcurrency(queue string, concurrency int) { + srv.broker.SetQueueConcurrency(queue, concurrency) +} diff --git a/server_test.go b/server_test.go index 0d7d6c8..437e408 100644 --- a/server_test.go +++ b/server_test.go @@ -92,9 +92,6 @@ func TestServerWithQueueConcurrency(t *testing.T) { t.Fatalf("asynq: unsupported RedisConnOpt type %T", r) } - c := NewClient(redisConnOpt) - defer c.Close() - const taskNum = 8 const serverNum = 2 tests := []struct { @@ -134,6 +131,8 @@ func TestServerWithQueueConcurrency(t *testing.T) { t.Run(tc.name, func(t *testing.T) { var err error testutil.FlushDB(t, r) + c := NewClient(redisConnOpt) + defer c.Close() for i := 0; i < taskNum; i++ { _, err = c.Enqueue(NewTask("send_email", testutil.JSON(map[string]interface{}{"recipient_id": i + 123}))) @@ -173,6 +172,106 @@ func TestServerWithQueueConcurrency(t *testing.T) { } } +func TestServerWithDynamicQueue(t *testing.T) { + // https://github.com/go-redis/redis/issues/1029 + ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper") + defer goleak.VerifyNone(t, ignoreOpt) + + redisConnOpt := getRedisConnOpt(t) + r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient) + if !ok { + t.Fatalf("asynq: unsupported RedisConnOpt type %T", r) + } + + const taskNum = 8 + const serverNum = 2 + tests := []struct { + name string + concurrency int + queueConcurrency int + wantActiveNum int + }{ + { + name: "based on client concurrency control", + concurrency: 2, + queueConcurrency: 6, + wantActiveNum: 2 * serverNum, + }, + { + name: "no queue concurrency control", + concurrency: 2, + queueConcurrency: 0, + wantActiveNum: 2 * serverNum, + }, + { + name: "based on queue concurrency control", + concurrency: 6, + queueConcurrency: 2, + wantActiveNum: 2 * serverNum, + }, + } + + // no-op handler + handle := func(ctx context.Context, task *Task) error { + time.Sleep(time.Second * 2) + return nil + } + + var DynamicQueueNameFmt = "dynamic:%d:%d" + var servers [serverNum]*Server + for tcn, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var err error + testutil.FlushDB(t, r) + c := NewClient(redisConnOpt) + defer c.Close() + for i := 0; i < taskNum; i++ { + _, err = c.Enqueue(NewTask("send_email", + testutil.JSON(map[string]interface{}{"recipient_id": i + 123})), + Queue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i%2))) + if err != nil { + t.Fatalf("could not enqueue a task: %v", err) + } + } + + for i := 0; i < serverNum; i++ { + srv := NewServer(redisConnOpt, Config{ + Concurrency: tc.concurrency, + LogLevel: testLogLevel, + QueueConcurrency: map[string]int{base.DefaultQueueName: tc.queueConcurrency}, + }) + err = srv.Start(HandlerFunc(handle)) + if err != nil { + t.Fatal(err) + } + srv.AddQueue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i), 1, tc.queueConcurrency) + servers[i] = srv + } + defer func() { + for _, srv := range servers { + srv.Shutdown() + } + }() + + time.Sleep(time.Second) + inspector := NewInspector(redisConnOpt) + + var tasks []*TaskInfo + + for i := range servers { + qtasks, err := inspector.ListActiveTasks(fmt.Sprintf(DynamicQueueNameFmt, tcn, i)) + if err != nil { + t.Fatalf("could not list active tasks: %v", err) + } + tasks = append(tasks, qtasks...) + } + + if len(tasks) != tc.wantActiveNum { + t.Errorf("dynamic queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum) + } + }) + } +} func TestServerRun(t *testing.T) { // https://github.com/go-redis/redis/issues/1029 diff --git a/subscriber.go b/subscriber.go index 8fc4eac..d4d0d0f 100644 --- a/subscriber.go +++ b/subscriber.go @@ -80,6 +80,9 @@ func (s *subscriber) start(wg *sync.WaitGroup) { s.logger.Debug("Subscriber done") return case msg := <-cancelCh: + if msg == nil { + return + } cancel, ok := s.cancelations.Get(msg.Payload) if ok { cancel()