diff --git a/context.go b/context.go index ba46450..3a64150 100644 --- a/context.go +++ b/context.go @@ -27,28 +27,14 @@ type ctxKey int const metadataCtxKey ctxKey = 0 // createContext returns a context and cancel function for a given task message. -func createContext(msg *base.TaskMessage) (ctx context.Context, cancel context.CancelFunc) { +func createContext(msg *base.TaskMessage, deadline time.Time) (ctx context.Context, cancel context.CancelFunc) { metadata := taskMetadata{ id: msg.ID.String(), maxRetry: msg.Retry, retryCount: msg.Retried, } ctx = context.WithValue(context.Background(), metadataCtxKey, metadata) - if msg.Timeout == 0 && msg.Deadline == 0 { - panic("asynq: internal error: missing both timeout and deadline") - } - if msg.Timeout != 0 { - timeout := time.Duration(msg.Timeout) * time.Second - ctx, cancel = context.WithTimeout(ctx, timeout) - } - if msg.Deadline != 0 { - deadline := time.Unix(int64(msg.Deadline), 0) - ctx, cancel = context.WithDeadline(ctx, deadline) - } - if cancel == nil { - ctx, cancel = context.WithCancel(ctx) - } - return ctx, cancel + return context.WithDeadline(ctx, deadline) } // GetTaskID extracts a task ID from a context, if any. diff --git a/context_test.go b/context_test.go index 4bb1fbb..261def5 100644 --- a/context_test.go +++ b/context_test.go @@ -10,46 +10,38 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/hibiken/asynq/internal/base" "github.com/rs/xid" ) -func TestCreateContextWithTimeRestrictions(t *testing.T) { +func TestCreateContextWithFutureDeadline(t *testing.T) { tests := []struct { - desc string - timeout time.Duration - deadline time.Time - wantDeadline time.Time + deadline time.Time }{ - {"only with timeout", 10 * time.Second, noDeadline, time.Now().Add(10 * time.Second)}, - {"only with deadline", noTimeout, time.Now().Add(time.Hour), time.Now().Add(time.Hour)}, - {"with timeout and deadline (timeout < deadline)", 10 * time.Second, time.Now().Add(time.Hour), time.Now().Add(10 * time.Second)}, - {"with timeout and deadline (timeout > deadline)", 10 * time.Minute, time.Now().Add(30 * time.Second), time.Now().Add(30 * time.Second)}, + {time.Now().Add(time.Hour)}, } for _, tc := range tests { msg := &base.TaskMessage{ - Type: "something", - ID: xid.New(), - Timeout: int(tc.timeout.Seconds()), - Deadline: int(tc.deadline.Unix()), + Type: "something", + ID: xid.New(), + Payload: nil, } - ctx, cancel := createContext(msg) + ctx, cancel := createContext(msg, tc.deadline) select { case x := <-ctx.Done(): - t.Errorf("%s: <-ctx.Done() == %v, want nothing (it should block)", tc.desc, x) + t.Errorf("<-ctx.Done() == %v, want nothing (it should block)", x) default: } got, ok := ctx.Deadline() if !ok { - t.Errorf("%s: ctx.Deadline() returned false, want deadline to be set", tc.desc) + t.Errorf("ctx.Deadline() returned false, want deadline to be set") } - if !cmp.Equal(tc.wantDeadline, got, cmpopts.EquateApproxTime(time.Second)) { - t.Errorf("%s: ctx.Deadline() returned %v, want %v", tc.desc, got, tc.wantDeadline) + if !cmp.Equal(tc.deadline, got) { + t.Errorf("ctx.Deadline() returned %v, want %v", got, tc.deadline) } cancel() @@ -62,19 +54,37 @@ func TestCreateContextWithTimeRestrictions(t *testing.T) { } } -func TestCreateContextWithoutTimeRestrictions(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("did not panic, want panic when both timeout and deadline are missing") - } - }() - msg := &base.TaskMessage{ - Type: "something", - ID: xid.New(), - Timeout: 0, // zero indicates no timeout - Deadline: 0, // zero indicates no deadline +func TestCreateContextWithPastDeadline(t *testing.T) { + tests := []struct { + deadline time.Time + }{ + {time.Now().Add(-2 * time.Hour)}, + } + + for _, tc := range tests { + msg := &base.TaskMessage{ + Type: "something", + ID: xid.New(), + Payload: nil, + } + + ctx, cancel := createContext(msg, tc.deadline) + defer cancel() + + select { + case <-ctx.Done(): + default: + t.Errorf("ctx.Done() blocked, want it to be non-blocking") + } + + got, ok := ctx.Deadline() + if !ok { + t.Errorf("ctx.Deadline() returned false, want deadline to be set") + } + if !cmp.Equal(tc.deadline, got) { + t.Errorf("ctx.Deadline() returned %v, want %v", got, tc.deadline) + } } - createContext(msg) } func TestGetTaskMetadataFromContext(t *testing.T) { @@ -87,7 +97,8 @@ func TestGetTaskMetadataFromContext(t *testing.T) { } for _, tc := range tests { - ctx, _ := createContext(tc.msg) + ctx, cancel := createContext(tc.msg, time.Now().Add(30*time.Minute)) + defer cancel() id, ok := GetTaskID(ctx) if !ok { diff --git a/internal/base/base.go b/internal/base/base.go index a03d381..4b2f54e 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -267,7 +267,7 @@ func (c *Cancelations) GetAll() []context.CancelFunc { type Broker interface { Enqueue(msg *TaskMessage) error EnqueueUnique(msg *TaskMessage, ttl time.Duration) error - Dequeue(qnames ...string) (*TaskMessage, int, error) + Dequeue(qnames ...string) (*TaskMessage, time.Time, error) Done(msg *TaskMessage) error Requeue(msg *TaskMessage) error Schedule(msg *TaskMessage, processAt time.Time) error diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index d577fd9..2b0660a 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -60,11 +60,11 @@ func (tb *TestBroker) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) er return tb.real.EnqueueUnique(msg, ttl) } -func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, int, error) { +func (tb *TestBroker) Dequeue(qnames ...string) (*base.TaskMessage, time.Time, error) { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { - return nil, 0, errRedisDown + return nil, time.Time{}, errRedisDown } return tb.real.Dequeue(qnames...) } diff --git a/processor.go b/processor.go index 5216a62..b41421c 100644 --- a/processor.go +++ b/processor.go @@ -158,30 +158,30 @@ func (p *processor) start(wg *sync.WaitGroup) { // exec pulls a task out of the queue and starts a worker goroutine to // process the task. func (p *processor) exec() { - qnames := p.queues() - msg, err := p.broker.Dequeue(qnames...) - switch { - case err == rdb.ErrNoProcessableTask: - p.logger.Debug("All queues are empty") - // Queues are empty, this is a normal behavior. - // Sleep to avoid slamming redis and let scheduler move tasks into queues. - // Note: We are not using blocking pop operation and polling queues instead. - // This adds significant load to redis. - time.Sleep(time.Second) - return - case err != nil: - if p.errLogLimiter.Allow() { - p.logger.Errorf("Dequeue error: %v", err) - } - return - } - select { case <-p.abort: - // shutdown is starting, return immediately after requeuing the message. - p.requeue(msg) return case p.sema <- struct{}{}: // acquire token + qnames := p.queues() + msg, deadline, err := p.broker.Dequeue(qnames...) + switch { + case err == rdb.ErrNoProcessableTask: + p.logger.Debug("All queues are empty") + // Queues are empty, this is a normal behavior. + // Sleep to avoid slamming redis and let scheduler move tasks into queues. + // Note: We are not using blocking pop operation and polling queues instead. + // This adds significant load to redis. + time.Sleep(time.Second) + <-p.sema // release token + return + case err != nil: + if p.errLogLimiter.Allow() { + p.logger.Errorf("Dequeue error: %v", err) + } + <-p.sema // release token + return + } + p.starting <- msg go func() { defer func() { @@ -189,7 +189,7 @@ func (p *processor) exec() { <-p.sema // release token }() - ctx, cancel := createContext(msg) + ctx, cancel := createContext(msg, deadline) p.cancelations.Add(msg.ID.String(), cancel) defer func() { cancel() @@ -206,6 +206,10 @@ func (p *processor) exec() { p.logger.Warnf("Quitting worker. task id=%s", msg.ID) p.requeue(msg) return + case <-ctx.Done(): + p.logger.Debugf("Retrying task. task id=%s", msg.ID) // TODO: Improve this log message and above + p.retryOrKill(msg, ctx.Err()) + return case resErr := <-resCh: // Note: One of three things should happen. // 1) Done -> Removes the message from InProgress @@ -215,11 +219,7 @@ func (p *processor) exec() { if p.errHandler != nil { p.errHandler.HandleError(task, resErr, msg.Retried, msg.Retry) } - if msg.Retried >= msg.Retry { - p.kill(msg, resErr) - } else { - p.retry(msg, resErr) - } + p.retryOrKill(msg, resErr) return } p.markAsDone(msg) @@ -251,6 +251,14 @@ func (p *processor) markAsDone(msg *base.TaskMessage) { } } +func (p *processor) retryOrKill(msg *base.TaskMessage, err error) { + if msg.Retried >= msg.Retry { + p.kill(msg, err) + } else { + p.retry(msg, err) + } +} + func (p *processor) retry(msg *base.TaskMessage, e error) { d := p.retryDelayFunc(msg.Retried, e, NewTask(msg.Type, msg.Payload)) retryAt := time.Now().Add(d)