diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 22df506..a29a262 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -13,11 +13,12 @@ import ( "time" "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "github.com/spf13/cast" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/timeutil" - "github.com/redis/go-redis/v9" - "github.com/spf13/cast" ) const statsTTL = 90 * 24 * time.Hour // 90 days @@ -25,19 +26,32 @@ const statsTTL = 90 * 24 * time.Hour // 90 days // LeaseDuration is the duration used to initially create a lease and to extend it thereafter. const LeaseDuration = 30 * time.Second +type Option func(r *RDB) + +func WithQueueConcurrency(queueConcurrency map[string]int) Option { + return func(r *RDB) { + r.queueConcurrency = queueConcurrency + } +} + // RDB is a client interface to query and mutate task queues. type RDB struct { - client redis.UniversalClient - clock timeutil.Clock - queuesPublished sync.Map + client redis.UniversalClient + clock timeutil.Clock + queuesPublished sync.Map + queueConcurrency map[string]int } // NewRDB returns a new instance of RDB. -func NewRDB(client redis.UniversalClient) *RDB { - return &RDB{ +func NewRDB(client redis.UniversalClient, opts ...Option) *RDB { + r := &RDB{ client: client, clock: timeutil.NewRealClock(), } + for _, opt := range opts { + opt(r) + } + return r } // Close closes the connection with redis server. @@ -217,6 +231,7 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time // -- // ARGV[1] -> initial lease expiration Unix time // ARGV[2] -> task key prefix +// ARGV[3] -> queue concurrency // // Output: // Returns nil if no processable task is found in the given queue. @@ -225,15 +240,20 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time // Note: dequeueCmd checks whether a queue is paused first, before // calling RPOPLPUSH to pop a task from the queue. var dequeueCmd = redis.NewScript(` -if redis.call("EXISTS", KEYS[2]) == 0 then - local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3]) - if id then - local key = ARGV[2] .. id - redis.call("HSET", key, "state", "active") - redis.call("HDEL", key, "pending_since") - redis.call("ZADD", KEYS[4], ARGV[1], id) - return redis.call("HGET", key, "msg") - end +if redis.call("EXISTS", KEYS[2]) > 0 then + return nil +end +local count = redis.call("ZCARD", KEYS[4]) +if (count >= tonumber(ARGV[3])) then + return nil +end +local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3]) +if id then + local key = ARGV[2] .. id + redis.call("HSET", key, "state", "active") + redis.call("HDEL", key, "pending_since") + redis.call("ZADD", KEYS[4], ARGV[1], id) + return redis.call("HGET", key, "msg") end return nil`) @@ -251,9 +271,14 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT base.LeaseKey(qname), } leaseExpirationTime = r.clock.Now().Add(LeaseDuration) + queueConcurrency, ok := r.queueConcurrency[qname] + if !ok || queueConcurrency <= 0 { + queueConcurrency = math.MaxInt + } argv := []interface{}{ leaseExpirationTime.Unix(), base.TaskKeyPrefix(qname), + queueConcurrency, } res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result() if err == redis.Nil { diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index 5249a29..e58a461 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -384,6 +384,7 @@ func TestDequeue(t *testing.T) { wantPending map[string][]*base.TaskMessage wantActive map[string][]*base.TaskMessage wantLease map[string][]base.Z + queueConcurrency map[string]int }{ { pending: map[string][]*base.TaskMessage{ @@ -494,6 +495,86 @@ func TestDequeue(t *testing.T) { } } +func TestDequeueWithQueueConcurrency(t *testing.T) { + r := setup(t) + defer r.Close() + now := time.Now() + r.SetClock(timeutil.NewSimulatedClock(now)) + const taskNum = 3 + msgs := make([]*base.TaskMessage, 0, taskNum) + for i := 0; i < taskNum; i++ { + msg := &base.TaskMessage{ + ID: uuid.NewString(), + Type: "send_email", + Payload: h.JSON(map[string]interface{}{"subject": "hello!"}), + Queue: "default", + Timeout: 1800, + Deadline: 0, + } + msgs = append(msgs, msg) + } + + tests := []struct { + name string + pending map[string][]*base.TaskMessage + qnames []string // list of queues to query + queueConcurrency map[string]int + wantMsgs []*base.TaskMessage + }{ + { + name: "without queue concurrency control", + pending: map[string][]*base.TaskMessage{ + "default": msgs, + }, + qnames: []string{"default"}, + wantMsgs: msgs, + }, + { + name: "with queue concurrency control", + pending: map[string][]*base.TaskMessage{ + "default": msgs, + }, + qnames: []string{"default"}, + queueConcurrency: map[string]int{"default": 2}, + wantMsgs: msgs[:2], + }, + { + name: "with queue concurrency zero", + pending: map[string][]*base.TaskMessage{ + "default": msgs, + }, + qnames: []string{"default"}, + queueConcurrency: map[string]int{"default": 0}, + wantMsgs: msgs, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h.FlushDB(t, r.client) // clean up db before each test case + h.SeedAllPendingQueues(t, r.client, tc.pending) + + r.queueConcurrency = tc.queueConcurrency + gotMsgs := make([]*base.TaskMessage, 0, len(msgs)) + for i := 0; i < len(msgs); i++ { + msg, _, err := r.Dequeue(tc.qnames...) + if errors.Is(err, errors.ErrNoProcessableTask) { + break + } + if err != nil { + t.Errorf("(*RDB).Dequeue(%v) returned error %v", tc.qnames, err) + continue + } + gotMsgs = append(gotMsgs, msg) + } + if diff := cmp.Diff(tc.wantMsgs, gotMsgs, h.SortZSetEntryOpt); diff != "" { + t.Errorf("(*RDB).Dequeue(%v) returned message %v; want %v", + tc.qnames, gotMsgs, tc.wantMsgs) + } + }) + } +} + func TestDequeueError(t *testing.T) { r := setup(t) defer r.Close() diff --git a/server.go b/server.go index 0cc4f38..111be65 100644 --- a/server.go +++ b/server.go @@ -15,10 +15,11 @@ import ( "sync" "time" + "github.com/redis/go-redis/v9" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/log" "github.com/hibiken/asynq/internal/rdb" - "github.com/redis/go-redis/v9" ) // Server is responsible for task processing and task lifecycle management. @@ -253,6 +254,11 @@ type Config struct { // If unset or zero, default batch size of 100 is used. // Make sure to not put a big number as the batch size to prevent a long-running script. JanitorBatchSize int + + // Maximum number of concurrent tasks of a queue. + // + // If set to a zero or not set, NewServer will not limit concurrency of the queue. + QueueConcurrency map[string]int } // GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler. @@ -504,7 +510,7 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server { } logger.SetLevel(toInternalLogLevel(loglevel)) - rdb := rdb.NewRDB(c) + rdb := rdb.NewRDB(c, rdb.WithQueueConcurrency(cfg.QueueConcurrency)) starting := make(chan *workerInfo) finished := make(chan *base.TaskMessage) syncCh := make(chan *syncRequest) diff --git a/server_test.go b/server_test.go index 967f519..0d7d6c8 100644 --- a/server_test.go +++ b/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/testbroker" "github.com/hibiken/asynq/internal/testutil" @@ -80,6 +81,99 @@ func TestServerFromRedisClient(t *testing.T) { } } +func TestServerWithQueueConcurrency(t *testing.T) { + // https://github.com/go-redis/redis/issues/1029 + ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper") + defer goleak.VerifyNone(t, ignoreOpt) + + redisConnOpt := getRedisConnOpt(t) + r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient) + if !ok { + t.Fatalf("asynq: unsupported RedisConnOpt type %T", r) + } + + c := NewClient(redisConnOpt) + defer c.Close() + + const taskNum = 8 + const serverNum = 2 + tests := []struct { + name string + concurrency int + queueConcurrency int + wantActiveNum int + }{ + { + name: "based on client concurrency control", + concurrency: 2, + queueConcurrency: 6, + wantActiveNum: 2 * serverNum, + }, + { + name: "no queue concurrency control", + concurrency: 2, + queueConcurrency: 0, + wantActiveNum: 2 * serverNum, + }, + { + name: "based on queue concurrency control", + concurrency: 6, + queueConcurrency: 2, + wantActiveNum: 2, + }, + } + + // no-op handler + handle := func(ctx context.Context, task *Task) error { + time.Sleep(time.Second * 2) + return nil + } + + var servers [serverNum]*Server + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var err error + testutil.FlushDB(t, r) + for i := 0; i < taskNum; i++ { + _, err = c.Enqueue(NewTask("send_email", + testutil.JSON(map[string]interface{}{"recipient_id": i + 123}))) + if err != nil { + t.Fatalf("could not enqueue a task: %v", err) + } + } + + for i := 0; i < serverNum; i++ { + srv := NewServer(redisConnOpt, Config{ + Concurrency: tc.concurrency, + LogLevel: testLogLevel, + QueueConcurrency: map[string]int{base.DefaultQueueName: tc.queueConcurrency}, + }) + err = srv.Start(HandlerFunc(handle)) + if err != nil { + t.Fatal(err) + } + servers[i] = srv + } + defer func() { + for _, srv := range servers { + srv.Shutdown() + } + }() + + time.Sleep(time.Second) + inspector := NewInspector(redisConnOpt) + tasks, err := inspector.ListActiveTasks(base.DefaultQueueName) + if err != nil { + t.Fatalf("could not list active tasks: %v", err) + } + if len(tasks) != tc.wantActiveNum { + t.Errorf("default queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum) + } + }) + } +} + + func TestServerRun(t *testing.T) { // https://github.com/go-redis/redis/issues/1029 ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")