From 23c522dc9feb95e055786f8fff9d04b973eedbb6 Mon Sep 17 00:00:00 2001 From: Ajat Prabha Date: Thu, 4 Nov 2021 04:25:23 +0530 Subject: [PATCH] Add asynq/x/rate package - Added a directory /x for external, experimental packeges - Added a `rate` package to enable rate limiting across multiple asynq worker servers --- .gitignore | 3 +- context.go | 55 +-- internal/context/context.go | 87 ++++ .../context/context_test.go | 8 +- processor.go | 3 +- x/rate/example_test.go | 40 ++ x/rate/semaphore.go | 114 +++++ x/rate/semaphore_test.go | 407 ++++++++++++++++++ 8 files changed, 661 insertions(+), 56 deletions(-) create mode 100644 internal/context/context.go rename context_test.go => internal/context/context_test.go (95%) create mode 100644 x/rate/example_test.go create mode 100644 x/rate/semaphore.go create mode 100644 x/rate/semaphore_test.go diff --git a/.gitignore b/.gitignore index 63a7360..b443d16 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,5 @@ .asynq.* # Ignore editor config files -.vscode \ No newline at end of file +.vscode +.idea \ No newline at end of file diff --git a/context.go b/context.go index 4710bbd..6eac87d 100644 --- a/context.go +++ b/context.go @@ -6,49 +6,16 @@ package asynq import ( "context" - "time" - "github.com/hibiken/asynq/internal/base" + asynqcontext "github.com/hibiken/asynq/internal/context" ) -// A taskMetadata holds task scoped data to put in context. -type taskMetadata struct { - id string - maxRetry int - retryCount int - qname string -} - -// ctxKey type is unexported to prevent collisions with context keys defined in -// other packages. -type ctxKey int - -// metadataCtxKey is the context key for the task metadata. -// Its value of zero is arbitrary. -const metadataCtxKey ctxKey = 0 - -// createContext returns a context and cancel function for a given task message. -func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) { - metadata := taskMetadata{ - id: msg.ID.String(), - maxRetry: msg.Retry, - retryCount: msg.Retried, - qname: msg.Queue, - } - ctx := context.WithValue(context.Background(), metadataCtxKey, metadata) - return context.WithDeadline(ctx, deadline) -} - // GetTaskID extracts a task ID from a context, if any. // // ID of a task is guaranteed to be unique. // ID of a task doesn't change if the task is being retried. func GetTaskID(ctx context.Context) (id string, ok bool) { - metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) - if !ok { - return "", false - } - return metadata.id, true + return asynqcontext.GetTaskID(ctx) } // GetRetryCount extracts retry count from a context, if any. @@ -56,11 +23,7 @@ func GetTaskID(ctx context.Context) (id string, ok bool) { // Return value n indicates the number of times associated task has been // retried so far. func GetRetryCount(ctx context.Context) (n int, ok bool) { - metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) - if !ok { - return 0, false - } - return metadata.retryCount, true + return asynqcontext.GetRetryCount(ctx) } // GetMaxRetry extracts maximum retry from a context, if any. @@ -68,20 +31,12 @@ func GetRetryCount(ctx context.Context) (n int, ok bool) { // Return value n indicates the maximum number of times the assoicated task // can be retried if ProcessTask returns a non-nil error. func GetMaxRetry(ctx context.Context) (n int, ok bool) { - metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) - if !ok { - return 0, false - } - return metadata.maxRetry, true + return asynqcontext.GetMaxRetry(ctx) } // GetQueueName extracts queue name from a context, if any. // // Return value qname indicates which queue the task was pulled from. func GetQueueName(ctx context.Context) (qname string, ok bool) { - metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) - if !ok { - return "", false - } - return metadata.qname, true + return asynqcontext.GetQueueName(ctx) } diff --git a/internal/context/context.go b/internal/context/context.go new file mode 100644 index 0000000..9b233ce --- /dev/null +++ b/internal/context/context.go @@ -0,0 +1,87 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package context + +import ( + "context" + "time" + + "github.com/hibiken/asynq/internal/base" +) + +// A taskMetadata holds task scoped data to put in context. +type taskMetadata struct { + id string + maxRetry int + retryCount int + qname string +} + +// ctxKey type is unexported to prevent collisions with context keys defined in +// other packages. +type ctxKey int + +// metadataCtxKey is the context key for the task metadata. +// Its value of zero is arbitrary. +const metadataCtxKey ctxKey = 0 + +// New returns a context and cancel function for a given task message. +func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) { + metadata := taskMetadata{ + id: msg.ID.String(), + maxRetry: msg.Retry, + retryCount: msg.Retried, + qname: msg.Queue, + } + ctx := context.WithValue(context.Background(), metadataCtxKey, metadata) + return context.WithDeadline(ctx, deadline) +} + +// GetTaskID extracts a task ID from a context, if any. +// +// ID of a task is guaranteed to be unique. +// ID of a task doesn't change if the task is being retried. +func GetTaskID(ctx context.Context) (id string, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return "", false + } + return metadata.id, true +} + +// GetRetryCount extracts retry count from a context, if any. +// +// Return value n indicates the number of times associated task has been +// retried so far. +func GetRetryCount(ctx context.Context) (n int, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return 0, false + } + return metadata.retryCount, true +} + +// GetMaxRetry extracts maximum retry from a context, if any. +// +// Return value n indicates the maximum number of times the assoicated task +// can be retried if ProcessTask returns a non-nil error. +func GetMaxRetry(ctx context.Context) (n int, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return 0, false + } + return metadata.maxRetry, true +} + +// GetQueueName extracts queue name from a context, if any. +// +// Return value qname indicates which queue the task was pulled from. +func GetQueueName(ctx context.Context) (qname string, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return "", false + } + return metadata.qname, true +} diff --git a/context_test.go b/internal/context/context_test.go similarity index 95% rename from context_test.go rename to internal/context/context_test.go index 305cd19..50498d7 100644 --- a/context_test.go +++ b/internal/context/context_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT license // that can be found in the LICENSE file. -package asynq +package context import ( "context" @@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) { Payload: nil, } - ctx, cancel := createContext(msg, tc.deadline) + ctx, cancel := New(msg, tc.deadline) select { case x := <-ctx.Done(): @@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) { Payload: nil, } - ctx, cancel := createContext(msg, tc.deadline) + ctx, cancel := New(msg, tc.deadline) defer cancel() select { @@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) { } for _, tc := range tests { - ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute)) + ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute)) defer cancel() id, ok := GetTaskID(ctx) diff --git a/processor.go b/processor.go index 5ede2d2..fcf7cb4 100644 --- a/processor.go +++ b/processor.go @@ -16,6 +16,7 @@ import ( "time" "github.com/hibiken/asynq/internal/base" + asynqcontext "github.com/hibiken/asynq/internal/context" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/log" "golang.org/x/time/rate" @@ -189,7 +190,7 @@ func (p *processor) exec() { <-p.sema // release token }() - ctx, cancel := createContext(msg, deadline) + ctx, cancel := asynqcontext.New(msg, deadline) p.cancelations.Add(msg.ID.String(), cancel) defer func() { cancel() diff --git a/x/rate/example_test.go b/x/rate/example_test.go new file mode 100644 index 0000000..37016a4 --- /dev/null +++ b/x/rate/example_test.go @@ -0,0 +1,40 @@ +package rate_test + +import ( + "context" + "fmt" + "time" + + "github.com/hibiken/asynq" + "github.com/hibiken/asynq/x/rate" +) + +type RateLimitError struct { + RetryIn time.Duration +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limited (retry in %v)", e.RetryIn) +} + +func ExampleNewSemaphore() { + redisConnOpt := asynq.RedisClientOpt{Addr: ":6379"} + sema := rate.NewSemaphore(redisConnOpt, "my_queue", 10) + // call sema.Close() when appropriate + + _ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error { + ok, err := sema.Acquire(ctx) + if err != nil { + return err + } + if !ok { + return &RateLimitError{RetryIn: 30 * time.Second} + } + + // Make sure to release the token once we're done. + defer sema.Release(ctx) + + // Process task + return nil + }) +} diff --git a/x/rate/semaphore.go b/x/rate/semaphore.go new file mode 100644 index 0000000..5a422a5 --- /dev/null +++ b/x/rate/semaphore.go @@ -0,0 +1,114 @@ +// Package rate contains rate limiting strategies for asynq.Handler(s). +package rate + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "github.com/hibiken/asynq" + asynqcontext "github.com/hibiken/asynq/internal/context" +) + +// NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens. +func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore { + rc, ok := rco.MakeRedisClient().(redis.UniversalClient) + if !ok { + panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco)) + } + + if maxTokens < 1 { + panic("rate.NewSemaphore: maxTokens cannot be less than 1") + } + + if len(strings.TrimSpace(scope)) == 0 { + panic("rate.NewSemaphore: scope should not be empty") + } + + return &Semaphore{ + rc: rc, + scope: scope, + maxTokens: maxTokens, + } +} + +// Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers. +type Semaphore struct { + rc redis.UniversalClient + maxTokens int + scope string +} + +// KEYS[1] -> asynq:sema: +// ARGV[1] -> max concurrency +// ARGV[2] -> current time in unix time +// ARGV[3] -> deadline in unix time +// ARGV[4] -> task ID +var acquireCmd = redis.NewScript(` +redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1) +local count = redis.call("ZCARD", KEYS[1]) + +if (count < tonumber(ARGV[1])) then + redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4]) + return 'true' +else + return 'false' +end +`) + +// Acquire attempts to acquire a token from the semaphore. +// - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens +// - Returns (false, nil) when token cannot be acquired +// - Returns (false, error) otherwise +// +// The context.Context passed to Acquire must have a deadline set, +// this ensures that token is released if the job goroutine crashes and does not call Release. +func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { + d, ok := ctx.Deadline() + if !ok { + return false, fmt.Errorf("provided context must have a deadline") + } + + taskID, ok := asynqcontext.GetTaskID(ctx) + if !ok { + return false, fmt.Errorf("provided context is missing task ID value") + } + + return acquireCmd.Run(ctx, s.rc, + []string{semaphoreKey(s.scope)}, + s.maxTokens, + time.Now().Unix(), + d.Unix(), + taskID, + ).Bool() +} + +// Release will release the token on the counting semaphore. +func (s *Semaphore) Release(ctx context.Context) error { + taskID, ok := asynqcontext.GetTaskID(ctx) + if !ok { + return fmt.Errorf("provided context is missing task ID value") + } + + n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result() + if err != nil { + return fmt.Errorf("redis command failed: %v", err) + } + + if n == 0 { + return fmt.Errorf("no token found for task %q", taskID) + } + + return nil +} + +// Close closes the connection to redis. +func (s *Semaphore) Close() error { + return s.rc.Close() +} + +func semaphoreKey(scope string) string { + return fmt.Sprintf("asynq:sema:%s", scope) +} diff --git a/x/rate/semaphore_test.go b/x/rate/semaphore_test.go new file mode 100644 index 0000000..6273687 --- /dev/null +++ b/x/rate/semaphore_test.go @@ -0,0 +1,407 @@ +package rate + +import ( + "context" + "flag" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/hibiken/asynq" + "github.com/hibiken/asynq/internal/base" + asynqcontext "github.com/hibiken/asynq/internal/context" + "strings" + "testing" + "time" +) + +var ( + redisAddr string + redisDB int + + useRedisCluster bool + redisClusterAddrs string // comma-separated list of host:port +) + +func init() { + flag.StringVar(&redisAddr, "redis_addr", "localhost:6379", "redis address to use in testing") + flag.IntVar(&redisDB, "redis_db", 14, "redis db number to use in testing") + flag.BoolVar(&useRedisCluster, "redis_cluster", false, "use redis cluster as a broker in testing") + flag.StringVar(&redisClusterAddrs, "redis_cluster_addrs", "localhost:7000,localhost:7001,localhost:7002", "comma separated list of redis server addresses") +} + +func TestNewSemaphore(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + wantPanic string + connOpt asynq.RedisConnOpt + }{ + { + desc: "Bad RedisConnOpt", + wantPanic: "rate.NewSemaphore: unsupported RedisConnOpt type *rate.badConnOpt", + connOpt: &badConnOpt{}, + }, + { + desc: "Zero maxTokens should panic", + wantPanic: "rate.NewSemaphore: maxTokens cannot be less than 1", + }, + { + desc: "Empty scope should panic", + maxConcurrency: 2, + name: " ", + wantPanic: "rate.NewSemaphore: scope should not be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if tt.wantPanic != "" { + defer func() { + if r := recover(); r.(string) != tt.wantPanic { + t.Errorf("%s;\nNewSemaphore should panic with msg: %s, got %s", tt.desc, tt.wantPanic, r.(string)) + } + }() + } + + opt := tt.connOpt + if tt.connOpt == nil { + opt = getRedisConnOpt(t) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + }) + } +} + +func TestNewSemaphore_Acquire(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + want []bool + }{ + { + desc: "Should acquire token when current token count is less than maxTokens", + name: "task-1", + maxConcurrency: 3, + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-1", + }, time.Now().Add(time.Second)) + }, + want: []bool{true, true}, + }, + { + desc: "Should fail acquiring token when current token count is equal to maxTokens", + name: "task-2", + maxConcurrency: 3, + taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-2", + }, time.Now().Add(time.Second)) + }, + want: []bool{true, true, true, false}, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + got, err := sema.Acquire(ctx) + if err != nil { + t.Errorf("%s;\nSemaphore.Acquire() got error %v", tt.desc, err) + } + + if got != tt.want[i] { + t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i]) + } + + cancel() + } + }) + } +} + +func TestNewSemaphore_Acquire_Error(t *testing.T) { + tests := []struct { + desc string + name string + maxConcurrency int + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + errStr string + }{ + { + desc: "Should return error if context has no deadline", + name: "task-3", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return context.Background(), func() {} + }, + errStr: "provided context must have a deadline", + }, + { + desc: "Should return error when context is missing taskID", + name: "task-4", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + errStr: "provided context is missing task ID value", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + _, err := sema.Acquire(ctx) + if err == nil || err.Error() != tt.errStr { + t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.errStr) + } + + cancel() + } + }) + } +} + +func TestNewSemaphore_Acquire_StaleToken(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + taskID := uuid.New() + + // adding a set member to mimic the case where token is acquired but the goroutine crashed, + // in which case, the token will not be explicitly removed and should be present already + rc.ZAdd(context.Background(), semaphoreKey("stale-token"), &redis.Z{ + Score: float64(time.Now().Add(-10 * time.Second).Unix()), + Member: taskID.String(), + }) + + sema := NewSemaphore(opt, "stale-token", 1) + defer sema.Close() + + ctx, cancel := asynqcontext.New(&base.TaskMessage{ + ID: taskID, + Queue: "task-1", + }, time.Now().Add(time.Second)) + defer cancel() + + got, err := sema.Acquire(ctx) + if err != nil { + t.Errorf("Acquire_StaleToken;\nSemaphore.Acquire() got error %v", err) + } + + if !got { + t.Error("Acquire_StaleToken;\nSemaphore.Acquire() got false want true") + } +} + +func TestNewSemaphore_Release(t *testing.T) { + tests := []struct { + desc string + name string + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + wantCount int64 + }{ + { + desc: "Should decrease token count", + name: "task-5", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-3", + }, time.Now().Add(time.Second)) + }, + }, + { + desc: "Should decrease token count by 2", + name: "task-6", + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-4", + }, time.Now().Add(time.Second)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + var members []*redis.Z + for i := 0; i < len(tt.taskIDs); i++ { + members = append(members, &redis.Z{ + Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()), + Member: tt.taskIDs[i].String(), + }) + } + if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, 3) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + if err := sema.Release(ctx); err != nil { + t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err) + } + + cancel() + } + + i, err := rc.ZCount(context.Background(), semaphoreKey(tt.name), "-inf", "+inf").Result() + if err != nil { + t.Errorf("%s;\nredis.UniversalClient.ZCount() got error %v", tt.desc, err) + } + + if i != tt.wantCount { + t.Errorf("%s;\nSemaphore.Release(ctx) didn't release token, got %v want 0", tt.desc, i) + } + }) + } +} + +func TestNewSemaphore_Release_Error(t *testing.T) { + testID := uuid.New() + + tests := []struct { + desc string + name string + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + errStr string + }{ + { + desc: "Should return error when context is missing taskID", + name: "task-7", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + errStr: "provided context is missing task ID value", + }, + { + desc: "Should return error when context has taskID which never acquired token", + name: "task-8", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: testID, + Queue: "task-4", + }, time.Now().Add(time.Second)) + }, + errStr: fmt.Sprintf("no token found for task %q", testID.String()), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) + } + + var members []*redis.Z + for i := 0; i < len(tt.taskIDs); i++ { + members = append(members, &redis.Z{ + Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()), + Member: tt.taskIDs[i].String(), + }) + } + if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err) + } + + sema := NewSemaphore(opt, tt.name, 3) + defer sema.Close() + + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + if err := sema.Release(ctx); err == nil || err.Error() != tt.errStr { + t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.errStr) + } + + cancel() + } + }) + } +} + +func getRedisConnOpt(tb testing.TB) asynq.RedisConnOpt { + tb.Helper() + if useRedisCluster { + addrs := strings.Split(redisClusterAddrs, ",") + if len(addrs) == 0 { + tb.Fatal("No redis cluster addresses provided. Please set addresses using --redis_cluster_addrs flag.") + } + return asynq.RedisClusterClientOpt{ + Addrs: addrs, + } + } + return asynq.RedisClientOpt{ + Addr: redisAddr, + DB: redisDB, + } +} + +type badConnOpt struct { +} + +func (b badConnOpt) MakeRedisClient() interface{} { + return nil +}