2
0
mirror of https://github.com/hibiken/asynq.git synced 2025-07-01 19:03:42 +08:00
asynq/x/rate/semaphore.go

73 lines
1.8 KiB
Go

package rate
import (
"context"
"fmt"
"strings"
"github.com/go-redis/redis/v8"
"github.com/hibiken/asynq"
)
// NewSemaphore creates a new counting Semaphore.
func NewSemaphore(rco asynq.RedisConnOpt, name string, maxConcurrency int) *Semaphore {
rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
if !ok {
panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
}
if maxConcurrency < 1 {
panic("rate.NewSemaphore: maxConcurrency cannot be less than 1")
}
if len(strings.TrimSpace(name)) == 0 {
panic("rate.NewSemaphore: name should not be empty")
}
return &Semaphore{
rc: rc,
name: name,
maxConcurrency: maxConcurrency,
}
}
// Semaphore is a distributed counting semaphore which can be used to set maxConcurrency across multiple asynq servers.
type Semaphore struct {
rc redis.UniversalClient
maxConcurrency int
name string
}
var acquireCmd = redis.NewScript(`
local lockCount = tonumber(redis.call('GET', KEYS[1]))
if (not lockCount or lockCount < tonumber(ARGV[1])) then
redis.call('INCR', KEYS[1])
return true
else
return false
end
`)
// Acquire will try to acquire lock on the counting semaphore.
// - Returns true, iff semaphore key exists and current value is less than maxConcurrency
// - Returns false otherwise
func (s *Semaphore) Acquire(ctx context.Context) bool {
val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool()
return val
}
// Release will release the lock on the counting semaphore.
func (s *Semaphore) Release(ctx context.Context) {
s.rc.Decr(ctx, semaphoreKey(s.name))
}
// Close closes the connection with redis.
func (s *Semaphore) Close() error {
return s.rc.Close()
}
func semaphoreKey(name string) string {
return fmt.Sprintf("asynq:sema:{%s}", name)
}