2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-14 19:38:49 +08:00

feat: concurrency control queue

This commit is contained in:
kanzihuang 2024-03-16 21:15:48 +08:00
parent d04888e748
commit 56bf84cb4a
4 changed files with 219 additions and 15 deletions

View File

@ -24,18 +24,31 @@ const statsTTL = 90 * 24 * time.Hour // 90 days
// LeaseDuration is the duration used to initially create a lease and to extend it thereafter. // LeaseDuration is the duration used to initially create a lease and to extend it thereafter.
const LeaseDuration = 30 * time.Second const LeaseDuration = 30 * time.Second
type Option func(r *RDB)
func WithQueueConcurrency(queueConcurrency map[string]int) Option {
return func(r *RDB) {
r.queueConcurrency = queueConcurrency
}
}
// RDB is a client interface to query and mutate task queues. // RDB is a client interface to query and mutate task queues.
type RDB struct { type RDB struct {
client redis.UniversalClient client redis.UniversalClient
clock timeutil.Clock clock timeutil.Clock
queueConcurrency map[string]int
} }
// NewRDB returns a new instance of RDB. // NewRDB returns a new instance of RDB.
func NewRDB(client redis.UniversalClient) *RDB { func NewRDB(client redis.UniversalClient, opts ...Option) *RDB {
return &RDB{ r := &RDB{
client: client, client: client,
clock: timeutil.NewRealClock(), clock: timeutil.NewRealClock(),
} }
for _, opt := range opts {
opt(r)
}
return r
} }
// Close closes the connection with redis server. // Close closes the connection with redis server.
@ -209,6 +222,7 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
// -- // --
// ARGV[1] -> initial lease expiration Unix time // ARGV[1] -> initial lease expiration Unix time
// ARGV[2] -> task key prefix // ARGV[2] -> task key prefix
// ARGV[3] -> queue concurrency
// //
// Output: // Output:
// Returns nil if no processable task is found in the given queue. // Returns nil if no processable task is found in the given queue.
@ -217,15 +231,20 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
// Note: dequeueCmd checks whether a queue is paused first, before // Note: dequeueCmd checks whether a queue is paused first, before
// calling RPOPLPUSH to pop a task from the queue. // calling RPOPLPUSH to pop a task from the queue.
var dequeueCmd = redis.NewScript(` var dequeueCmd = redis.NewScript(`
if redis.call("EXISTS", KEYS[2]) == 0 then if redis.call("EXISTS", KEYS[2]) > 0 then
local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3]) return nil
if id then end
local key = ARGV[2] .. id local count = redis.call("ZCARD", KEYS[4])
redis.call("HSET", key, "state", "active") if (count >= tonumber(ARGV[3])) then
redis.call("HDEL", key, "pending_since") return nil
redis.call("ZADD", KEYS[4], ARGV[1], id) end
return redis.call("HGET", key, "msg") local id = redis.call("RPOPLPUSH", KEYS[1], KEYS[3])
end if id then
local key = ARGV[2] .. id
redis.call("HSET", key, "state", "active")
redis.call("HDEL", key, "pending_since")
redis.call("ZADD", KEYS[4], ARGV[1], id)
return redis.call("HGET", key, "msg")
end end
return nil`) return nil`)
@ -243,9 +262,14 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT
base.LeaseKey(qname), base.LeaseKey(qname),
} }
leaseExpirationTime = r.clock.Now().Add(LeaseDuration) leaseExpirationTime = r.clock.Now().Add(LeaseDuration)
queueConcurrency, ok := r.queueConcurrency[qname]
if !ok || queueConcurrency <= 0 {
queueConcurrency = math.MaxInt
}
argv := []interface{}{ argv := []interface{}{
leaseExpirationTime.Unix(), leaseExpirationTime.Unix(),
base.TaskKeyPrefix(qname), base.TaskKeyPrefix(qname),
queueConcurrency,
} }
res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result() res, err := dequeueCmd.Run(context.Background(), r.client, keys, argv...).Result()
if err == redis.Nil { if err == redis.Nil {

View File

@ -331,6 +331,7 @@ func TestDequeue(t *testing.T) {
wantPending map[string][]*base.TaskMessage wantPending map[string][]*base.TaskMessage
wantActive map[string][]*base.TaskMessage wantActive map[string][]*base.TaskMessage
wantLease map[string][]base.Z wantLease map[string][]base.Z
queueConcurrency map[string]int
}{ }{
{ {
pending: map[string][]*base.TaskMessage{ pending: map[string][]*base.TaskMessage{
@ -441,6 +442,86 @@ func TestDequeue(t *testing.T) {
} }
} }
func TestDequeueWithQueueConcurrency(t *testing.T) {
r := setup(t)
defer r.Close()
now := time.Now()
r.SetClock(timeutil.NewSimulatedClock(now))
const taskNum = 3
msgs := make([]*base.TaskMessage, 0, taskNum)
for i := 0; i < taskNum; i++ {
msg := &base.TaskMessage{
ID: uuid.NewString(),
Type: "send_email",
Payload: h.JSON(map[string]interface{}{"subject": "hello!"}),
Queue: "default",
Timeout: 1800,
Deadline: 0,
}
msgs = append(msgs, msg)
}
tests := []struct {
name string
pending map[string][]*base.TaskMessage
qnames []string // list of queues to query
queueConcurrency map[string]int
wantMsgs []*base.TaskMessage
}{
{
name: "without queue concurrency control",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
wantMsgs: msgs,
},
{
name: "with queue concurrency control",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
queueConcurrency: map[string]int{"default": 2},
wantMsgs: msgs[:2],
},
{
name: "with queue concurrency zero",
pending: map[string][]*base.TaskMessage{
"default": msgs,
},
qnames: []string{"default"},
queueConcurrency: map[string]int{"default": 0},
wantMsgs: msgs,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
h.FlushDB(t, r.client) // clean up db before each test case
h.SeedAllPendingQueues(t, r.client, tc.pending)
r.queueConcurrency = tc.queueConcurrency
gotMsgs := make([]*base.TaskMessage, 0, len(msgs))
for i := 0; i < len(msgs); i++ {
msg, _, err := r.Dequeue(tc.qnames...)
if errors.Is(err, errors.ErrNoProcessableTask) {
break
}
if err != nil {
t.Errorf("(*RDB).Dequeue(%v) returned error %v", tc.qnames, err)
continue
}
gotMsgs = append(gotMsgs, msg)
}
if diff := cmp.Diff(tc.wantMsgs, gotMsgs, h.SortZSetEntryOpt); diff != "" {
t.Errorf("(*RDB).Dequeue(%v) returned message %v; want %v",
tc.qnames, gotMsgs, tc.wantMsgs)
}
})
}
}
func TestDequeueError(t *testing.T) { func TestDequeueError(t *testing.T) {
r := setup(t) r := setup(t)
defer r.Close() defer r.Close()

View File

@ -103,7 +103,7 @@ type Config struct {
// If BaseContext is nil, the default is context.Background(). // If BaseContext is nil, the default is context.Background().
// If this is defined, then it MUST return a non-nil context // If this is defined, then it MUST return a non-nil context
BaseContext func() context.Context BaseContext func() context.Context
// TaskCheckInterval specifies the interval between checks for new tasks to process when all queues are empty. // TaskCheckInterval specifies the interval between checks for new tasks to process when all queues are empty.
// //
// If unset, zero or a negative value, the interval is set to 1 second. // If unset, zero or a negative value, the interval is set to 1 second.
@ -250,6 +250,11 @@ type Config struct {
// If unset or zero, default batch size of 100 is used. // If unset or zero, default batch size of 100 is used.
// Make sure to not put a big number as the batch size to prevent a long-running script. // Make sure to not put a big number as the batch size to prevent a long-running script.
JanitorBatchSize int JanitorBatchSize int
// Maximum number of concurrent tasks of a queue.
//
// If set to a zero or not set, NewServer will not limit concurrency of the queue.
QueueConcurrency map[string]int
} }
// GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler. // GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler.
@ -493,7 +498,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server {
} }
logger.SetLevel(toInternalLogLevel(loglevel)) logger.SetLevel(toInternalLogLevel(loglevel))
rdb := rdb.NewRDB(c) rdb := rdb.NewRDB(c, rdb.WithQueueConcurrency(cfg.QueueConcurrency))
starting := make(chan *workerInfo) starting := make(chan *workerInfo)
finished := make(chan *base.TaskMessage) finished := make(chan *base.TaskMessage)
syncCh := make(chan *syncRequest) syncCh := make(chan *syncRequest)

View File

@ -11,9 +11,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
"github.com/hibiken/asynq/internal/testbroker" "github.com/hibiken/asynq/internal/testbroker"
"github.com/hibiken/asynq/internal/testutil" "github.com/hibiken/asynq/internal/testutil"
"github.com/redis/go-redis/v9"
"go.uber.org/goleak" "go.uber.org/goleak"
) )
@ -53,6 +55,98 @@ func TestServer(t *testing.T) {
srv.Shutdown() srv.Shutdown()
} }
func TestServerWithQueueConcurrency(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")
defer goleak.VerifyNone(t, ignoreOpt)
redisConnOpt := getRedisConnOpt(t)
r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient)
if !ok {
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
}
c := NewClient(redisConnOpt)
defer c.Close()
const taskNum = 8
const serverNum = 2
tests := []struct {
name string
concurrency int
queueConcurrency int
wantActiveNum int
}{
{
name: "based on client concurrency control",
concurrency: 2,
queueConcurrency: 6,
wantActiveNum: 2 * serverNum,
},
{
name: "no queue concurrency control",
concurrency: 2,
queueConcurrency: 0,
wantActiveNum: 2 * serverNum,
},
{
name: "based on queue concurrency control",
concurrency: 6,
queueConcurrency: 2,
wantActiveNum: 2,
},
}
// no-op handler
handle := func(ctx context.Context, task *Task) error {
time.Sleep(time.Second * 2)
return nil
}
var servers [serverNum]*Server
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var err error
testutil.FlushDB(t, r)
for i := 0; i < taskNum; i++ {
_, err = c.Enqueue(NewTask("send_email",
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})))
if err != nil {
t.Fatalf("could not enqueue a task: %v", err)
}
}
for i := 0; i < serverNum; i++ {
srv := NewServer(redisConnOpt, Config{
Concurrency: tc.concurrency,
LogLevel: testLogLevel,
QueueConcurrency: map[string]int{base.DefaultQueueName: tc.queueConcurrency},
})
err = srv.Start(HandlerFunc(handle))
if err != nil {
t.Fatal(err)
}
servers[i] = srv
}
defer func() {
for _, srv := range servers {
srv.Shutdown()
}
}()
time.Sleep(time.Second)
inspector := NewInspector(redisConnOpt)
tasks, err := inspector.ListActiveTasks(base.DefaultQueueName)
if err != nil {
t.Fatalf("could not list active tasks: %v", err)
}
if len(tasks) != tc.wantActiveNum {
t.Errorf("default queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum)
}
})
}
}
func TestServerRun(t *testing.T) { func TestServerRun(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029 // https://github.com/go-redis/redis/issues/1029
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper") ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")