diff --git a/x/rate/semaphore.go b/x/rate/semaphore.go index 7519db1..f123e1c 100644 --- a/x/rate/semaphore.go +++ b/x/rate/semaphore.go @@ -11,33 +11,33 @@ import ( asynqcontext "github.com/hibiken/asynq/internal/context" ) -// NewSemaphore creates a new counting Semaphore. -func NewSemaphore(rco asynq.RedisConnOpt, name string, maxConcurrency int) *Semaphore { +// NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens. +func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore { rc, ok := rco.MakeRedisClient().(redis.UniversalClient) if !ok { panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco)) } - if maxConcurrency < 1 { - panic("rate.NewSemaphore: maxConcurrency cannot be less than 1") + if maxTokens < 1 { + panic("rate.NewSemaphore: maxTokens cannot be less than 1") } - if len(strings.TrimSpace(name)) == 0 { - panic("rate.NewSemaphore: name should not be empty") + if len(strings.TrimSpace(scope)) == 0 { + panic("rate.NewSemaphore: scope should not be empty") } return &Semaphore{ - rc: rc, - name: name, - maxConcurrency: maxConcurrency, + rc: rc, + scope: scope, + maxTokens: maxTokens, } } -// Semaphore is a distributed counting semaphore which can be used to set maxConcurrency across multiple asynq servers. +// Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers. type Semaphore struct { - rc redis.UniversalClient - maxConcurrency int - name string + rc redis.UniversalClient + maxTokens int + scope string } // KEYS[1] -> asynq:sema: @@ -47,9 +47,9 @@ type Semaphore struct { // ARGV[4] -> task ID var acquireCmd = redis.NewScript(` redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1) -local lockCount = redis.call("ZCARD", KEYS[1]) +local count = redis.call("ZCARD", KEYS[1]) -if (lockCount < tonumber(ARGV[1])) then +if (count < tonumber(ARGV[1])) then redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4]) return true else @@ -57,13 +57,13 @@ else end `) -// Acquire will try to acquire lock on the counting semaphore. -// - Returns (true, nil), iff semaphore key exists and current value is less than maxConcurrency -// - Returns (false, nil) when lock cannot be acquired +// Acquire attempts to acquire a token from the semaphore. +// - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens +// - Returns (false, nil) when token cannot be acquired // - Returns (false, error) otherwise // // The context.Context passed to Acquire must have a deadline set, -// this ensures that lock is released if the job goroutine crashes and does not call Release. +// this ensures that token is released if the job goroutine crashes and does not call Release. func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { d, ok := ctx.Deadline() if !ok { @@ -76,13 +76,12 @@ func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { } b, err := acquireCmd.Run(ctx, s.rc, - []string{semaphoreKey(s.name)}, - []interface{}{ - s.maxConcurrency, - time.Now().Unix(), - d.Unix(), - taskID, - }...).Bool() + []string{semaphoreKey(s.scope)}, + s.maxTokens, + time.Now().Unix(), + d.Unix(), + taskID, + ).Bool() if err == redis.Nil { return b, nil } @@ -90,26 +89,26 @@ func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { return b, err } -// Release will release the lock on the counting semaphore. +// Release will release the token on the counting semaphore. func (s *Semaphore) Release(ctx context.Context) error { taskID, ok := asynqcontext.GetTaskID(ctx) if !ok { return fmt.Errorf("provided context is missing task ID value") } - n, err := s.rc.ZRem(ctx, semaphoreKey(s.name), taskID).Result() + n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result() if err != nil { return fmt.Errorf("redis command failed: %v", err) } if n == 0 { - return fmt.Errorf("no lock found for task %q", taskID) + return fmt.Errorf("no token found for task %q", taskID) } return nil } -// Close closes the connection with redis. +// Close closes the connection to redis. func (s *Semaphore) Close() error { return s.rc.Close() } diff --git a/x/rate/semaphore_test.go b/x/rate/semaphore_test.go index 114261a..e9ee484 100644 --- a/x/rate/semaphore_test.go +++ b/x/rate/semaphore_test.go @@ -43,14 +43,14 @@ func TestNewSemaphore(t *testing.T) { connOpt: &badConnOpt{}, }, { - desc: "Zero maxConcurrency should panic", - wantPanic: "rate.NewSemaphore: maxConcurrency cannot be less than 1", + desc: "Zero maxTokens should panic", + wantPanic: "rate.NewSemaphore: maxTokens cannot be less than 1", }, { - desc: "Empty name should panic", + desc: "Empty scope should panic", maxConcurrency: 2, name: " ", - wantPanic: "rate.NewSemaphore: name should not be empty", + wantPanic: "rate.NewSemaphore: scope should not be empty", }, } @@ -86,7 +86,7 @@ func TestNewSemaphore_Acquire(t *testing.T) { wantErr string }{ { - desc: "Should acquire lock when current lock count is less than maxConcurrency", + desc: "Should acquire lock when current lock count is less than maxTokens", name: "task-1", maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, @@ -99,7 +99,7 @@ func TestNewSemaphore_Acquire(t *testing.T) { want: []bool{true, true}, }, { - desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency", + desc: "Should fail acquiring lock when current lock count is equal to maxTokens", name: "task-2", maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()}, @@ -250,7 +250,7 @@ func TestNewSemaphore_Release(t *testing.T) { }, time.Now().Add(time.Second)) }, wantCount: 1, - wantErr: fmt.Sprintf("no lock found for task %q", testID.String()), + wantErr: fmt.Sprintf("no token found for task %q", testID.String()), }, }