mirror of
https://github.com/hibiken/asynq.git
synced 2025-04-22 16:50:18 +08:00
use sorted-set to release stale locks
This commit is contained in:
parent
cb720fc10b
commit
34cc2d8081
38
alias.go
Normal file
38
alias.go
Normal file
@ -0,0 +1,38 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
)
|
||||
|
||||
// GetTaskID extracts a task ID from a context, if any.
|
||||
//
|
||||
// ID of a task is guaranteed to be unique.
|
||||
// ID of a task doesn't change if the task is being retried.
|
||||
func GetTaskID(ctx context.Context) (id string, ok bool) {
|
||||
return asynqcontext.GetTaskID(ctx)
|
||||
}
|
||||
|
||||
// GetRetryCount extracts retry count from a context, if any.
|
||||
//
|
||||
// Return value n indicates the number of times associated task has been
|
||||
// retried so far.
|
||||
func GetRetryCount(ctx context.Context) (n int, ok bool) {
|
||||
return asynqcontext.GetRetryCount(ctx)
|
||||
}
|
||||
|
||||
// GetMaxRetry extracts maximum retry from a context, if any.
|
||||
//
|
||||
// Return value n indicates the maximum number of times the assoicated task
|
||||
// can be retried if ProcessTask returns a non-nil error.
|
||||
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
|
||||
return asynqcontext.GetMaxRetry(ctx)
|
||||
}
|
||||
|
||||
// GetQueueName extracts queue name from a context, if any.
|
||||
//
|
||||
// Return value qname indicates which queue the task was pulled from.
|
||||
func GetQueueName(ctx context.Context) (qname string, ok bool) {
|
||||
return asynqcontext.GetQueueName(ctx)
|
||||
}
|
@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
h "github.com/hibiken/asynq/internal/asynqtest"
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
)
|
||||
|
||||
// Creates a new task of type "task<n>" with payload {"data": n}.
|
||||
@ -104,7 +105,7 @@ func BenchmarkEndToEnd(b *testing.B) {
|
||||
n = 1
|
||||
b.Logf("internal error: could not get data from payload")
|
||||
}
|
||||
retried, ok := GetRetryCount(ctx)
|
||||
retried, ok := asynqcontext.GetRetryCount(ctx)
|
||||
if !ok {
|
||||
b.Logf("internal error: could not get retry count from context")
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
// Use of this source code is governed by a MIT license
|
||||
// that can be found in the LICENSE file.
|
||||
|
||||
package asynq
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -27,8 +27,8 @@ type ctxKey int
|
||||
// Its value of zero is arbitrary.
|
||||
const metadataCtxKey ctxKey = 0
|
||||
|
||||
// createContext returns a context and cancel function for a given task message.
|
||||
func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
|
||||
// New returns a context and cancel function for a given task message.
|
||||
func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
|
||||
metadata := taskMetadata{
|
||||
id: msg.ID.String(),
|
||||
maxRetry: msg.Retry,
|
@ -2,7 +2,7 @@
|
||||
// Use of this source code is governed by a MIT license
|
||||
// that can be found in the LICENSE file.
|
||||
|
||||
package asynq
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) {
|
||||
Payload: nil,
|
||||
}
|
||||
|
||||
ctx, cancel := createContext(msg, tc.deadline)
|
||||
ctx, cancel := New(msg, tc.deadline)
|
||||
|
||||
select {
|
||||
case x := <-ctx.Done():
|
||||
@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) {
|
||||
Payload: nil,
|
||||
}
|
||||
|
||||
ctx, cancel := createContext(msg, tc.deadline)
|
||||
ctx, cancel := New(msg, tc.deadline)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute))
|
||||
ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute))
|
||||
defer cancel()
|
||||
|
||||
id, ok := GetTaskID(ctx)
|
@ -16,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
"github.com/hibiken/asynq/internal/errors"
|
||||
"github.com/hibiken/asynq/internal/log"
|
||||
"golang.org/x/time/rate"
|
||||
@ -189,7 +190,7 @@ func (p *processor) exec() {
|
||||
<-p.sema // release token
|
||||
}()
|
||||
|
||||
ctx, cancel := createContext(msg, deadline)
|
||||
ctx, cancel := asynqcontext.New(msg, deadline)
|
||||
p.cancelations.Add(msg.ID.String(), cancel)
|
||||
defer func() {
|
||||
cancel()
|
||||
|
@ -23,9 +23,14 @@ func ExampleNewSemaphore() {
|
||||
// call sema.Close() when appropriate
|
||||
|
||||
_ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error {
|
||||
if !sema.Acquire(ctx) {
|
||||
ok, err := sema.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &RateLimitError{RetryIn: 30 * time.Second}
|
||||
}
|
||||
|
||||
// Make sure to release the token once we're done.
|
||||
defer sema.Release(ctx)
|
||||
|
||||
|
@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/hibiken/asynq"
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
)
|
||||
|
||||
// NewSemaphore creates a new counting Semaphore.
|
||||
@ -38,11 +40,17 @@ type Semaphore struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// KEYS[1] -> asynq:sema:<scope>
|
||||
// ARGV[1] -> max concurrency
|
||||
// ARGV[2] -> current time in unix time
|
||||
// ARGV[3] -> deadline in unix time
|
||||
// ARGV[4] -> task ID
|
||||
var acquireCmd = redis.NewScript(`
|
||||
local lockCount = tonumber(redis.call('GET', KEYS[1]))
|
||||
redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1)
|
||||
local lockCount = redis.call("ZCARD", KEYS[1])
|
||||
|
||||
if (not lockCount or lockCount < tonumber(ARGV[1])) then
|
||||
redis.call('INCR', KEYS[1])
|
||||
if (lockCount < tonumber(ARGV[1])) then
|
||||
redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4])
|
||||
return true
|
||||
else
|
||||
return false
|
||||
@ -50,16 +58,55 @@ end
|
||||
`)
|
||||
|
||||
// Acquire will try to acquire lock on the counting semaphore.
|
||||
// - Returns true, iff semaphore key exists and current value is less than maxConcurrency
|
||||
// - Returns false otherwise
|
||||
func (s *Semaphore) Acquire(ctx context.Context) bool {
|
||||
val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool()
|
||||
return val
|
||||
// - Returns (true, nil), iff semaphore key exists and current value is less than maxConcurrency
|
||||
// - Returns (false, nil) when lock cannot be acquired
|
||||
// - Returns (false, error) otherwise
|
||||
//
|
||||
// The context.Context passed to Acquire must have a deadline set,
|
||||
// this ensures that lock is released if the job goroutine crashes and does not call Release.
|
||||
func (s *Semaphore) Acquire(ctx context.Context) (bool, error) {
|
||||
d, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
return false, fmt.Errorf("provided context must have a deadline")
|
||||
}
|
||||
|
||||
taskID, ok := asynqcontext.GetTaskID(ctx)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("provided context is missing task ID value")
|
||||
}
|
||||
|
||||
b, err := acquireCmd.Run(ctx, s.rc,
|
||||
[]string{semaphoreKey(s.name)},
|
||||
[]interface{}{
|
||||
s.maxConcurrency,
|
||||
time.Now().Unix(),
|
||||
d.Unix(),
|
||||
taskID,
|
||||
}...).Bool()
|
||||
if err == redis.Nil {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
return b, err
|
||||
}
|
||||
|
||||
// Release will release the lock on the counting semaphore.
|
||||
func (s *Semaphore) Release(ctx context.Context) {
|
||||
s.rc.Decr(ctx, semaphoreKey(s.name))
|
||||
func (s *Semaphore) Release(ctx context.Context) error {
|
||||
taskID, ok := asynqcontext.GetTaskID(ctx)
|
||||
if !ok {
|
||||
return fmt.Errorf("provided context is missing task ID value")
|
||||
}
|
||||
|
||||
n, err := s.rc.ZRem(ctx, semaphoreKey(s.name), taskID).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("redis command failed: %v", err)
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return fmt.Errorf("no lock found for task %q", taskID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection with redis.
|
||||
@ -68,5 +115,5 @@ func (s *Semaphore) Close() error {
|
||||
}
|
||||
|
||||
func semaphoreKey(name string) string {
|
||||
return fmt.Sprintf("asynq:sema:{%s}", name)
|
||||
return fmt.Sprintf("asynq:sema:%s", name)
|
||||
}
|
||||
|
@ -3,11 +3,15 @@ package rate
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hibiken/asynq"
|
||||
"github.com/hibiken/asynq/internal/base"
|
||||
asynqcontext "github.com/hibiken/asynq/internal/context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/hibiken/asynq"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -76,19 +80,58 @@ func TestNewSemaphore_Acquire(t *testing.T) {
|
||||
desc string
|
||||
name string
|
||||
maxConcurrency int
|
||||
callAcquire int
|
||||
taskIDs []uuid.UUID
|
||||
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||
want []bool
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
desc: "Should acquire lock when current lock count is less than maxConcurrency",
|
||||
name: "task-1",
|
||||
maxConcurrency: 3,
|
||||
callAcquire: 2,
|
||||
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return asynqcontext.New(&base.TaskMessage{
|
||||
ID: id,
|
||||
Queue: "task-1",
|
||||
}, time.Now().Add(time.Second))
|
||||
},
|
||||
want: []bool{true, true},
|
||||
},
|
||||
{
|
||||
desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency",
|
||||
name: "task-2",
|
||||
maxConcurrency: 3,
|
||||
callAcquire: 4,
|
||||
taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()},
|
||||
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return asynqcontext.New(&base.TaskMessage{
|
||||
ID: id,
|
||||
Queue: "task-2",
|
||||
}, time.Now().Add(time.Second))
|
||||
},
|
||||
want: []bool{true, true, true, false},
|
||||
},
|
||||
{
|
||||
desc: "Should return error if context has no deadline",
|
||||
name: "task-3",
|
||||
maxConcurrency: 1,
|
||||
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return context.Background(), func() {}
|
||||
},
|
||||
want: []bool{false, false},
|
||||
wantErr: "provided context must have a deadline",
|
||||
},
|
||||
{
|
||||
desc: "Should return error when context is missing taskID",
|
||||
name: "task-4",
|
||||
maxConcurrency: 1,
|
||||
taskIDs: []uuid.UUID{uuid.New()},
|
||||
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), time.Second)
|
||||
},
|
||||
want: []bool{false},
|
||||
wantErr: "provided context is missing task ID value",
|
||||
},
|
||||
}
|
||||
|
||||
@ -105,33 +148,109 @@ func TestNewSemaphore_Acquire(t *testing.T) {
|
||||
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||
defer sema.Close()
|
||||
|
||||
for i := 0; i < tt.callAcquire; i++ {
|
||||
if got := sema.Acquire(context.Background()); got != (i < tt.maxConcurrency) {
|
||||
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, i < tt.maxConcurrency)
|
||||
for i := 0; i < len(tt.taskIDs); i++ {
|
||||
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||
|
||||
got, err := sema.Acquire(ctx)
|
||||
if got != tt.want[i] {
|
||||
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i])
|
||||
}
|
||||
if (tt.wantErr == "" && err != nil) || (tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr)) {
|
||||
t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.wantErr)
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSemaphore_Acquire_StaleLock(t *testing.T) {
|
||||
opt := getRedisConnOpt(t)
|
||||
rc := opt.MakeRedisClient().(redis.UniversalClient)
|
||||
defer rc.Close()
|
||||
|
||||
taskID := uuid.New()
|
||||
|
||||
rc.ZAdd(context.Background(), semaphoreKey("stale-lock"), &redis.Z{
|
||||
Score: float64(time.Now().Add(-10 * time.Second).Unix()),
|
||||
Member: taskID.String(),
|
||||
})
|
||||
|
||||
sema := NewSemaphore(opt, "stale-lock", 1)
|
||||
defer sema.Close()
|
||||
|
||||
ctx, cancel := asynqcontext.New(&base.TaskMessage{
|
||||
ID: taskID,
|
||||
Queue: "task-1",
|
||||
}, time.Now().Add(time.Second))
|
||||
defer cancel()
|
||||
|
||||
got, err := sema.Acquire(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Acquire_StaleLock;\nSemaphore.Acquire() got error %v", err)
|
||||
}
|
||||
|
||||
if !got {
|
||||
t.Error("Acquire_StaleLock;\nSemaphore.Acquire() got false want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSemaphore_Release(t *testing.T) {
|
||||
testID := uuid.New()
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
name string
|
||||
maxConcurrency int
|
||||
callRelease int
|
||||
taskIDs []uuid.UUID
|
||||
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
|
||||
wantCount int64
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
desc: "Should decrease lock count",
|
||||
name: "task-3",
|
||||
maxConcurrency: 3,
|
||||
callRelease: 1,
|
||||
name: "task-5",
|
||||
taskIDs: []uuid.UUID{uuid.New()},
|
||||
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return asynqcontext.New(&base.TaskMessage{
|
||||
ID: id,
|
||||
Queue: "task-3",
|
||||
}, time.Now().Add(time.Second))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should decrease lock count by 2",
|
||||
name: "task-4",
|
||||
maxConcurrency: 3,
|
||||
callRelease: 2,
|
||||
name: "task-6",
|
||||
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
|
||||
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return asynqcontext.New(&base.TaskMessage{
|
||||
ID: id,
|
||||
Queue: "task-4",
|
||||
}, time.Now().Add(time.Second))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should return error when context is missing taskID",
|
||||
name: "task-7",
|
||||
taskIDs: []uuid.UUID{uuid.New()},
|
||||
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), time.Second)
|
||||
},
|
||||
wantCount: 1,
|
||||
wantErr: "provided context is missing task ID value",
|
||||
},
|
||||
{
|
||||
desc: "Should return error when context has taskID which never acquired lock",
|
||||
name: "task-8",
|
||||
taskIDs: []uuid.UUID{uuid.New()},
|
||||
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
|
||||
return asynqcontext.New(&base.TaskMessage{
|
||||
ID: testID,
|
||||
Queue: "task-4",
|
||||
}, time.Now().Add(time.Second))
|
||||
},
|
||||
wantCount: 1,
|
||||
wantErr: fmt.Sprintf("no lock found for task %q", testID.String()),
|
||||
},
|
||||
}
|
||||
|
||||
@ -145,25 +264,43 @@ func TestNewSemaphore_Release(t *testing.T) {
|
||||
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
if err := rc.IncrBy(context.Background(), semaphoreKey(tt.name), int64(tt.maxConcurrency)).Err(); err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.IncrBy() got error %v", tt.desc, err)
|
||||
var members []*redis.Z
|
||||
for i := 0; i < len(tt.taskIDs); i++ {
|
||||
members = append(members, &redis.Z{
|
||||
Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()),
|
||||
Member: tt.taskIDs[i].String(),
|
||||
})
|
||||
}
|
||||
if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
|
||||
sema := NewSemaphore(opt, tt.name, 3)
|
||||
defer sema.Close()
|
||||
|
||||
for i := 0; i < tt.callRelease; i++ {
|
||||
sema.Release(context.Background())
|
||||
for i := 0; i < len(tt.taskIDs); i++ {
|
||||
ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
|
||||
|
||||
err := sema.Release(ctx)
|
||||
|
||||
if tt.wantErr == "" && err != nil {
|
||||
t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
i, err := rc.Get(context.Background(), semaphoreKey(tt.name)).Int()
|
||||
if tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr) {
|
||||
t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.wantErr)
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
|
||||
i, err := rc.ZCount(context.Background(), semaphoreKey(tt.name), "-inf", "+inf").Result()
|
||||
if err != nil {
|
||||
t.Errorf("%s;\nredis.UniversalClient.Get() got error %v", tt.desc, err)
|
||||
t.Errorf("%s;\nredis.UniversalClient.ZCount() got error %v", tt.desc, err)
|
||||
}
|
||||
|
||||
if i != tt.maxConcurrency-tt.callRelease {
|
||||
t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want %v",
|
||||
tt.desc, i, tt.maxConcurrency-tt.callRelease)
|
||||
if i != tt.wantCount {
|
||||
t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want 0", tt.desc, i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user