mirror of
https://github.com/hibiken/asynq.git
synced 2025-07-01 19:03:42 +08:00
73 lines
1.8 KiB
Go
73 lines
1.8 KiB
Go
package rate
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"github.com/hibiken/asynq"
|
|
)
|
|
|
|
// NewSemaphore creates a new counting Semaphore.
|
|
func NewSemaphore(rco asynq.RedisConnOpt, name string, maxConcurrency int) *Semaphore {
|
|
rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
|
|
if !ok {
|
|
panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
|
|
}
|
|
|
|
if maxConcurrency < 1 {
|
|
panic("rate.NewSemaphore: maxConcurrency cannot be less than 1")
|
|
}
|
|
|
|
if len(strings.TrimSpace(name)) == 0 {
|
|
panic("rate.NewSemaphore: name should not be empty")
|
|
}
|
|
|
|
return &Semaphore{
|
|
rc: rc,
|
|
name: name,
|
|
maxConcurrency: maxConcurrency,
|
|
}
|
|
}
|
|
|
|
// Semaphore is a distributed counting semaphore which can be used to set maxConcurrency across multiple asynq servers.
|
|
type Semaphore struct {
|
|
rc redis.UniversalClient
|
|
maxConcurrency int
|
|
name string
|
|
}
|
|
|
|
var acquireCmd = redis.NewScript(`
|
|
local lockCount = tonumber(redis.call('GET', KEYS[1]))
|
|
|
|
if (not lockCount or lockCount < tonumber(ARGV[1])) then
|
|
redis.call('INCR', KEYS[1])
|
|
return true
|
|
else
|
|
return false
|
|
end
|
|
`)
|
|
|
|
// Acquire will try to acquire lock on the counting semaphore.
|
|
// - Returns true, iff semaphore key exists and current value is less than maxConcurrency
|
|
// - Returns false otherwise
|
|
func (s *Semaphore) Acquire(ctx context.Context) bool {
|
|
val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool()
|
|
return val
|
|
}
|
|
|
|
// Release will release the lock on the counting semaphore.
|
|
func (s *Semaphore) Release(ctx context.Context) {
|
|
s.rc.Decr(ctx, semaphoreKey(s.name))
|
|
}
|
|
|
|
// Close closes the connection with redis.
|
|
func (s *Semaphore) Close() error {
|
|
return s.rc.Close()
|
|
}
|
|
|
|
func semaphoreKey(name string) string {
|
|
return fmt.Sprintf("asynq:sema:{%s}", name)
|
|
}
|