2
0
mirror of https://github.com/hibiken/asynq.git synced 2025-04-22 16:50:18 +08:00

use sorted-set to release stale locks

This commit is contained in:
ajatprabha 2021-10-30 16:07:46 +05:30
parent cb720fc10b
commit 34cc2d8081
No known key found for this signature in database
GPG Key ID: EEA3FDB0312545DA
8 changed files with 283 additions and 54 deletions

38
alias.go Normal file
View File

@ -0,0 +1,38 @@
package asynq
import (
"context"
asynqcontext "github.com/hibiken/asynq/internal/context"
)
// GetTaskID extracts a task ID from a context, if any.
//
// ID of a task is guaranteed to be unique.
// ID of a task doesn't change if the task is being retried.
func GetTaskID(ctx context.Context) (id string, ok bool) {
return asynqcontext.GetTaskID(ctx)
}
// GetRetryCount extracts retry count from a context, if any.
//
// Return value n indicates the number of times associated task has been
// retried so far.
func GetRetryCount(ctx context.Context) (n int, ok bool) {
return asynqcontext.GetRetryCount(ctx)
}
// GetMaxRetry extracts maximum retry from a context, if any.
//
// Return value n indicates the maximum number of times the assoicated task
// can be retried if ProcessTask returns a non-nil error.
func GetMaxRetry(ctx context.Context) (n int, ok bool) {
return asynqcontext.GetMaxRetry(ctx)
}
// GetQueueName extracts queue name from a context, if any.
//
// Return value qname indicates which queue the task was pulled from.
func GetQueueName(ctx context.Context) (qname string, ok bool) {
return asynqcontext.GetQueueName(ctx)
}

View File

