From 34cc2d8081a6c739a2d5cc87b96e3d288a8b4f37 Mon Sep 17 00:00:00 2001 From: ajatprabha Date: Sat, 30 Oct 2021 16:07:46 +0530 Subject: [PATCH] use sorted-set to release stale locks --- alias.go | 38 ++++ benchmark_test.go | 3 +- context.go => internal/context/context.go | 6 +- .../context/context_test.go | 8 +- processor.go | 3 +- x/rate/example_test.go | 7 +- x/rate/semaphore.go | 73 +++++-- x/rate/semaphore_test.go | 199 +++++++++++++++--- 8 files changed, 283 insertions(+), 54 deletions(-) create mode 100644 alias.go rename context.go => internal/context/context.go (92%) rename context_test.go => internal/context/context_test.go (95%) diff --git a/alias.go b/alias.go new file mode 100644 index 0000000..f3eba57 --- /dev/null +++ b/alias.go @@ -0,0 +1,38 @@ +package asynq + +import ( + "context" + + asynqcontext "github.com/hibiken/asynq/internal/context" +) + +// GetTaskID extracts a task ID from a context, if any. +// +// ID of a task is guaranteed to be unique. +// ID of a task doesn't change if the task is being retried. +func GetTaskID(ctx context.Context) (id string, ok bool) { + return asynqcontext.GetTaskID(ctx) +} + +// GetRetryCount extracts retry count from a context, if any. +// +// Return value n indicates the number of times associated task has been +// retried so far. +func GetRetryCount(ctx context.Context) (n int, ok bool) { + return asynqcontext.GetRetryCount(ctx) +} + +// GetMaxRetry extracts maximum retry from a context, if any. +// +// Return value n indicates the maximum number of times the assoicated task +// can be retried if ProcessTask returns a non-nil error. +func GetMaxRetry(ctx context.Context) (n int, ok bool) { + return asynqcontext.GetMaxRetry(ctx) +} + +// GetQueueName extracts queue name from a context, if any. +// +// Return value qname indicates which queue the task was pulled from. +func GetQueueName(ctx context.Context) (qname string, ok bool) { + return asynqcontext.GetQueueName(ctx) +} diff --git a/benchmark_test.go b/benchmark_test.go index b98ea34..042bf43 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -13,6 +13,7 @@ import ( "time" h "github.com/hibiken/asynq/internal/asynqtest" + asynqcontext "github.com/hibiken/asynq/internal/context" ) // Creates a new task of type "task" with payload {"data": n}. @@ -104,7 +105,7 @@ func BenchmarkEndToEnd(b *testing.B) { n = 1 b.Logf("internal error: could not get data from payload") } - retried, ok := GetRetryCount(ctx) + retried, ok := asynqcontext.GetRetryCount(ctx) if !ok { b.Logf("internal error: could not get retry count from context") } diff --git a/context.go b/internal/context/context.go similarity index 92% rename from context.go rename to internal/context/context.go index 4710bbd..9b233ce 100644 --- a/context.go +++ b/internal/context/context.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT license // that can be found in the LICENSE file. -package asynq +package context import ( "context" @@ -27,8 +27,8 @@ type ctxKey int // Its value of zero is arbitrary. const metadataCtxKey ctxKey = 0 -// createContext returns a context and cancel function for a given task message. -func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) { +// New returns a context and cancel function for a given task message. +func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) { metadata := taskMetadata{ id: msg.ID.String(), maxRetry: msg.Retry, diff --git a/context_test.go b/internal/context/context_test.go similarity index 95% rename from context_test.go rename to internal/context/context_test.go index 305cd19..50498d7 100644 --- a/context_test.go +++ b/internal/context/context_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a MIT license // that can be found in the LICENSE file. -package asynq +package context import ( "context" @@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) { Payload: nil, } - ctx, cancel := createContext(msg, tc.deadline) + ctx, cancel := New(msg, tc.deadline) select { case x := <-ctx.Done(): @@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) { Payload: nil, } - ctx, cancel := createContext(msg, tc.deadline) + ctx, cancel := New(msg, tc.deadline) defer cancel() select { @@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) { } for _, tc := range tests { - ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute)) + ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute)) defer cancel() id, ok := GetTaskID(ctx) diff --git a/processor.go b/processor.go index 5ede2d2..fcf7cb4 100644 --- a/processor.go +++ b/processor.go @@ -16,6 +16,7 @@ import ( "time" "github.com/hibiken/asynq/internal/base" + asynqcontext "github.com/hibiken/asynq/internal/context" "github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/log" "golang.org/x/time/rate" @@ -189,7 +190,7 @@ func (p *processor) exec() { <-p.sema // release token }() - ctx, cancel := createContext(msg, deadline) + ctx, cancel := asynqcontext.New(msg, deadline) p.cancelations.Add(msg.ID.String(), cancel) defer func() { cancel() diff --git a/x/rate/example_test.go b/x/rate/example_test.go index 55c414e..37016a4 100644 --- a/x/rate/example_test.go +++ b/x/rate/example_test.go @@ -23,9 +23,14 @@ func ExampleNewSemaphore() { // call sema.Close() when appropriate _ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error { - if !sema.Acquire(ctx) { + ok, err := sema.Acquire(ctx) + if err != nil { + return err + } + if !ok { return &RateLimitError{RetryIn: 30 * time.Second} } + // Make sure to release the token once we're done. defer sema.Release(ctx) diff --git a/x/rate/semaphore.go b/x/rate/semaphore.go index 353882f..7519db1 100644 --- a/x/rate/semaphore.go +++ b/x/rate/semaphore.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "strings" + "time" "github.com/go-redis/redis/v8" "github.com/hibiken/asynq" + asynqcontext "github.com/hibiken/asynq/internal/context" ) // NewSemaphore creates a new counting Semaphore. @@ -38,28 +40,73 @@ type Semaphore struct { name string } +// KEYS[1] -> asynq:sema: +// ARGV[1] -> max concurrency +// ARGV[2] -> current time in unix time +// ARGV[3] -> deadline in unix time +// ARGV[4] -> task ID var acquireCmd = redis.NewScript(` -local lockCount = tonumber(redis.call('GET', KEYS[1])) +redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1) +local lockCount = redis.call("ZCARD", KEYS[1]) -if (not lockCount or lockCount < tonumber(ARGV[1])) then - redis.call('INCR', KEYS[1]) - return true +if (lockCount < tonumber(ARGV[1])) then + redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4]) + return true else - return false + return false end `) // Acquire will try to acquire lock on the counting semaphore. -// - Returns true, iff semaphore key exists and current value is less than maxConcurrency -// - Returns false otherwise -func (s *Semaphore) Acquire(ctx context.Context) bool { - val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool() - return val +// - Returns (true, nil), iff semaphore key exists and current value is less than maxConcurrency +// - Returns (false, nil) when lock cannot be acquired +// - Returns (false, error) otherwise +// +// The context.Context passed to Acquire must have a deadline set, +// this ensures that lock is released if the job goroutine crashes and does not call Release. +func (s *Semaphore) Acquire(ctx context.Context) (bool, error) { + d, ok := ctx.Deadline() + if !ok { + return false, fmt.Errorf("provided context must have a deadline") + } + + taskID, ok := asynqcontext.GetTaskID(ctx) + if !ok { + return false, fmt.Errorf("provided context is missing task ID value") + } + + b, err := acquireCmd.Run(ctx, s.rc, + []string{semaphoreKey(s.name)}, + []interface{}{ + s.maxConcurrency, + time.Now().Unix(), + d.Unix(), + taskID, + }...).Bool() + if err == redis.Nil { + return b, nil + } + + return b, err } // Release will release the lock on the counting semaphore. -func (s *Semaphore) Release(ctx context.Context) { - s.rc.Decr(ctx, semaphoreKey(s.name)) +func (s *Semaphore) Release(ctx context.Context) error { + taskID, ok := asynqcontext.GetTaskID(ctx) + if !ok { + return fmt.Errorf("provided context is missing task ID value") + } + + n, err := s.rc.ZRem(ctx, semaphoreKey(s.name), taskID).Result() + if err != nil { + return fmt.Errorf("redis command failed: %v", err) + } + + if n == 0 { + return fmt.Errorf("no lock found for task %q", taskID) + } + + return nil } // Close closes the connection with redis. @@ -68,5 +115,5 @@ func (s *Semaphore) Close() error { } func semaphoreKey(name string) string { - return fmt.Sprintf("asynq:sema:{%s}", name) + return fmt.Sprintf("asynq:sema:%s", name) } diff --git a/x/rate/semaphore_test.go b/x/rate/semaphore_test.go index e53f4f9..114261a 100644 --- a/x/rate/semaphore_test.go +++ b/x/rate/semaphore_test.go @@ -3,11 +3,15 @@ package rate import ( "context" "flag" + "fmt" + "github.com/go-redis/redis/v8" + "github.com/google/uuid" + "github.com/hibiken/asynq" + "github.com/hibiken/asynq/internal/base" + asynqcontext "github.com/hibiken/asynq/internal/context" "strings" "testing" - - "github.com/go-redis/redis/v8" - "github.com/hibiken/asynq" + "time" ) var ( @@ -76,19 +80,58 @@ func TestNewSemaphore_Acquire(t *testing.T) { desc string name string maxConcurrency int - callAcquire int + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + want []bool + wantErr string }{ { desc: "Should acquire lock when current lock count is less than maxConcurrency", name: "task-1", maxConcurrency: 3, - callAcquire: 2, + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-1", + }, time.Now().Add(time.Second)) + }, + want: []bool{true, true}, }, { desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency", name: "task-2", maxConcurrency: 3, - callAcquire: 4, + taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-2", + }, time.Now().Add(time.Second)) + }, + want: []bool{true, true, true, false}, + }, + { + desc: "Should return error if context has no deadline", + name: "task-3", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return context.Background(), func() {} + }, + want: []bool{false, false}, + wantErr: "provided context must have a deadline", + }, + { + desc: "Should return error when context is missing taskID", + name: "task-4", + maxConcurrency: 1, + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + want: []bool{false}, + wantErr: "provided context is missing task ID value", }, } @@ -105,33 +148,109 @@ func TestNewSemaphore_Acquire(t *testing.T) { sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) defer sema.Close() - for i := 0; i < tt.callAcquire; i++ { - if got := sema.Acquire(context.Background()); got != (i < tt.maxConcurrency) { - t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, i < tt.maxConcurrency) + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + got, err := sema.Acquire(ctx) + if got != tt.want[i] { + t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i]) } + if (tt.wantErr == "" && err != nil) || (tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr)) { + t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.wantErr) + } + + cancel() } }) } } +func TestNewSemaphore_Acquire_StaleLock(t *testing.T) { + opt := getRedisConnOpt(t) + rc := opt.MakeRedisClient().(redis.UniversalClient) + defer rc.Close() + + taskID := uuid.New() + + rc.ZAdd(context.Background(), semaphoreKey("stale-lock"), &redis.Z{ + Score: float64(time.Now().Add(-10 * time.Second).Unix()), + Member: taskID.String(), + }) + + sema := NewSemaphore(opt, "stale-lock", 1) + defer sema.Close() + + ctx, cancel := asynqcontext.New(&base.TaskMessage{ + ID: taskID, + Queue: "task-1", + }, time.Now().Add(time.Second)) + defer cancel() + + got, err := sema.Acquire(ctx) + if err != nil { + t.Errorf("Acquire_StaleLock;\nSemaphore.Acquire() got error %v", err) + } + + if !got { + t.Error("Acquire_StaleLock;\nSemaphore.Acquire() got false want true") + } +} + func TestNewSemaphore_Release(t *testing.T) { + testID := uuid.New() + tests := []struct { - desc string - name string - maxConcurrency int - callRelease int + desc string + name string + taskIDs []uuid.UUID + ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc) + wantCount int64 + wantErr string }{ { - desc: "Should decrease lock count", - name: "task-3", - maxConcurrency: 3, - callRelease: 1, + desc: "Should decrease lock count", + name: "task-5", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-3", + }, time.Now().Add(time.Second)) + }, }, { - desc: "Should decrease lock count by 2", - name: "task-4", - maxConcurrency: 3, - callRelease: 2, + desc: "Should decrease lock count by 2", + name: "task-6", + taskIDs: []uuid.UUID{uuid.New(), uuid.New()}, + ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: id, + Queue: "task-4", + }, time.Now().Add(time.Second)) + }, + }, + { + desc: "Should return error when context is missing taskID", + name: "task-7", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Second) + }, + wantCount: 1, + wantErr: "provided context is missing task ID value", + }, + { + desc: "Should return error when context has taskID which never acquired lock", + name: "task-8", + taskIDs: []uuid.UUID{uuid.New()}, + ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) { + return asynqcontext.New(&base.TaskMessage{ + ID: testID, + Queue: "task-4", + }, time.Now().Add(time.Second)) + }, + wantCount: 1, + wantErr: fmt.Sprintf("no lock found for task %q", testID.String()), }, } @@ -145,25 +264,43 @@ func TestNewSemaphore_Release(t *testing.T) { t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) } - if err := rc.IncrBy(context.Background(), semaphoreKey(tt.name), int64(tt.maxConcurrency)).Err(); err != nil { - t.Errorf("%s;\nredis.UniversalClient.IncrBy() got error %v", tt.desc, err) + var members []*redis.Z + for i := 0; i < len(tt.taskIDs); i++ { + members = append(members, &redis.Z{ + Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()), + Member: tt.taskIDs[i].String(), + }) + } + if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil { + t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err) } - sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) + sema := NewSemaphore(opt, tt.name, 3) defer sema.Close() - for i := 0; i < tt.callRelease; i++ { - sema.Release(context.Background()) + for i := 0; i < len(tt.taskIDs); i++ { + ctx, cancel := tt.ctxFunc(tt.taskIDs[i]) + + err := sema.Release(ctx) + + if tt.wantErr == "" && err != nil { + t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err) + } + + if tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr) { + t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.wantErr) + } + + cancel() } - i, err := rc.Get(context.Background(), semaphoreKey(tt.name)).Int() + i, err := rc.ZCount(context.Background(), semaphoreKey(tt.name), "-inf", "+inf").Result() if err != nil { - t.Errorf("%s;\nredis.UniversalClient.Get() got error %v", tt.desc, err) + t.Errorf("%s;\nredis.UniversalClient.ZCount() got error %v", tt.desc, err) } - if i != tt.maxConcurrency-tt.callRelease { - t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want %v", - tt.desc, i, tt.maxConcurrency-tt.callRelease) + if i != tt.wantCount { + t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want 0", tt.desc, i) } }) }