From e869722b704f432b8ef74be9bf3c933003f61714 Mon Sep 17 00:00:00 2001 From: ajatprabha Date: Wed, 27 Oct 2021 18:52:33 +0530 Subject: [PATCH] add counting semaphore to control concurrency across multiple asynq servers --- x/rate/semaphore.go | 72 +++++++++++++++ x/rate/semaphore_test.go | 194 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 x/rate/semaphore.go create mode 100644 x/rate/semaphore_test.go diff --git a/x/rate/semaphore.go b/x/rate/semaphore.go new file mode 100644 index 0000000..353882f --- /dev/null +++ b/x/rate/semaphore.go @@ -0,0 +1,72 @@ +package rate + +import ( + "context" + "fmt" + "strings" + + "github.com/go-redis/redis/v8" + "github.com/hibiken/asynq" +) + +// NewSemaphore creates a new counting Semaphore. +func NewSemaphore(rco asynq.RedisConnOpt, name string, maxConcurrency int) *Semaphore { + rc, ok := rco.MakeRedisClient().(redis.UniversalClient) + if !ok { + panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco)) + } + + if maxConcurrency < 1 { + panic("rate.NewSemaphore: maxConcurrency cannot be less than 1") + } + + if len(strings.TrimSpace(name)) == 0 { + panic("rate.NewSemaphore: name should not be empty") + } + + return &Semaphore{ + rc: rc, + name: name, + maxConcurrency: maxConcurrency, + } +} + +// Semaphore is a distributed counting semaphore which can be used to set maxConcurrency across multiple asynq servers. +type Semaphore struct { + rc redis.UniversalClient + maxConcurrency int + name string +} + +var acquireCmd = redis.NewScript(` +local lockCount = tonumber(redis.call('GET', KEYS[1])) + +if (not lockCount or lockCount < tonumber(ARGV[1])) then + redis.call('INCR', KEYS[1]) + return true +else + return false +end +`) + +// Acquire will try to acquire lock on the counting semaphore. +// - Returns true, iff semaphore key exists and current value is less than maxConcurrency +// - Returns false otherwise +func (s *Semaphore) Acquire(ctx context.Context) bool { + val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool() + return val +} + +// Release will release the lock on the counting semaphore. +func (s *Semaphore) Release(ctx context.Context) { + s.rc.Decr(ctx, semaphoreKey(s.name)) +} + +// Close closes the connection with redis. +func (s *Semaphore) Close() error { + return s.rc.Close() +} + +func semaphoreKey(name string) string { + return fmt.Sprintf("asynq:sema:{%s}", name) +} diff --git a/x/rate/semaphore_test.go b/x/rate/semaphore_test.go new file mode 100644 index 0000000..e53f4f9 --- /dev/null +++ b/x/rate/semaphore_test.go @@ -0,0 +1,194 @@ +package rate + +import ( + "context" + "flag" + "strings" + "testing" + + "github.com/go-redis/redis/v8" + "github.com/hibiken/asynq" +) + +var ( + redisAddr string + redisDB int + + useRedisCluster bool + redisClusterAddrs string // comma-separated list of host:port +) + +func init() { + flag.StringVar(&redisAddr, "redis_addr", "localhost:6379", "redis address to use in testing") + flag.IntVar(&redisDB, "redis_db", 14, "redis db number to use in testing") + flag.BoolVar(&useRedisCluster, "redis_cluster", false, "use redis cluster as a broker in testing") + flag.StringVar(&redisClusterAddrs, "redis_cluster_addrs", "localhost:7000,localhost:7001,localhost:7002", "comma separated list of redis server addresses") +} + +func TestNewSemaphore(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + wantPanic string + connOpt asynq.RedisConnOpt + }{ + { + desc: "Bad RedisConnOpt", + wantPanic: "rate.NewSemaphore: unsupported RedisConnOpt type *rate.badConnOpt", + connOpt: &badConnOpt{}, + }, + { + desc: "Zero maxConcurrency should panic", + wantPanic: "rate.NewSemaphore: maxConcurrency cannot be less than 1", + }, + { + desc: "Empty name should panic", + maxConcurrency: 2, + name: " ", + wantPanic: "rate.NewSemaphore: name should not be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if tt.wantPanic != "" { + defer func() { + if r := recover(); r.(string) != tt.wantPanic { + t.Errorf("%s;\nNewSemaphore should panic with msg: %s, got %s", tt.desc, tt.wantPanic, r.(string)) + } + }() + } + + opt := tt.connOpt + if tt.connOpt == nil { + opt = getRedisConnOpt(t) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + }) + } +} + +func TestNewSemaphore_Acquire(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + callAcquire int + }{ + { + desc: "Should acquire lock when current lock count is less than maxConcurrency", + name: "task-1", + maxConcurrency: 3, + callAcquire: 2, + }, + { + desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency", + name: "task-2", + maxConcurrency: 3, + callAcquire: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + + for i := 0; i < tt.callAcquire; i++ { + if got := sema.Acquire(context.Background()); got != (i < tt.maxConcurrency) { + t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, i < tt.maxConcurrency) + } + } + }) + } +} + +func TestNewSemaphore_Release(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + callRelease int + }{ + { + desc: "Should decrease lock count", + name: "task-3", + maxConcurrency: 3, + callRelease: 1, + }, + { + desc: "Should decrease lock count by 2", + name: "task-4", + maxConcurrency: 3, + callRelease: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + if err := rc.IncrBy(context.Background(), semaphoreKey(tt.name), int64(tt.maxConcurrency)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.IncrBy() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + + for i := 0; i < tt.callRelease; i++ { + sema.Release(context.Background()) + } + + i, err := rc.Get(context.Background(), semaphoreKey(tt.name)).Int() + if err != nil { + t.Errorf("%s;\nredis.UniversalClient.Get() got error %v", tt.desc, err) + } + + if i != tt.maxConcurrency-tt.callRelease { + t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want %v", + tt.desc, i, tt.maxConcurrency-tt.callRelease) + } + }) + } +} + +func getRedisConnOpt(tb testing.TB) asynq.RedisConnOpt { + tb.Helper() + if useRedisCluster { + addrs := strings.Split(redisClusterAddrs, ",") + if len(addrs) == 0 { + tb.Fatal("No redis cluster addresses provided. Please set addresses using --redis_cluster_addrs flag.") + } + return asynq.RedisClusterClientOpt{ + Addrs: addrs, + } + } + return asynq.RedisClientOpt{ + Addr: redisAddr, + DB: redisDB, + } +} + +type badConnOpt struct { +} + +func (b badConnOpt) MakeRedisClient() interface{} { + return nil +}