mirror of
https://github.com/hibiken/asynq.git
synced 2025-07-01 02:43:40 +08:00
add counting semaphore to control concurrency across multiple asynq servers
This commit is contained in:
parent
b18f27bad7
commit
e869722b70
72
x/rate/semaphore.go
Normal file
72
x/rate/semaphore.go
Normal file
@ -0,0 +1,72 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
// NewSemaphore creates a new counting Semaphore.
|
||||
func NewSemaphore(rco asynq.RedisConnOpt, name string, maxConcurrency int) *Semaphore {
|
||||
rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
|
||||
}
|
||||
|
||||
if maxConcurrency < 1 {
|
||||
panic("rate.NewSemaphore: maxConcurrency cannot be less than 1")
|
||||
}
|
||||
|
||||
if len(strings.TrimSpace(name)) == 0 {
|
||||
panic("rate.NewSemaphore: name should not be empty")
|
||||
}
|
||||
|
||||
return &Semaphore{
|
||||
rc: rc,
|
||||
name: name,
|
||||
maxConcurrency: maxConcurrency,
|
||||
}
|
||||
}
|
||||
|
||||
// Semaphore is a distributed counting semaphore which can be used to set maxConcurrency across multiple asynq servers.
|
||||
type Semaphore struct {
|
||||
rc redis.UniversalClient
|
||||
maxConcurrency int
|
||||
name string
|
||||
}
|
||||
|
||||
var acquireCmd = redis.NewScript(`
|
||||
local lockCount = tonumber(redis.call('GET', KEYS[1]))
|
||||
|
||||
if (not lockCount or lockCount < tonumber(ARGV[1])) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
`)
|
||||
|
||||
// Acquire will try to acquire lock on the counting semaphore.
|
||||
// - Returns true, iff semaphore key exists and current value is less than maxConcurrency
|
||||
// - Returns false otherwise
|
||||
func (s *Semaphore) Acquire(ctx context.Context) bool {
|
||||
val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool()
|
||||
return val
|
||||
}
|
||||
|
||||
// Release will release the lock on the counting semaphore.
|
||||
func (s *Semaphore) Release(ctx context.Context) {
|
||||
s.rc.Decr(ctx, semaphoreKey(s.name))
|
||||
}
|
||||
|
||||
// Close closes the connection with redis.
|
||||
func (s *Semaphore) Close() error {
|
||||
return s.rc.Close()
|
||||
}
|
||||
|
||||
func semaphoreKey(name string) string {
|
||||
return fmt.Sprintf("asynq:sema:{%s}", name)
|
||||
}
|
194
x/rate/semaphore_test.go
Normal file
194
x/rate/semaphore_test.go
Normal file
@ -0,0 +1,194 @@
|
||||
package rate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
var (
|
||||
redisAddr string
|
||||
redisDB int
|
||||
|
||||
useRedisCluster bool
|
||||
redisClusterAddrs string // comma-separated list of host:port
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&redisAddr, "redis_addr", "localhost:6379", "redis address to use in testing")
|
||||
flag.IntVar(&redisDB, "redis_db", 14, "redis db number to use in testing")
|
||||
flag.BoolVar(&useRedisCluster, "redis_cluster", false, "use redis cluster as a broker in testing")
|
||||
flag.StringVar(&redisClusterAddrs, "redis_cluster_addrs", "localhost:7000,localhost:7001,localhost:7002", "comma separated list of redis server addresses")
|
||||
}
|
||||
|
||||
func TestNewSemaphore(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
name string
|
||||
maxConcurrency int
|
||||
wantPanic string
|
||||
connOpt asynq.RedisConnOpt
|
||||
}{
|
||||
{
|
||||
desc: "Bad RedisConnOpt",
|
||||
wantPanic: "rate.NewSemaphore: unsupported RedisConnOpt type *rate.badConnOpt",
|
||||
connOpt: &badConnOpt{},
|
||||
},
|
||||
{
|
||||
desc: "Zero maxConcurrency should panic",
|
||||
wantPanic: "rate.NewSemaphore: maxConcurrency cannot be less than 1",
|
||||
},
|
||||
{
|
||||
desc: "Empty name should panic",
|
||||
maxConcurrency: 2,
|
||||
name: " ",
|
||||
wantPanic: "rate.NewSemaphore: name should not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
if tt.wantPanic != "" {
|
||||
defer func() {
|
||||
if r := recover(); r.(string) != tt.wantPanic {
|
||||
t.Errorf("%s;\nNewSemaphore should panic with msg: %s, got %s", tt.desc, tt.wantPanic, r.(string))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
opt := tt.connOpt
|
||||
if tt.connOpt == nil {
|
||||
opt = getRedisConnOpt(t)
|
||||
}
|
||||
|
||||
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||
defer sema.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSemaphore_Acquire(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
name string
|
||||
maxConcurrency int
|
||||
callAcquire int
|
||||
}{
|
||||
{
|
||||
desc: "Should acquire lock when current lock count is less than maxConcurrency",
|
||||
name: "task-1",
|
||||
maxConcurrency: 3,
|
||||
callAcquire: 2,
|
||||
},
|
||||
{
|
||||
desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency",
|
||||
name: "task-2",
|
||||
maxConcurrency: 3,
|
||||
callAcquire: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
opt := getRedisConnOpt(t)
|
||||
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||
defer rc.Close()
|
||||
|
||||
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||
defer sema.Close()
|
||||
|
||||
for i := 0; i < tt.callAcquire; i++ {
|
||||
if got := sema.Acquire(context.Background()); got != (i < tt.maxConcurrency) {
|
||||
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, i < tt.maxConcurrency)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSemaphore_Release(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
name string
|
||||
maxConcurrency int
|
||||
callRelease int
|
||||
}{
|
||||
{
|
||||
desc: "Should decrease lock count",
|
||||
name: "task-3",
|
||||
maxConcurrency: 3,
|
||||
callRelease: 1,
|
||||
},
|
||||
{
|
||||
desc: "Should decrease lock count by 2",
|
||||
name: "task-4",
|
||||
maxConcurrency: 3,
|
||||
callRelease: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
opt := getRedisConnOpt(t)
|
||||
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||
defer rc.Close()
|
||||
|
||||
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
if err := rc.IncrBy(context.Background(), semaphoreKey(tt.name), int64(tt.maxConcurrency)).Err(); err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.IncrBy() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||
defer sema.Close()
|
||||
|
||||
for i := 0; i < tt.callRelease; i++ {
|
||||
sema.Release(context.Background())
|
||||
}
|
||||
|
||||
i, err := rc.Get(context.Background(), semaphoreKey(tt.name)).Int()
|
||||
if err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.Get() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
if i != tt.maxConcurrency-tt.callRelease {
|
||||
t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want %v",
|
||||
tt.desc, i, tt.maxConcurrency-tt.callRelease)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getRedisConnOpt(tb testing.TB) asynq.RedisConnOpt {
|
||||
tb.Helper()
|
||||
if useRedisCluster {
|
||||
addrs := strings.Split(redisClusterAddrs, ",")
|
||||
if len(addrs) == 0 {
|
||||
tb.Fatal("No redis cluster addresses provided. Please set addresses using --redis_cluster_addrs flag.")
|
||||
}
|
||||
return asynq.RedisClusterClientOpt{
|
||||
Addrs: addrs,
|
||||
}
|
||||
}
|
||||
return asynq.RedisClientOpt{
|
||||
Addr: redisAddr,
|
||||
DB: redisDB,
|
||||
}
|
||||
}
|
||||
|
||||
type badConnOpt struct {
|
||||
}
|
||||
|
||||
func (b badConnOpt) MakeRedisClient() interface{} {
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user