@ -13,6 +13,7 @@ import (
"time" "time"
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
asynqcontext "github.com/hibiken/asynq/internal/context"
) )
// Creates a new task of type "task<n>" with payload {"data": n}. // Creates a new task of type "task<n>" with payload {"data": n}.
@ -104,7 +105,7 @@ func BenchmarkEndToEnd(b *testing.B) {
n = 1 n = 1
b.Logf("internal error: could not get data from payload") b.Logf("internal error: could not get data from payload")
} }
retried, ok := GetRetryCount(ctx) retried, ok := asynqcontext.GetRetryCount(ctx)
if !ok { if !ok {
b.Logf("internal error: could not get retry count from context") b.Logf("internal error: could not get retry count from context")
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT license // Use of this source code is governed by a MIT license
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
package asynq package context
import ( import (
"context" "context"
@ -27,8 +27,8 @@ type ctxKey int
// Its value of zero is arbitrary. // Its value of zero is arbitrary.
const metadataCtxKey ctxKey = 0 const metadataCtxKey ctxKey = 0
// createContext returns a context and cancel function for a given task message. // New returns a context and cancel function for a given task message.
func createContext(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) { func New(msg *base.TaskMessage, deadline time.Time) (context.Context, context.CancelFunc) {
metadata := taskMetadata{ metadata := taskMetadata{
id: msg.ID.String(), id: msg.ID.String(),
maxRetry: msg.Retry, maxRetry: msg.Retry,

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT license // Use of this source code is governed by a MIT license
// that can be found in the LICENSE file. // that can be found in the LICENSE file.
package asynq package context
import ( import (
"context" "context"
@ -28,7 +28,7 @@ func TestCreateContextWithFutureDeadline(t *testing.T) {
Payload: nil, Payload: nil,
} }
ctx, cancel := createContext(msg, tc.deadline) ctx, cancel := New(msg, tc.deadline)
select { select {
case x := <-ctx.Done(): case x := <-ctx.Done():
@ -68,7 +68,7 @@ func TestCreateContextWithPastDeadline(t *testing.T) {
Payload: nil, Payload: nil,
} }
ctx, cancel := createContext(msg, tc.deadline) ctx, cancel := New(msg, tc.deadline)
defer cancel() defer cancel()
select { select {
@ -98,7 +98,7 @@ func TestGetTaskMetadataFromContext(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute)) ctx, cancel := New(tc.msg, time.Now().Add(30*time.Minute))
defer cancel() defer cancel()
id, ok := GetTaskID(ctx) id, ok := GetTaskID(ctx)

View File

@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
asynqcontext "github.com/hibiken/asynq/internal/context"
"github.com/hibiken/asynq/internal/errors" "github.com/hibiken/asynq/internal/errors"
"github.com/hibiken/asynq/internal/log" "github.com/hibiken/asynq/internal/log"
"golang.org/x/time/rate" "golang.org/x/time/rate"
@ -189,7 +190,7 @@ func (p *processor) exec() {
<-p.sema // release token <-p.sema // release token
}() }()
ctx, cancel := createContext(msg, deadline) ctx, cancel := asynqcontext.New(msg, deadline)
p.cancelations.Add(msg.ID.String(), cancel) p.cancelations.Add(msg.ID.String(), cancel)
defer func() { defer func() {
cancel() cancel()

View File

@ -23,9 +23,14 @@ func ExampleNewSemaphore() {
// call sema.Close() when appropriate // call sema.Close() when appropriate
_ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error { _ = asynq.HandlerFunc(func(ctx context.Context, task *asynq.Task) error {
if !sema.Acquire(ctx) { ok, err := sema.Acquire(ctx)
if err != nil {
return err
}
if !ok {
return &RateLimitError{RetryIn: 30 * time.Second} return &RateLimitError{RetryIn: 30 * time.Second}
} }
// Make sure to release the token once we're done. // Make sure to release the token once we're done.
defer sema.Release(ctx) defer sema.Release(ctx)

View File

@ -4,9 +4,11 @@ import (
"context" "context"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/hibiken/asynq" "github.com/hibiken/asynq"
asynqcontext "github.com/hibiken/asynq/internal/context"
) )
// NewSemaphore creates a new counting Semaphore. // NewSemaphore creates a new counting Semaphore.
@ -38,28 +40,73 @@ type Semaphore struct {
name string name string
} }
// KEYS[1] -> asynq:sema:<scope>
// ARGV[1] -> max concurrency
// ARGV[2] -> current time in unix time
// ARGV[3] -> deadline in unix time
// ARGV[4] -> task ID
var acquireCmd = redis.NewScript(` var acquireCmd = redis.NewScript(`
local lockCount = tonumber(redis.call('GET', KEYS[1])) redis.call("ZREMRANGEBYSCORE", KEYS[1], "-inf", tonumber(ARGV[2])-1)
local lockCount = redis.call("ZCARD", KEYS[1])
if (not lockCount or lockCount < tonumber(ARGV[1])) then if (lockCount < tonumber(ARGV[1])) then
redis.call('INCR', KEYS[1]) redis.call("ZADD", KEYS[1], ARGV[3], ARGV[4])
return true return true
else else
return false return false
end end
`) `)
// Acquire will try to acquire lock on the counting semaphore. // Acquire will try to acquire lock on the counting semaphore.
// - Returns true, iff semaphore key exists and current value is less than maxConcurrency // - Returns (true, nil), iff semaphore key exists and current value is less than maxConcurrency
// - Returns false otherwise // - Returns (false, nil) when lock cannot be acquired
func (s *Semaphore) Acquire(ctx context.Context) bool { // - Returns (false, error) otherwise
val, _ := acquireCmd.Run(ctx, s.rc, []string{semaphoreKey(s.name)}, []interface{}{s.maxConcurrency}...).Bool() //
return val // The context.Context passed to Acquire must have a deadline set,
// this ensures that lock is released if the job goroutine crashes and does not call Release.
func (s *Semaphore) Acquire(ctx context.Context) (bool, error) {
d, ok := ctx.Deadline()
if !ok {
return false, fmt.Errorf("provided context must have a deadline")
}
taskID, ok := asynqcontext.GetTaskID(ctx)
if !ok {
return false, fmt.Errorf("provided context is missing task ID value")
}
b, err := acquireCmd.Run(ctx, s.rc,
[]string{semaphoreKey(s.name)},
[]interface{}{
s.maxConcurrency,
time.Now().Unix(),
d.Unix(),
taskID,
}...).Bool()
if err == redis.Nil {
return b, nil
}
return b, err
} }
// Release will release the lock on the counting semaphore. // Release will release the lock on the counting semaphore.
func (s *Semaphore) Release(ctx context.Context) { func (s *Semaphore) Release(ctx context.Context) error {
s.rc.Decr(ctx, semaphoreKey(s.name)) taskID, ok := asynqcontext.GetTaskID(ctx)
if !ok {
return fmt.Errorf("provided context is missing task ID value")
}
n, err := s.rc.ZRem(ctx, semaphoreKey(s.name), taskID).Result()
if err != nil {
return fmt.Errorf("redis command failed: %v", err)
}
if n == 0 {
return fmt.Errorf("no lock found for task %q", taskID)
}
return nil
} }
// Close closes the connection with redis. // Close closes the connection with redis.
@ -68,5 +115,5 @@ func (s *Semaphore) Close() error {
} }
func semaphoreKey(name string) string { func semaphoreKey(name string) string {
return fmt.Sprintf("asynq:sema:{%s}", name) return fmt.Sprintf("asynq:sema:%s", name)
} }

View File

@ -3,11 +3,15 @@ package rate
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/google/uuid"
"github.com/hibiken/asynq"
"github.com/hibiken/asynq/internal/base"
asynqcontext "github.com/hibiken/asynq/internal/context"
"strings" "strings"
"testing" "testing"
"time"
"github.com/go-redis/redis/v8"
"github.com/hibiken/asynq"
) )
var ( var (
@ -76,19 +80,58 @@ func TestNewSemaphore_Acquire(t *testing.T) {
desc string desc string
name string name string
maxConcurrency int maxConcurrency int
callAcquire int taskIDs []uuid.UUID
ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
want []bool
wantErr string
}{ }{
{ {
desc: "Should acquire lock when current lock count is less than maxConcurrency", desc: "Should acquire lock when current lock count is less than maxConcurrency",
name: "task-1", name: "task-1",
maxConcurrency: 3, maxConcurrency: 3,
callAcquire: 2, taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
return asynqcontext.New(&base.TaskMessage{
ID: id,
Queue: "task-1",
}, time.Now().Add(time.Second))
},
want: []bool{true, true},
}, },
{ {
desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency", desc: "Should fail acquiring lock when current lock count is equal to maxConcurrency",
name: "task-2", name: "task-2",
maxConcurrency: 3, maxConcurrency: 3,
callAcquire: 4, taskIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New(), uuid.New()},
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
return asynqcontext.New(&base.TaskMessage{
ID: id,
Queue: "task-2",
}, time.Now().Add(time.Second))
},
want: []bool{true, true, true, false},
},
{
desc: "Should return error if context has no deadline",
name: "task-3",
maxConcurrency: 1,
taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
return context.Background(), func() {}
},
want: []bool{false, false},
wantErr: "provided context must have a deadline",
},
{
desc: "Should return error when context is missing taskID",
name: "task-4",
maxConcurrency: 1,
taskIDs: []uuid.UUID{uuid.New()},
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
want: []bool{false},
wantErr: "provided context is missing task ID value",
}, },
} }
@ -105,33 +148,109 @@ func TestNewSemaphore_Acquire(t *testing.T) {
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) sema := NewSemaphore(opt, tt.name, tt.maxConcurrency)
defer sema.Close() defer sema.Close()
for i := 0; i < tt.callAcquire; i++ { for i := 0; i < len(tt.taskIDs); i++ {
if got := sema.Acquire(context.Background()); got != (i < tt.maxConcurrency) { ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, i < tt.maxConcurrency)
got, err := sema.Acquire(ctx)
if got != tt.want[i] {
t.Errorf("%s;\nSemaphore.Acquire(ctx) returned %v, want %v", tt.desc, got, tt.want[i])
} }
if (tt.wantErr == "" && err != nil) || (tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr)) {
t.Errorf("%s;\nSemaphore.Acquire() got error %v want error %v", tt.desc, err, tt.wantErr)
}
cancel()
} }
}) })
} }
} }
func TestNewSemaphore_Acquire_StaleLock(t *testing.T) {
opt := getRedisConnOpt(t)
rc := opt.MakeRedisClient().(redis.UniversalClient)
defer rc.Close()
taskID := uuid.New()
rc.ZAdd(context.Background(), semaphoreKey("stale-lock"), &redis.Z{
Score: float64(time.Now().Add(-10 * time.Second).Unix()),
Member: taskID.String(),
})
sema := NewSemaphore(opt, "stale-lock", 1)
defer sema.Close()
ctx, cancel := asynqcontext.New(&base.TaskMessage{
ID: taskID,
Queue: "task-1",
}, time.Now().Add(time.Second))
defer cancel()
got, err := sema.Acquire(ctx)
if err != nil {
t.Errorf("Acquire_StaleLock;\nSemaphore.Acquire() got error %v", err)
}
if !got {
t.Error("Acquire_StaleLock;\nSemaphore.Acquire() got false want true")
}
}
func TestNewSemaphore_Release(t *testing.T) { func TestNewSemaphore_Release(t *testing.T) {
testID := uuid.New()
tests := []struct { tests := []struct {
desc string desc string
name string name string
maxConcurrency int taskIDs []uuid.UUID
callRelease int ctxFunc func(uuid.UUID) (context.Context, context.CancelFunc)
wantCount int64
wantErr string
}{ }{
{ {
desc: "Should decrease lock count", desc: "Should decrease lock count",
name: "task-3", name: "task-5",
maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New()},
callRelease: 1, ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
return asynqcontext.New(&base.TaskMessage{
ID: id,
Queue: "task-3",
}, time.Now().Add(time.Second))
},
}, },
{ {
desc: "Should decrease lock count by 2", desc: "Should decrease lock count by 2",
name: "task-4", name: "task-6",
maxConcurrency: 3, taskIDs: []uuid.UUID{uuid.New(), uuid.New()},
callRelease: 2, ctxFunc: func(id uuid.UUID) (context.Context, context.CancelFunc) {
return asynqcontext.New(&base.TaskMessage{
ID: id,
Queue: "task-4",
}, time.Now().Add(time.Second))
},
},
{
desc: "Should return error when context is missing taskID",
name: "task-7",
taskIDs: []uuid.UUID{uuid.New()},
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
wantCount: 1,
wantErr: "provided context is missing task ID value",
},
{
desc: "Should return error when context has taskID which never acquired lock",
name: "task-8",
taskIDs: []uuid.UUID{uuid.New()},
ctxFunc: func(_ uuid.UUID) (context.Context, context.CancelFunc) {
return asynqcontext.New(&base.TaskMessage{
ID: testID,
Queue: "task-4",
}, time.Now().Add(time.Second))
},
wantCount: 1,
wantErr: fmt.Sprintf("no lock found for task %q", testID.String()),
}, },
} }
@ -145,25 +264,43 @@ func TestNewSemaphore_Release(t *testing.T) {
t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err) t.Errorf("%s;\nredis.UniversalClient.Del() got error %v", tt.desc, err)
} }
if err := rc.IncrBy(context.Background(), semaphoreKey(tt.name), int64(tt.maxConcurrency)).Err(); err != nil { var members []*redis.Z
t.Errorf("%s;\nredis.UniversalClient.IncrBy() got error %v", tt.desc, err) for i := 0; i < len(tt.taskIDs); i++ {
members = append(members, &redis.Z{
Score: float64(time.Now().Add(time.Duration(i) * time.Second).Unix()),
Member: tt.taskIDs[i].String(),
})
}
if err := rc.ZAdd(context.Background(), semaphoreKey(tt.name), members...).Err(); err != nil {
t.Errorf("%s;\nredis.UniversalClient.ZAdd() got error %v", tt.desc, err)
} }
sema := NewSemaphore(opt, tt.name, tt.maxConcurrency) sema := NewSemaphore(opt, tt.name, 3)
defer sema.Close() defer sema.Close()
for i := 0; i < tt.callRelease; i++ { for i := 0; i < len(tt.taskIDs); i++ {
sema.Release(context.Background()) ctx, cancel := tt.ctxFunc(tt.taskIDs[i])
err := sema.Release(ctx)
if tt.wantErr == "" && err != nil {
t.Errorf("%s;\nSemaphore.Release() got error %v", tt.desc, err)
}
if tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr) {
t.Errorf("%s;\nSemaphore.Release() got error %v want error %v", tt.desc, err, tt.wantErr)
}
cancel()
} }
i, err := rc.Get(context.Background(), semaphoreKey(tt.name)).Int() i, err := rc.ZCount(context.Background(), semaphoreKey(tt.name), "-inf", "+inf").Result()
if err != nil { if err != nil {
t.Errorf("%s;\nredis.UniversalClient.Get() got error %v", tt.desc, err) t.Errorf("%s;\nredis.UniversalClient.ZCount() got error %v", tt.desc, err)
} }
if i != tt.maxConcurrency-tt.callRelease { if i != tt.wantCount {
t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want %v", t.Errorf("%s;\nSemaphore.Release(ctx) didn't release lock, got %v want 0", tt.desc, i)
tt.desc, i, tt.maxConcurrency-tt.callRelease)
} }
}) })
} }