diff --git a/CHANGELOG.md b/CHANGELOG.md index e9c96b4..f36332f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- `GetTaskID`, `GetRetryCount`, and `GetMaxRetry` functions were added to extract task metadata from context. + ## [0.9.0] - 2020-05-16 ### Changed diff --git a/context.go b/context.go new file mode 100644 index 0000000..ee69ba1 --- /dev/null +++ b/context.go @@ -0,0 +1,85 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "context" + "time" + + "github.com/hibiken/asynq/internal/base" +) + +// A taskMetadata holds task scoped data to put in context. +type taskMetadata struct { + id string + maxRetry int + retryCount int +} + +// ctxKey type is unexported to prevent collisions with context keys defined in +// other packages. +type ctxKey int + +// metadataCtxKey is the context key for the task metadata. +// Its value of zero is arbitrary. +const metadataCtxKey ctxKey = 0 + +// createContext returns a context and cancel function for a given task message. +func createContext(msg *base.TaskMessage) (ctx context.Context, cancel context.CancelFunc) { + metadata := taskMetadata{ + id: msg.ID.String(), + maxRetry: msg.Retry, + retryCount: msg.Retried, + } + ctx = context.WithValue(context.Background(), metadataCtxKey, metadata) + timeout, err := time.ParseDuration(msg.Timeout) + if err == nil && timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } + deadline, err := time.Parse(time.RFC3339, msg.Deadline) + if err == nil && !deadline.IsZero() { + ctx, cancel = context.WithDeadline(ctx, deadline) + } + if cancel == nil { + ctx, cancel = context.WithCancel(ctx) + } + return ctx, cancel +} + +// GetTaskID extracts a task ID from a context, if any. +// +// ID of a task is guaranteed to be unique. +// ID of a task doesn't change if the task is being retried. +func GetTaskID(ctx context.Context) (id string, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return "", false + } + return metadata.id, true +} + +// GetRetryCount extracts retry count from a context, if any. +// +// Return value n indicates the number of times associated task has been +// retried so far. +func GetRetryCount(ctx context.Context) (n int, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return 0, false + } + return metadata.retryCount, true +} + +// GetMaxRetry extracts maximum retry from a context, if any. +// +// Return value n indicates the maximum number of times the assoicated task +// can be retried if ProcessTask returns a non-nil error. +func GetMaxRetry(ctx context.Context) (n int, ok bool) { + metadata, ok := ctx.Value(metadataCtxKey).(taskMetadata) + if !ok { + return 0, false + } + return metadata.maxRetry, true +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..1a5eff6 --- /dev/null +++ b/context_test.go @@ -0,0 +1,157 @@ +// Copyright 2020 Kentaro Hibino. All rights reserved. +// Use of this source code is governed by a MIT license +// that can be found in the LICENSE file. + +package asynq + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/hibiken/asynq/internal/base" + "github.com/rs/xid" +) + +func TestCreateContextWithTimeRestrictions(t *testing.T) { + var ( + noTimeout = time.Duration(0) + noDeadline = time.Time{} + ) + + tests := []struct { + desc string + timeout time.Duration + deadline time.Time + wantDeadline time.Time + }{ + {"only with timeout", 10 * time.Second, noDeadline, time.Now().Add(10 * time.Second)}, + {"only with deadline", noTimeout, time.Now().Add(time.Hour), time.Now().Add(time.Hour)}, + {"with timeout and deadline (timeout < deadline)", 10 * time.Second, time.Now().Add(time.Hour), time.Now().Add(10 * time.Second)}, + {"with timeout and deadline (timeout > deadline)", 10 * time.Minute, time.Now().Add(30 * time.Second), time.Now().Add(30 * time.Second)}, + } + + for _, tc := range tests { + msg := &base.TaskMessage{ + Type: "something", + ID: xid.New(), + Timeout: tc.timeout.String(), + Deadline: tc.deadline.Format(time.RFC3339), + } + + ctx, cancel := createContext(msg) + + select { + case x := <-ctx.Done(): + t.Errorf("%s: <-ctx.Done() == %v, want nothing (it should block)", tc.desc, x) + default: + } + + got, ok := ctx.Deadline() + if !ok { + t.Errorf("%s: ctx.Deadline() returned false, want deadline to be set", tc.desc) + } + if !cmp.Equal(tc.wantDeadline, got, cmpopts.EquateApproxTime(time.Second)) { + t.Errorf("%s: ctx.Deadline() returned %v, want %v", tc.desc, got, tc.wantDeadline) + } + + cancel() + + select { + case <-ctx.Done(): + default: + t.Errorf("ctx.Done() blocked, want it to be non-blocking") + } + } +} + +func TestCreateContextWithoutTimeRestrictions(t *testing.T) { + msg := &base.TaskMessage{ + Type: "something", + ID: xid.New(), + Timeout: time.Duration(0).String(), // zero value to indicate no timeout + Deadline: time.Time{}.Format(time.RFC3339), // zero value to indicate no deadline + } + + ctx, cancel := createContext(msg) + + select { + case x := <-ctx.Done(): + t.Errorf("<-ctx.Done() == %v, want nothing (it should block)", x) + default: + } + + _, ok := ctx.Deadline() + if ok { + t.Error("ctx.Deadline() returned true, want deadline to not be set") + } + + cancel() + + select { + case <-ctx.Done(): + default: + t.Error("ctx.Done() blocked, want it to be non-blocking") + } +} + +func TestGetTaskMetadataFromContext(t *testing.T) { + tests := []struct { + desc string + msg *base.TaskMessage + }{ + {"with zero retried message", &base.TaskMessage{Type: "something", ID: xid.New(), Retry: 25, Retried: 0}}, + {"with non-zero retried message", &base.TaskMessage{Type: "something", ID: xid.New(), Retry: 10, Retried: 5}}, + } + + for _, tc := range tests { + ctx, _ := createContext(tc.msg) + + id, ok := GetTaskID(ctx) + if !ok { + t.Errorf("%s: GetTaskID(ctx) returned ok == false", tc.desc) + } + if ok && id != tc.msg.ID.String() { + t.Errorf("%s: GetTaskID(ctx) returned id == %q, want %q", tc.desc, id, tc.msg.ID.String()) + } + + retried, ok := GetRetryCount(ctx) + if !ok { + t.Errorf("%s: GetRetryCount(ctx) returned ok == false", tc.desc) + } + if ok && retried != tc.msg.Retried { + t.Errorf("%s: GetRetryCount(ctx) returned n == %d want %d", tc.desc, retried, tc.msg.Retried) + } + + maxRetry, ok := GetMaxRetry(ctx) + if !ok { + t.Errorf("%s: GetMaxRetry(ctx) returned ok == false", tc.desc) + } + if ok && maxRetry != tc.msg.Retry { + t.Errorf("%s: GetMaxRetry(ctx) returned n == %d want %d", tc.desc, maxRetry, tc.msg.Retry) + } + } +} + +func TestGetTaskMetadataFromContextError(t *testing.T) { + tests := []struct { + desc string + ctx context.Context + }{ + {"with background context", context.Background()}, + } + + for _, tc := range tests { + if _, ok := GetTaskID(tc.ctx); ok { + t.Errorf("%s: GetTaskID(ctx) returned ok == true", tc.desc) + } + if _, ok := GetRetryCount(tc.ctx); ok { + t.Errorf("%s: GetRetryCount(ctx) returned ok == true", tc.desc) + } + if _, ok := GetMaxRetry(tc.ctx); ok { + t.Errorf("%s: GetMaxRetry(ctx) returned ok == true", tc.desc) + } + } +} diff --git a/processor.go b/processor.go index 125e80f..109644d 100644 --- a/processor.go +++ b/processor.go @@ -405,20 +405,3 @@ func gcd(xs ...int) int { } return res } - -// createContext returns a context and cancel function for a given task message. -func createContext(msg *base.TaskMessage) (ctx context.Context, cancel context.CancelFunc) { - ctx = context.Background() - timeout, err := time.ParseDuration(msg.Timeout) - if err == nil && timeout != 0 { - ctx, cancel = context.WithTimeout(ctx, timeout) - } - deadline, err := time.Parse(time.RFC3339, msg.Deadline) - if err == nil && !deadline.IsZero() { - ctx, cancel = context.WithDeadline(ctx, deadline) - } - if cancel == nil { - ctx, cancel = context.WithCancel(ctx) - } - return ctx, cancel -} diff --git a/processor_test.go b/processor_test.go index 98c4cd4..954fe0f 100644 --- a/processor_test.go +++ b/processor_test.go @@ -17,7 +17,6 @@ import ( h "github.com/hibiken/asynq/internal/asynqtest" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/rdb" - "github.com/rs/xid" ) func TestProcessorSuccess(t *testing.T) { @@ -391,88 +390,6 @@ func TestPerform(t *testing.T) { } } -func TestCreateContextWithTimeRestrictions(t *testing.T) { - var ( - noTimeout = time.Duration(0) - noDeadline = time.Time{} - ) - - tests := []struct { - desc string - timeout time.Duration - deadline time.Time - wantDeadline time.Time - }{ - {"only with timeout", 10 * time.Second, noDeadline, time.Now().Add(10 * time.Second)}, - {"only with deadline", noTimeout, time.Now().Add(time.Hour), time.Now().Add(time.Hour)}, - {"with timeout and deadline (timeout < deadline)", 10 * time.Second, time.Now().Add(time.Hour), time.Now().Add(10 * time.Second)}, - {"with timeout and deadline (timeout > deadline)", 10 * time.Minute, time.Now().Add(30 * time.Second), time.Now().Add(30 * time.Second)}, - } - - for _, tc := range tests { - msg := &base.TaskMessage{ - Type: "something", - ID: xid.New(), - Timeout: tc.timeout.String(), - Deadline: tc.deadline.Format(time.RFC3339), - } - - ctx, cancel := createContext(msg) - - select { - case x := <-ctx.Done(): - t.Errorf("%s: <-ctx.Done() == %v, want nothing (it should block)", tc.desc, x) - default: - } - - got, ok := ctx.Deadline() - if !ok { - t.Errorf("%s: ctx.Deadline() returned false, want deadline to be set", tc.desc) - } - if !cmp.Equal(tc.wantDeadline, got, cmpopts.EquateApproxTime(time.Second)) { - t.Errorf("%s: ctx.Deadline() returned %v, want %v", tc.desc, got, tc.wantDeadline) - } - - cancel() - - select { - case <-ctx.Done(): - default: - t.Errorf("ctx.Done() blocked, want it to be non-blocking") - } - } -} - -func TestCreateContextWithoutTimeRestrictions(t *testing.T) { - msg := &base.TaskMessage{ - Type: "something", - ID: xid.New(), - Timeout: time.Duration(0).String(), // zero value to indicate no timeout - Deadline: time.Time{}.Format(time.RFC3339), // zero value to indicate no deadline - } - - ctx, cancel := createContext(msg) - - select { - case x := <-ctx.Done(): - t.Errorf("<-ctx.Done() == %v, want nothing (it should block)", x) - default: - } - - _, ok := ctx.Deadline() - if ok { - t.Error("ctx.Deadline() returned true, want deadline to not be set") - } - - cancel() - - select { - case <-ctx.Done(): - default: - t.Error("ctx.Done() blocked, want it to be non-blocking") - } -} - func TestGCD(t *testing.T) { tests := []struct { input []int