mirror of
https://github.com/hibiken/asynq.git
synced 2024-12-25 23:32:17 +08:00
Add asynq/x/rate package
- Added a directory /x for external, experimental packeges - Added a `rate` package to enable rate limiting across multiple asynq worker servers
This commit is contained in:
parent
0d2c0f612b
commit
23c522dc9f
3
.gitignore
vendored
3
.gitignore
vendored
@ -21,4 +21,5 @@
|
|||||||
.asynq.*
|
.asynq.*
|
||||||
|
|
||||||
# Ignore editor config files
|
# Ignore editor config files
|
||||||
.vscode
|
.vscode
|
||||||
|
.idea
|
55
context.go
55
context.go
@ -6,49 +6,16 @@ package asynq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/hibiken/asynq/internal/base"
|
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A taskMetadata holds task scoped data to put in context.
|
|
||||||
type taskMetadata struct {
|
|
||||||
id string
|
|
||||||
maxRetry int
|
|
||||||
retryCount int
|
|
||||||
qname string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ctxKey type is unexported to prevent collisions with context keys defined in
|
|
||||||
// other packages.
|
|
||||||
type ctxKey int
|
|
||||||
|
|
||||||
// metadataCtxKey is the context key for the task metadata.
|
|
||||||
// Its value of zero is arbitrary.
|
|
||||||
const metadataCtxKey ctxKey = 0
|
|
||||||
|
|
||||||
// createContext returns a context and cancel function for a given task message.
|
|
||||||
func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
|
|
||||||
metadata := taskMetadata{
|
|
||||||
id: msg.ID.String(),
|
|
||||||
maxRetry: msg.Retry,
|
|
||||||
retryCount: msg.Retried,
|
|
||||||
qname: msg.Queue,
|
|
||||||
}
|
|
||||||
ctx := context.WithValue(context.Background(), metadataCtxKey, metadata)
|
|
||||||
return context.WithDeadline(ctx, deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTaskID extracts a task ID from a context, if any.
|
// GetTaskID extracts a task ID from a context, if any.
|
||||||
//
|
//
|
||||||
// ID of a task is guaranteed to be unique.
|
// ID of a task is guaranteed to be unique.
|
||||||
// ID of a task doesn't change if the task is being retried.
|
// ID of a task doesn't change if the task is being retried.
|
||||||
func GetTaskID(ctx context.Context) (id string, ok bool) {
|
func GetTaskID(ctx context.Context) (id string, ok bool) {
|
||||||
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
return asynqcontext.GetTaskID(ctx)
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return metadata.id, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRetryCount extracts retry count from a context, if any.
|
// GetRetryCount extracts retry count from a context, if any.
|
||||||
@ -56,11 +23,7 @@ func GetTaskID(ctx context.Context) (id string, ok bool) {
|
|||||||
// Return value n indicates the number of times associated task has been
|
// Return value n indicates the number of times associated task has been
|
||||||
// retried so far.
|
// retried so far.
|
||||||
func GetRetryCount(ctx context.Context) (n int, ok bool) {
|
func GetRetryCount(ctx context.Context) (n int, ok bool) {
|
||||||
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
return asynqcontext.GetRetryCount(ctx)
|
||||||
if !ok {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return metadata.retryCount, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMaxRetry extracts maximum retry from a context, if any.
|
// GetMaxRetry extracts maximum retry from a context, if any.
|
||||||
@ -68,20 +31,12 @@ func GetRetryCount(ctx context.Context) (n int, ok bool) {
|
|||||||
// Return value n indicates the maximum number of times the assoicated task
|
// Return value n indicates the maximum number of times the assoicated task
|
||||||
// can be retried if ProcessTask returns a non-nil error.
|
// can be retried if ProcessTask returns a non-nil error.
|
||||||
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
|
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
|
||||||
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
return asynqcontext.GetMaxRetry(ctx)
|
||||||
if !ok {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return metadata.maxRetry, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetQueueName extracts queue name from a context, if any.
|
// GetQueueName extracts queue name from a context, if any.
|
||||||
//
|
//
|
||||||
// Return value qname indicates which queue the task was pulled from.
|
// Return value qname indicates which queue the task was pulled from.
|
||||||
func GetQueueName(ctx context.Context) (qname string, ok bool) {
|
func GetQueueName(ctx context.Context) (qname string, ok bool) {
|
||||||
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
return asynqcontext.GetQueueName(ctx)
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return metadata.qname, true
|
|
||||||
}
|
}
|
||||||
|
87
internal/context/context.go
Normal file
87
internal/context/context.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Copyright 2020 Kentaro Hibino. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT license
|
||||||
|
// that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package context
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hibiken/asynq/internal/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A taskMetadata holds task scoped data to put in context.
|
||||||
|
type taskMetadata struct {
|
||||||
|
id string
|
||||||
|
maxRetry int
|
||||||
|
retryCount int
|
||||||
|
qname string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ctxKey type is unexported to prevent collisions with context keys defined in
|
||||||
|
// other packages.
|
||||||
|
type ctxKey int
|
||||||
|
|
||||||
|
// metadataCtxKey is the context key for the task metadata.
|
||||||
|
// Its value of zero is arbitrary.
|
||||||
|
const metadataCtxKey ctxKey = 0
|
||||||
|
|
||||||
|
// New returns a context and cancel function for a given task message.
|
||||||
|
func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
|
||||||
|
metadata := taskMetadata{
|
||||||
|
id: msg.ID.String(),
|
||||||
|
maxRetry: msg.Retry,
|
||||||
|
retryCount: msg.Retried,
|
||||||
|
qname: msg.Queue,
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(context.Background(), metadataCtxKey, metadata)
|
||||||
|
return context.WithDeadline(ctx, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskID extracts a task ID from a context, if any.
|
||||||
|
//
|
||||||
|
// ID of a task is guaranteed to be unique.
|
||||||
|
// ID of a task doesn't change if the task is being retried.
|
||||||
|
func GetTaskID(ctx context.Context) (id string, ok bool) {
|
||||||
|
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return metadata.id, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRetryCount extracts retry count from a context, if any.
|
||||||
|
//
|
||||||
|
// Return value n indicates the number of times associated task has been
|
||||||
|
// retried so far.
|
||||||
|
func GetRetryCount(ctx context.Context) (n int, ok bool) {
|
||||||
|
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return metadata.retryCount, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxRetry extracts maximum retry from a context, if any.
|
||||||
|
//
|
||||||
|
// Return value n indicates the maximum number of times the assoicated task
|
||||||
|
// can be retried if ProcessTask returns a non-nil error.
|
||||||
|
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
|
||||||
|
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return metadata.maxRetry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQueueName extracts queue name from a context, if any.
|
||||||
|
//
|
||||||
|
// Return value qname indicates which queue the task was pulled from.
|
||||||
|
func GetQueueName(ctx context.Context) (qname string, ok bool) {
|
||||||
|
metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return metadata.qname, true
|
||||||
|
}
|
@ -2,7 +2,7 @@
|
|||||||
// Use of this source code is governed by a MIT license
|
// Use of this source code is governed by a MIT license
|
||||||
// that can be found in the LICENSE file.
|
// that can be found in the LICENSE file.
|
||||||
|
|
||||||
package asynq
|
package context
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) {
|
|||||||
Payload: nil,
|
Payload: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := createContext(msg, tc.deadline)
|
ctx, cancel := New(msg, tc.deadline)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case x := <-ctx.Done():
|
case x := <-ctx.Done():
|
||||||
@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) {
|
|||||||
Payload: nil,
|
Payload: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := createContext(msg, tc.deadline)
|
ctx, cancel := New(msg, tc.deadline)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute))
|
ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
id, ok := GetTaskID(ctx)
|
id, ok := GetTaskID(ctx)
|
@ -16,6 +16,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hibiken/asynq/internal/base"
|
"github.com/hibiken/asynq/internal/base"
|
||||||
|
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||||
"github.com/hibiken/asynq/internal/errors"
|
"github.com/hibiken/asynq/internal/errors"
|
||||||
"github.com/hibiken/asynq/internal/log"
|
"github.com/hibiken/asynq/internal/log"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
@ -189,7 +190,7 @@ func (p *processor) exec() {
|
|||||||
<-p.sema // release token
|
<-p.sema // release token
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ctx, cancel := createContext(msg, deadline)
|
ctx, cancel := asynqcontext.New(msg, deadline)
|
||||||
p.cancelations.Add(msg.ID.String(), cancel)
|
p.cancelations.Add(msg.ID.String(), cancel)
|
||||||
defer func() {
|
defer func() {
|
||||||
cancel()
|
cancel()
|
||||||
|
40
x/rate/example_test.go
Normal file
40
x/rate/example_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package rate_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hibiken/asynq"
|
||||||
|
"github.com/hibiken/asynq/x/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RateLimitError struct {
|
||||||
|
RetryIn time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RateLimitError) Error() string {
|
||||||
|
return fmt.Sprintf("rate limited (retry in %v)", e.RetryIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleNewSemaphore() {
|
||||||
|
redisConnOpt := asynq.RedisClientOpt{Addr: ":6379"}
|
||||||
|
sema := rate.NewSemaphore(redisConnOpt, "my_queue", 10)
|
||||||
|
// call sema.Close() when appropriate
|
||||||
|
|
||||||
|
_ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error {
|
||||||
|
ok, err := sema.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return &RateLimitError{RetryIn: 30 * time.Second}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure to release the token once we're done.
|
||||||
|
defer sema.Release(ctx)
|
||||||
|
|
||||||
|
// Process task
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
114
x/rate/semaphore.go
Normal file
114
x/rate/semaphore.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
// Package rate contains rate limiting strategies for asynq.Handler(s).
|
||||||
|
package rate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/hibiken/asynq"
|
||||||
|
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSemaphore creates a counting Semaphore for the given scope with the given number of tokens.
|
||||||
|
func NewSemaphore(rco asynq.RedisConnOpt, scope string, maxTokens int) *Semaphore {
|
||||||
|
rc, ok := rco.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("rate.NewSemaphore: unsupported RedisConnOpt type %T", rco))
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTokens < 1 {
|
||||||
|
panic("rate.NewSemaphore: maxTokens cannot be less than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(strings.TrimSpace(scope)) == 0 {
|
||||||
|
panic("rate.NewSemaphore: scope should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Semaphore{
|
||||||
|
rc: rc,
|
||||||
|
scope: scope,
|
||||||
|
maxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Semaphore is a distributed counting semaphore which can be used to set maxTokens across multiple asynq servers.
|
||||||
|
type Semaphore struct {
|
||||||
|
rc redis.UniversalClient
|
||||||
|
maxTokens int
|
||||||
|
scope string
|
||||||
|
}
|
||||||
|
|
||||||
|
// KEYS[1] -> asynq:sema:<scope>
|
||||||
|
// ARGV[1] -> max concurrency
|
||||||
|
// ARGV[2] -> current time in unix time
|
||||||
|
// ARGV[3] -> deadline in unix time
|
||||||
|
// ARGV[4] -> task ID
|
||||||
|
var acquireCmd = redis.NewScript(`
|
||||||
|
redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1)
|
||||||
|
local count = redis.call("ZCARD", KEYS[1])
|
||||||
|
|
||||||
|
if (count < tonumber(ARGV[1])) then
|
||||||
|
redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4])
|
||||||
|
return 'true'
|
||||||
|
else
|
||||||
|
return 'false'
|
||||||
|
end
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Acquire attempts to acquire a token from the semaphore.
|
||||||
|
// - Returns (true, nil), iff semaphore key exists and current value is less than maxTokens
|
||||||
|
// - Returns (false, nil) when token cannot be acquired
|
||||||
|
// - Returns (false, error) otherwise
|
||||||
|
//
|
||||||
|
// The context.Context passed to Acquire must have a deadline set,
|
||||||
|
// this ensures that token is released if the job goroutine crashes and does not call Release.
|
||||||
|
func (s *Semaphore) Acquire(ctx context.Context) (bool, error) {
|
||||||
|
d, ok := ctx.Deadline()
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("provided context must have a deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, ok := asynqcontext.GetTaskID(ctx)
|
||||||
|
if !ok {
|
||||||
|
return false, fmt.Errorf("provided context is missing task ID value")
|
||||||
|
}
|
||||||
|
|
||||||
|
return acquireCmd.Run(ctx, s.rc,
|
||||||
|
[]string{semaphoreKey(s.scope)},
|
||||||
|
s.maxTokens,
|
||||||
|
time.Now().Unix(),
|
||||||
|
d.Unix(),
|
||||||
|
taskID,
|
||||||
|
).Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release will release the token on the counting semaphore.
|
||||||
|
func (s *Semaphore) Release(ctx context.Context) error {
|
||||||
|
taskID, ok := asynqcontext.GetTaskID(ctx)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("provided context is missing task ID value")
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := s.rc.ZRem(ctx, semaphoreKey(s.scope), taskID).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("redis command failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return fmt.Errorf("no token found for task %q", taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection to redis.
|
||||||
|
func (s *Semaphore) Close() error {
|
||||||
|
return s.rc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func semaphoreKey(scope string) string {
|
||||||
|
return fmt.Sprintf("asynq:sema:%s", scope)
|
||||||
|
}
|
407
x/rate/semaphore_test.go
Normal file
407
x/rate/semaphore_test.go
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
package rate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/hibiken/asynq"
|
||||||
|
"github.com/hibiken/asynq/internal/base"
|
||||||
|
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
redisAddr string
|
||||||
|
redisDB int
|
||||||
|
|
||||||
|
useRedisCluster bool
|
||||||
|
redisClusterAddrs string // comma-separated list of host:port
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&redisAddr, "redis_addr", "localhost:6379", "redis address to use in testing")
|
||||||
|
flag.IntVar(&redisDB, "redis_db", 14, "redis db number to use in testing")
|
||||||
|
flag.BoolVar(&useRedisCluster, "redis_cluster", false, "use redis cluster as a broker in testing")
|
||||||
|
flag.StringVar(&redisClusterAddrs, "redis_cluster_addrs", "localhost:7000,localhost:7001,localhost:7002", "comma separated list of redis server addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
maxConcurrency int
|
||||||
|
wantPanic string
|
||||||
|
connOpt asynq.RedisConnOpt
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Bad RedisConnOpt",
|
||||||
|
wantPanic: "rate.NewSemaphore: unsupported RedisConnOpt type *rate.badConnOpt",
|
||||||
|
connOpt: &badConnOpt{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Zero maxTokens should panic",
|
||||||
|
wantPanic: "rate.NewSemaphore: maxTokens cannot be less than 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Empty scope should panic",
|
||||||
|
maxConcurrency: 2,
|
||||||
|
name: " ",
|
||||||
|
wantPanic: "rate.NewSemaphore: scope should not be empty",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
if tt.wantPanic != "" {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r.(string) != tt.wantPanic {
|
||||||
|
t.Errorf("%s;\nNewSemaphore should panic with msg: %s, got %s", tt.desc, tt.wantPanic, r.(string))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
opt := tt.connOpt
|
||||||
|
if tt.connOpt == nil {
|
||||||
|
opt = getRedisConnOpt(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||||
|
defer sema.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore_Acquire(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
maxConcurrency int
|
||||||
|
taskIDs []uuid.UUID
|
||||||
|
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||||
|
want []bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Should acquire token when current token count is less than maxTokens",
|
||||||
|
name: "task-1",
|
||||||
|
maxConcurrency: 3,
|
||||||
|
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||||
|
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: id,
|
||||||
|
Queue: "task-1",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
},
|
||||||
|
want: []bool{true, true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Should fail acquiring token when current token count is equal to maxTokens",
|
||||||
|
name: "task-2",
|
||||||
|
maxConcurrency: 3,
|
||||||
|
taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()},
|
||||||
|
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: id,
|
||||||
|
Queue: "task-2",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
},
|
||||||
|
want: []bool{true, true, true, false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
opt := getRedisConnOpt(t)
|
||||||
|
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||||
|
defer sema.Close()
|
||||||
|
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||||
|
|
||||||
|
got, err := sema.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s;\nSemaphore.Acquire() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.want[i] {
|
||||||
|
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore_Acquire_Error(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
maxConcurrency int
|
||||||
|
taskIDs []uuid.UUID
|
||||||
|
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||||
|
errStr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Should return error if context has no deadline",
|
||||||
|
name: "task-3",
|
||||||
|
maxConcurrency: 1,
|
||||||
|
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||||
|
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return context.Background(), func() {}
|
||||||
|
},
|
||||||
|
errStr: "provided context must have a deadline",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Should return error when context is missing taskID",
|
||||||
|
name: "task-4",
|
||||||
|
maxConcurrency: 1,
|
||||||
|
taskIDs: []uuid.UUID{uuid.New()},
|
||||||
|
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), time.Second)
|
||||||
|
},
|
||||||
|
errStr: "provided context is missing task ID value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
opt := getRedisConnOpt(t)
|
||||||
|
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||||
|
defer sema.Close()
|
||||||
|
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||||
|
|
||||||
|
_, err := sema.Acquire(ctx)
|
||||||
|
if err == nil || err.Error() != tt.errStr {
|
||||||
|
t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.errStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore_Acquire_StaleToken(t *testing.T) {
|
||||||
|
opt := getRedisConnOpt(t)
|
||||||
|
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
taskID := uuid.New()
|
||||||
|
|
||||||
|
// adding a set member to mimic the case where token is acquired but the goroutine crashed,
|
||||||
|
// in which case, the token will not be explicitly removed and should be present already
|
||||||
|
rc.ZAdd(context.Background(), semaphoreKey("stale-token"), &redis.Z{
|
||||||
|
Score: float64(time.Now().Add(-10 * time.Second).Unix()),
|
||||||
|
Member: taskID.String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, "stale-token", 1)
|
||||||
|
defer sema.Close()
|
||||||
|
|
||||||
|
ctx, cancel := asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: taskID,
|
||||||
|
Queue: "task-1",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
got, err := sema.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Acquire_StaleToken;\nSemaphore.Acquire() got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !got {
|
||||||
|
t.Error("Acquire_StaleToken;\nSemaphore.Acquire() got false want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore_Release(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
taskIDs []uuid.UUID
|
||||||
|
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||||
|
wantCount int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Should decrease token count",
|
||||||
|
name: "task-5",
|
||||||
|
taskIDs: []uuid.UUID{uuid.New()},
|
||||||
|
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: id,
|
||||||
|
Queue: "task-3",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Should decrease token count by 2",
|
||||||
|
name: "task-6",
|
||||||
|
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||||
|
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: id,
|
||||||
|
Queue: "task-4",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
opt := getRedisConnOpt(t)
|
||||||
|
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var members []*redis.Z
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
members = append(members, &redis.Z{
|
||||||
|
Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()),
|
||||||
|
Member: tt.taskIDs[i].String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, tt.name, 3)
|
||||||
|
defer sema.Close()
|
||||||
|
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||||
|
|
||||||
|
if err := sema.Release(ctx); err != nil {
|
||||||
|
t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := rc.ZCount(context.Background(), semaphoreKey(tt.name), "-inf", "+inf").Result()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.ZCount() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != tt.wantCount {
|
||||||
|
t.Errorf("%s;\nSemaphore.Release(ctx) didn't release token, got %v want 0", tt.desc, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSemaphore_Release_Error(t *testing.T) {
|
||||||
|
testID := uuid.New()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
name string
|
||||||
|
taskIDs []uuid.UUID
|
||||||
|
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||||
|
errStr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Should return error when context is missing taskID",
|
||||||
|
name: "task-7",
|
||||||
|
taskIDs: []uuid.UUID{uuid.New()},
|
||||||
|
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), time.Second)
|
||||||
|
},
|
||||||
|
errStr: "provided context is missing task ID value",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Should return error when context has taskID which never acquired token",
|
||||||
|
name: "task-8",
|
||||||
|
taskIDs: []uuid.UUID{uuid.New()},
|
||||||
|
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||||
|
return asynqcontext.New(&base.TaskMessage{
|
||||||
|
ID: testID,
|
||||||
|
Queue: "task-4",
|
||||||
|
}, time.Now().Add(time.Second))
|
||||||
|
},
|
||||||
|
errStr: fmt.Sprintf("no token found for task %q", testID.String()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
opt := getRedisConnOpt(t)
|
||||||
|
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if err := rc.Del(context.Background(), semaphoreKey(tt.name)).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var members []*redis.Z
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
members = append(members, &redis.Z{
|
||||||
|
Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()),
|
||||||
|
Member: tt.taskIDs[i].String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil {
|
||||||
|
t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sema := NewSemaphore(opt, tt.name, 3)
|
||||||
|
defer sema.Close()
|
||||||
|
|
||||||
|
for i := 0; i < len(tt.taskIDs); i++ {
|
||||||
|
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||||
|
|
||||||
|
if err := sema.Release(ctx); err == nil || err.Error() != tt.errStr {
|
||||||
|
t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.errStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRedisConnOpt(tb testing.TB) asynq.RedisConnOpt {
|
||||||
|
tb.Helper()
|
||||||
|
if useRedisCluster {
|
||||||
|
addrs := strings.Split(redisClusterAddrs, ",")
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
tb.Fatal("No redis cluster addresses provided. Please set addresses using --redis_cluster_addrs flag.")
|
||||||
|
}
|
||||||
|
return asynq.RedisClusterClientOpt{
|
||||||
|
Addrs: addrs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return asynq.RedisClientOpt{
|
||||||
|
Addr: redisAddr,
|
||||||
|
DB: redisDB,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type badConnOpt struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b badConnOpt) MakeRedisClient() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user