2
0
mirror of https://github.com/hibiken/asynq.git synced 2025-10-20 21:26:14 +08:00

Compare commits

..

27 Commits

Author SHA1 Message Date
Ken Hibino
f91c05b92c v0.7.0 2020-03-22 12:04:37 -07:00
Ken Hibino
9b4438347e Fix comment 2020-03-21 11:44:26 -07:00
Ken Hibino
c33dd447ac Allow client to enqueue a task with unique option
Changes:

- Added Unique option for clients
- Require go v.13 or above (to use new errors wrapping functions)
- Fixed adding queue key to all-queues set (asynq:queues) when scheduling.
2020-03-21 11:40:40 -07:00
Ken Hibino
6df2c3ae2b v0.6.2 2020-03-15 21:02:28 -07:00
Ken Hibino
37554fd23c Update readme example code 2020-03-15 14:56:00 -07:00
Ken Hibino
77f5a38453 Refactor payload_test to reduce cyclomatic complexities 2020-03-14 12:30:42 -07:00
Ken Hibino
8d2b9d6be7 Add comments to exported types and functions from internal/log package 2020-03-13 21:04:45 -07:00
Bo-Yi Wu
1b7d557c66 fix typo 2020-03-13 20:02:26 -07:00
Bo-Yi Wu
30b68728d4 chore(lint): fix from gofmt -s 2020-03-13 20:01:39 -07:00
Ken Hibino
310d38620d Minor tweak to readme example code 2020-03-13 17:27:20 -07:00
Ken Hibino
1a53bbf21b Update changelog 2020-03-13 17:27:20 -07:00
Ken Hibino
9c79a7d507 Simplify code with gofmt -s 2020-03-13 14:24:24 -07:00
Ken Hibino
516f95edff Add Use method to better support middlewares with ServeMux 2020-03-13 14:13:17 -07:00
Ken Hibino
cf7a677312 v0.6.1 2020-03-12 08:42:34 -07:00
Ken Hibino
0bc6eba021 Allow custom logger to be used in Background 2020-03-12 08:40:37 -07:00
Ken Hibino
d664d68fa4 Extract out log package 2020-03-09 07:17:52 -07:00
Ken Hibino
a425f54d23 [ci skip] Remove todo comment 2020-03-09 06:09:07 -07:00
Ken Hibino
3c722386b0 Add Deadline option when enqueuing tasks
Deadline option sets the deadline for the given task's context deadline.
2020-03-08 17:12:42 -07:00
Ken Hibino
25992c2781 [ci skip] Minor readme update 2020-03-03 21:39:16 -08:00
Ken Hibino
b9e3cad7a7 [ci skip] Update readme
- Added flow chart for task queue
- Reordered sections
2020-03-02 07:06:11 -08:00
Ken Hibino
b6486716b4 v0.6.0 2020-03-01 15:54:59 -08:00
Ken Hibino
742ed6546f Add ServeMux type
Allow user to use ServeMux type to be used as a Handler.
ServeMux API is design to be similar to net/http.ServeMux API.
2020-03-01 15:53:18 -08:00
Ken Hibino
897ab4e28b Add ErrorHandler type to changelog 2020-02-29 22:09:13 -08:00
Ken Hibino
a4e4c0b1d5 Call error handler when task was not processed successfully 2020-02-29 22:09:13 -08:00
Ken Hibino
95b7dcaad4 Clean up processor test 2020-02-29 22:09:13 -08:00
Ken Hibino
8d3248e850 Add ErrorHandler type and add it to Config 2020-02-29 22:09:13 -08:00
Ken Hibino
e69264dc04 Run travis build with go v1.14.x 2020-02-27 08:28:35 -08:00
30 changed files with 2132 additions and 593 deletions

View File

@@ -2,9 +2,7 @@ language: go
go_import_path: github.com/hibiken/asynq go_import_path: github.com/hibiken/asynq
git: git:
depth: 1 depth: 1
env: go: [1.13.x, 1.14.x]
- GO111MODULE=on # go modules are the default
go: [1.12.x, 1.13.x]
script: script:
- go test -race -v -coverprofile=coverage.txt -covermode=atomic ./... - go test -race -v -coverprofile=coverage.txt -covermode=atomic ./...
services: services:

View File

@@ -7,6 +7,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.7.0] - 2020-03-22
### Changed
- Support Go v1.13+, dropped support for go v1.12
### Added
- `Unique` option was added to allow client to enqueue a task only if it's unique within a certain time period.
## [0.6.2] - 2020-03-15
### Added
- `Use` method was added to `ServeMux` to apply middlewares to all handlers.
## [0.6.1] - 2020-03-12
### Added
- `Client` can optionally schedule task with `asynq.Deadline(time)` to specify deadline for task's context. Default is no deadline.
- `Logger` option was added to config, which allows user to specify the logger used by the background instance.
## [0.6.0] - 2020-03-01
### Added
- Added `ServeMux` type to make it easy for users to implement Handler interface.
- `ErrorHandler` type was added. Allow users to specify error handling function (e.g. Report error to error reporting service such as Honeybadger, Bugsnag, etc)
## [0.5.0] - 2020-02-23 ## [0.5.0] - 2020-02-23
### Changed ### Changed

215
README.md
View File

@@ -12,15 +12,7 @@ It is backed by Redis and it is designed to have a low barrier to entry. It shou
**Important Note**: Current major version is zero (v0.x.x) to accomodate rapid development and fast iteration while getting early feedback from users. The public API could change without a major version update before v1.0.0 release. **Important Note**: Current major version is zero (v0.x.x) to accomodate rapid development and fast iteration while getting early feedback from users. The public API could change without a major version update before v1.0.0 release.
![Gif](/docs/assets/demo.gif) ![Task Queue Diagram](/docs/assets/task-queue.png)
## Installation
To install `asynq` library, run the following command:
```sh
go get -u github.com/hibiken/asynq
```
## Quickstart ## Quickstart
@@ -30,48 +22,134 @@ First, make sure you are running a Redis server locally.
$ redis-server $ redis-server
``` ```
To create and schedule tasks, use `Client` and provide a task and when to process the task. Next, write a package that encapslates task creation and task handling.
```go ```go
func main() { package tasks
r := &asynq.RedisClientOpt{
Addr: "127.0.0.1:6379", import (
"fmt"
"github.com/hibiken/asynq"
)
// A list of background task types.
const (
EmailDelivery = "email:deliver"
ImageProcessing = "image:process"
)
// Write function NewXXXTask to create a task.
func NewEmailDeliveryTask(userID int, tmplID string) *asynq.Task {
payload := map[string]interface{}{"user_id": userID, "template_id": tmplID}
return asynq.NewTask(EmailDelivery, payload)
}
func NewImageProcessingTask(src, dst string) *asynq.Task {
payload := map[string]interface{}{"src": src, "dst": dst}
return asynq.NewTask(ImageProcessing, payload)
}
// Write function HandleXXXTask to handle the given task.
// NOTE: It satisfies the asynq.HandlerFunc interface.
func HandleEmailDeliveryTask(ctx context.Context, t *asynq.Task) error {
userID, err := t.Payload.GetInt("user_id")
if err != nil {
return err
} }
tmplID, err := t.Payload.GetString("template_id")
if err != nil {
return err
}
fmt.Printf("Send Email to User: user_id = %d, template_id = %s\n", userID, tmplID)
// Email delivery logic ...
return nil
}
client := asynq.NewClient(r) func HandleImageProcessingTask(ctx context.Context, t *asynq.Task) error {
src, err := t.Payload.GetString("src")
// Create a task with task type and payload if err != nil {
t1 := asynq.NewTask("send_welcome_email", map[string]interface{}{"user_id": 42}) return err
}
t2 := asynq.NewTask("send_reminder_email", map[string]interface{}{"user_id": 42}) dst, err := t.Payload.GetString("dst")
if err != nil {
// Process immediately return err
err := client.Enqueue(t1) }
fmt.Printf("Process image: src = %s, dst = %s\n", src, dst)
// Process 24 hrs later // Image processing logic ...
err = client.EnqueueIn(24*time.Hour, t2) return nil
// Process at specified time.
t := time.Date(2020, time.March, 6, 10, 0, 0, 0, time.UTC)
err = client.EnqueueAt(t, t2)
// Pass options to specify processing behavior for a given task.
//
// MaxRetry specifies the maximum number of times this task will be retried (Default is 25).
// Queue specifies which queue to enqueue this task to (Default is "default").
// Timeout specifies the the timeout for the task's context (Default is no timeout).
err = client.Enqueue(t1, asynq.MaxRetry(10), asynq.Queue("critical"), asynq.Timeout(time.Minute))
} }
``` ```
To start the background workers, use `Background` and provide your `Handler` to process the tasks. In your web application code, import the above package and use [`Client`](https://pkg.go.dev/github.com/hibiken/asynq?tab=doc#Client) to enqueue tasks to the task queue.
A task will be processed by a background worker as soon as the task gets enqueued.
Scheduled tasks will be stored in Redis and will be enqueued at the specified time.
```go ```go
package main
import (
"time"
"github.com/hibiken/asynq"
"your/app/package/tasks"
)
const redisAddr = "127.0.0.1:6379"
func main() { func main() {
r := &asynq.RedisClientOpt{ r := &asynq.RedisClientOpt{Addr: redisAddr}
Addr: "127.0.0.1:6379", c := asynq.NewClient(r)
// Example 1: Enqueue task to be processed immediately.
t := tasks.NewEmailDeliveryTask(42, "some:template:id")
err := c.Enqueue(t)
if err != nil {
log.Fatal("could not enqueue task: %v", err)
} }
// Example 2: Schedule task to be processed in the future.
t = tasks.NewEmailDeliveryTask(42, "other:template:id")
err = c.EnqueueIn(24*time.Hour, t)
if err != nil {
log.Fatal("could not schedule task: %v", err)
}
// Example 3: Pass options to tune task processing behavior.
// Options include MaxRetry, Queue, Timeout, Deadline, etc.
t = tasks.NewImageProcessingTask("some/blobstore/url", "other/blobstore/url")
err = c.Enqueue(t, asynq.MaxRetry(10), asynq.Queue("critical"), asynq.Timeout(time.Minute))
if err != nil {
log.Fatal("could not enqueue task: %v", err)
}
}
```
Next, create a binary to process these tasks in the background.
To start the background workers, use [`Background`](https://pkg.go.dev/github.com/hibiken/asynq?tab=doc#Background) and provide your [`Handler`](https://pkg.go.dev/github.com/hibiken/asynq?tab=doc#Handler) to process the tasks.
You can optionally use [`ServeMux`](https://pkg.go.dev/github.com/hibiken/asynq?tab=doc#ServeMux) to create a handler, just as you would with [`"net/http"`](https://golang.org/pkg/net/http/) Handler.
```go
package main
import (
"github.com/hibiken/asynq"
"your/app/package/tasks"
)
const redisAddr = "127.0.0.1:6379"
func main() {
r := &asynq.RedisClientOpt{Addr: redisAddr}
bg := asynq.NewBackground(r, &asynq.Config{ bg := asynq.NewBackground(r, &asynq.Config{
// Specify how many concurrent workers to use // Specify how many concurrent workers to use
Concurrency: 10, Concurrency: 10,
@@ -84,45 +162,50 @@ func main() {
// See the godoc for other configuration options // See the godoc for other configuration options
}) })
bg.Run(handler) // mux maps a type to a handler
} mux := asynq.NewServeMux()
``` mux.HandleFunc(tasks.EmailDelivery, tasks.HandleEmailDeliveryTask)
mux.HandleFunc(tasks.ImageProcessing, tasks.HandleImageProcessingTask)
// ...register other handlers...
`Handler` is an interface with one method `ProcessTask` with the following signature. bg.Run(mux)
```go
// ProcessTask should return nil if the processing of a task
// is successful.
//
// If ProcessTask return a non-nil error or panics, the task
// will be retried after delay.
type Handler interface {
ProcessTask(context.Context, *asynq.Task) error
} }
``` ```
For a more detailed walk-through of the library, see our [Getting Started Guide](https://github.com/hibiken/asynq/wiki/Getting-Started). For a more detailed walk-through of the library, see our [Getting Started Guide](https://github.com/hibiken/asynq/wiki/Getting-Started).
To Learn more about `asynq` features and APIs, see our [Wiki pages](https://github.com/hibiken/asynq/wiki) and [godoc](https://godoc.org/github.com/hibiken/asynq). To Learn more about `asynq` features and APIs, see our [Wiki](https://github.com/hibiken/asynq/wiki) and [godoc](https://godoc.org/github.com/hibiken/asynq).
## Command Line Tool
Asynq ships with a command line tool to inspect the state of queues and tasks.
Here's an example of running the `stats` command.
![Gif](/docs/assets/demo.gif)
For details on how to use the tool, refer to the tool's [README](/tools/asynqmon/README.md).
## Installation
To install `asynq` library, run the following command:
```sh
go get -u github.com/hibiken/asynq
```
To install the CLI tool, run the following command:
```sh
go get -u github.com/hibiken/asynq/tools/asynqmon
```
## Requirements ## Requirements
| Dependency | Version | | Dependency | Version |
| -------------------------- | ------- | | -------------------------- | ------- |
| [Redis](https://redis.io/) | v2.8+ | | [Redis](https://redis.io/) | v2.8+ |
| [Go](https://golang.org/) | v1.12+ | | [Go](https://golang.org/) | v1.13+ |
## Command Line Tool
Asynq ships with a command line tool to inspect the state of queues and tasks.
To install, run the following command:
```sh
go get -u github.com/hibiken/asynq/tools/asynqmon
```
For details on how to use the tool, refer to the tool's [README](/tools/asynqmon/README.md).
## Contributing ## Contributing

View File

@@ -138,6 +138,6 @@ func createRedisClient(r RedisConnOpt) *redis.Client {
TLSConfig: r.TLSConfig, TLSConfig: r.TLSConfig,
}) })
default: default:
panic(fmt.Sprintf("unexpected type %T for RedisConnOpt", r)) panic(fmt.Sprintf("asynq: unexpected type %T for RedisConnOpt", r))
} }
} }

View File

@@ -5,12 +5,14 @@
package asynq package asynq
import ( import (
"os"
"sort" "sort"
"testing" "testing"
"github.com/go-redis/redis/v7" "github.com/go-redis/redis/v7"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/log"
) )
// This file defines test helper functions used by // This file defines test helper functions used by
@@ -22,6 +24,8 @@ const (
redisDB = 14 redisDB = 14
) )
var testLogger = log.NewLogger(os.Stderr)
func setup(tb testing.TB) *redis.Client { func setup(tb testing.TB) *redis.Client {
tb.Helper() tb.Helper()
r := redis.NewClient(&redis.Options{ r := redis.NewClient(&redis.Options{

View File

@@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/log"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
) )
@@ -39,6 +40,8 @@ type Background struct {
// wait group to wait for all goroutines to finish. // wait group to wait for all goroutines to finish.
wg sync.WaitGroup wg sync.WaitGroup
logger Logger
rdb *rdb.RDB rdb *rdb.RDB
scheduler *scheduler scheduler *scheduler
processor *processor processor *processor
@@ -89,6 +92,59 @@ type Config struct {
// The tasks in lower priority queues are processed only when those queues with // The tasks in lower priority queues are processed only when those queues with
// higher priorities are empty. // higher priorities are empty.
StrictPriority bool StrictPriority bool
// ErrorHandler handles errors returned by the task handler.
//
// HandleError is invoked only if the task handler returns a non-nil error.
//
// Example:
// func reportError(task *asynq.Task, err error, retried, maxRetry int) {
// if retried >= maxRetry {
// err = fmt.Errorf("retry exhausted for task %s: %w", task.Type, err)
// }
// errorReportingService.Notify(err)
// })
//
// ErrorHandler: asynq.ErrorHandlerFunc(reportError)
ErrorHandler ErrorHandler
// Logger specifies the logger used by the background instance.
//
// If unset, default logger is used.
Logger Logger
}
// An ErrorHandler handles errors returned by the task handler.
type ErrorHandler interface {
HandleError(task *Task, err error, retried, maxRetry int)
}
// The ErrorHandlerFunc type is an adapter to allow the use of ordinary functions as a ErrorHandler.
// If f is a function with the appropriate signature, ErrorHandlerFunc(f) is a ErrorHandler that calls f.
type ErrorHandlerFunc func(task *Task, err error, retried, maxRetry int)
// HandleError calls fn(task, err, retried, maxRetry)
func (fn ErrorHandlerFunc) HandleError(task *Task, err error, retried, maxRetry int) {
fn(task, err, retried, maxRetry)
}
// Logger implements logging with various log levels.
type Logger interface {
// Debug logs a message at Debug level.
Debug(format string, args ...interface{})
// Info logs a message at Info level.
Info(format string, args ...interface{})
// Warn logs a message at Warning level.
Warn(format string, args ...interface{})
// Error logs a message at Error level.
Error(format string, args ...interface{})
// Fatal logs a message at Fatal level
// and process will exit with status set to 1.
Fatal(format string, args ...interface{})
} }
// Formula taken from https://github.com/mperham/sidekiq. // Formula taken from https://github.com/mperham/sidekiq.
@@ -122,6 +178,10 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
if len(queues) == 0 { if len(queues) == 0 {
queues = defaultQueueConfig queues = defaultQueueConfig
} }
logger := cfg.Logger
if logger == nil {
logger = log.NewLogger(os.Stderr)
}
host, err := os.Hostname() host, err := os.Hostname()
if err != nil { if err != nil {
@@ -133,12 +193,13 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
ps := base.NewProcessState(host, pid, n, queues, cfg.StrictPriority) ps := base.NewProcessState(host, pid, n, queues, cfg.StrictPriority)
syncCh := make(chan *syncRequest) syncCh := make(chan *syncRequest)
cancels := base.NewCancelations() cancels := base.NewCancelations()
syncer := newSyncer(syncCh, 5*time.Second) syncer := newSyncer(logger, syncCh, 5*time.Second)
heartbeater := newHeartbeater(rdb, ps, 5*time.Second) heartbeater := newHeartbeater(logger, rdb, ps, 5*time.Second)
scheduler := newScheduler(rdb, 5*time.Second, queues) scheduler := newScheduler(logger, rdb, 5*time.Second, queues)
processor := newProcessor(rdb, ps, delayFunc, syncCh, cancels) processor := newProcessor(logger, rdb, ps, delayFunc, syncCh, cancels, cfg.ErrorHandler)
subscriber := newSubscriber(rdb, cancels) subscriber := newSubscriber(logger, rdb, cancels)
return &Background{ return &Background{
logger: logger,
rdb: rdb, rdb: rdb,
ps: ps, ps: ps,
scheduler: scheduler, scheduler: scheduler,
@@ -149,7 +210,7 @@ func NewBackground(r RedisConnOpt, cfg *Config) *Background {
} }
} }
// A Handler processes a task. // A Handler processes tasks.
// //
// ProcessTask should return nil if the processing of a task // ProcessTask should return nil if the processing of a task
// is successful. // is successful.
@@ -176,14 +237,20 @@ func (fn HandlerFunc) ProcessTask(ctx context.Context, task *Task) error {
// a signal, it gracefully shuts down all pending workers and other // a signal, it gracefully shuts down all pending workers and other
// goroutines to process the tasks. // goroutines to process the tasks.
func (bg *Background) Run(handler Handler) { func (bg *Background) Run(handler Handler) {
logger.SetPrefix(fmt.Sprintf("asynq: pid=%d ", os.Getpid())) type prefixLogger interface {
logger.info("Starting processing") SetPrefix(prefix string)
}
// If logger supports setting prefix, then set prefix for log output.
if l, ok := bg.logger.(prefixLogger); ok {
l.SetPrefix(fmt.Sprintf("asynq: pid=%d ", os.Getpid()))
}
bg.logger.Info("Starting processing")
bg.start(handler) bg.start(handler)
defer bg.stop() defer bg.stop()
logger.info("Send signal TSTP to stop processing new tasks") bg.logger.Info("Send signal TSTP to stop processing new tasks")
logger.info("Send signal TERM or INT to terminate the process") bg.logger.Info("Send signal TERM or INT to terminate the process")
// Wait for a signal to terminate. // Wait for a signal to terminate.
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
@@ -198,7 +265,7 @@ func (bg *Background) Run(handler Handler) {
break break
} }
fmt.Println() fmt.Println()
logger.info("Starting graceful shutdown") bg.logger.Info("Starting graceful shutdown")
} }
// starts the background-task processing. // starts the background-task processing.
@@ -242,5 +309,5 @@ func (bg *Background) stop() {
bg.rdb.Close() bg.rdb.Close()
bg.running = false bg.running = false
logger.info("Bye!") bg.logger.Info("Bye!")
} }

125
client.go
View File

@@ -5,6 +5,9 @@
package asynq package asynq
import ( import (
"errors"
"fmt"
"sort"
"strings" "strings"
"time" "time"
@@ -34,9 +37,11 @@ type Option interface{}
// Internal option representations. // Internal option representations.
type ( type (
retryOption int retryOption int
queueOption string queueOption string
timeoutOption time.Duration timeoutOption time.Duration
deadlineOption time.Time
uniqueOption time.Duration
) )
// MaxRetry returns an option to specify the max number of times // MaxRetry returns an option to specify the max number of times
@@ -64,17 +69,43 @@ func Timeout(d time.Duration) Option {
return timeoutOption(d) return timeoutOption(d)
} }
// Deadline returns an option to specify the deadline for the given task.
func Deadline(t time.Time) Option {
return deadlineOption(t)
}
// Unique returns an option to enqueue a task only if the given task is unique.
// Task enqueued with this option is guaranteed to be unique within the given ttl.
// Once the task gets processed successfully or once the TTL has expired, another task with the same uniqueness may be enqueued.
// ErrDuplicateTask error is returned when enqueueing a duplicate task.
//
// Uniqueness of a task is based on the following properties:
// - Task Type
// - Task Payload
// - Queue Name
func Unique(ttl time.Duration) Option {
return uniqueOption(ttl)
}
// ErrDuplicateTask indicates that the given task could not be enqueued since it's a duplicate of another task.
//
// ErrDuplicateTask error only applies to tasks enqueued with a Unique option.
var ErrDuplicateTask = errors.New("task already exists")
type option struct { type option struct {
retry int retry int
queue string queue string
timeout time.Duration timeout time.Duration
deadline time.Time
uniqueTTL time.Duration
} }
func composeOptions(opts ...Option) option { func composeOptions(opts ...Option) option {
res := option{ res := option{
retry: defaultMaxRetry, retry: defaultMaxRetry,
queue: base.DefaultQueueName, queue: base.DefaultQueueName,
timeout: 0, timeout: 0,
deadline: time.Time{},
} }
for _, opt := range opts { for _, opt := range opts {
switch opt := opt.(type) { switch opt := opt.(type) {
@@ -84,6 +115,10 @@ func composeOptions(opts ...Option) option {
res.queue = string(opt) res.queue = string(opt)
case timeoutOption: case timeoutOption:
res.timeout = time.Duration(opt) res.timeout = time.Duration(opt)
case deadlineOption:
res.deadline = time.Time(opt)
case uniqueOption:
res.uniqueTTL = time.Duration(opt)
default: default:
// ignore unexpected option // ignore unexpected option
} }
@@ -91,6 +126,39 @@ func composeOptions(opts ...Option) option {
return res return res
} }
// uniqueKey computes the redis key used for the given task.
// It returns an empty string if ttl is zero.
func uniqueKey(t *Task, ttl time.Duration, qname string) string {
if ttl == 0 {
return ""
}
return fmt.Sprintf("%s:%s:%s", t.Type, serializePayload(t.Payload.data), qname)
}
func serializePayload(payload map[string]interface{}) string {
if payload == nil {
return "nil"
}
type entry struct {
k string
v interface{}
}
var es []entry
for k, v := range payload {
es = append(es, entry{k, v})
}
// sort entries by key
sort.Slice(es, func(i, j int) bool { return es[i].k < es[j].k })
var b strings.Builder
for _, e := range es {
if b.Len() > 0 {
b.WriteString(",")
}
b.WriteString(fmt.Sprintf("%s=%v", e.k, e.v))
}
return b.String()
}
const ( const (
// Max retry count by default // Max retry count by default
defaultMaxRetry = 25 defaultMaxRetry = 25
@@ -105,14 +173,25 @@ const (
func (c *Client) EnqueueAt(t time.Time, task *Task, opts ...Option) error { func (c *Client) EnqueueAt(t time.Time, task *Task, opts ...Option) error {
opt := composeOptions(opts...) opt := composeOptions(opts...)
msg := &base.TaskMessage{ msg := &base.TaskMessage{
ID: xid.New(), ID: xid.New(),
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Queue: opt.queue, Queue: opt.queue,
Retry: opt.retry, Retry: opt.retry,
Timeout: opt.timeout.String(), Timeout: opt.timeout.String(),
Deadline: opt.deadline.Format(time.RFC3339),
UniqueKey: uniqueKey(task, opt.uniqueTTL, opt.queue),
} }
return c.enqueue(msg, t) var err error
if time.Now().After(t) {
err = c.enqueue(msg, opt.uniqueTTL)
} else {
err = c.schedule(msg, t, opt.uniqueTTL)
}
if err == rdb.ErrDuplicateTask {
return fmt.Errorf("%w", ErrDuplicateTask)
}
return err
} }
// Enqueue enqueues task to be processed immediately. // Enqueue enqueues task to be processed immediately.
@@ -135,9 +214,17 @@ func (c *Client) EnqueueIn(d time.Duration, task *Task, opts ...Option) error {
return c.EnqueueAt(time.Now().Add(d), task, opts...) return c.EnqueueAt(time.Now().Add(d), task, opts...)
} }
func (c *Client) enqueue(msg *base.TaskMessage, t time.Time) error { func (c *Client) enqueue(msg *base.TaskMessage, uniqueTTL time.Duration) error {
if time.Now().After(t) { if uniqueTTL > 0 {
return c.rdb.Enqueue(msg) return c.rdb.EnqueueUnique(msg, uniqueTTL)
}
return c.rdb.Enqueue(msg)
}
func (c *Client) schedule(msg *base.TaskMessage, t time.Time, uniqueTTL time.Duration) error {
if uniqueTTL > 0 {
ttl := t.Add(uniqueTTL).Sub(time.Now())
return c.rdb.ScheduleUnique(msg, t, ttl)
} }
return c.rdb.Schedule(msg, t) return c.rdb.Schedule(msg, t)
} }

View File

@@ -5,10 +5,12 @@
package asynq package asynq
import ( import (
"errors"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
) )
@@ -25,6 +27,9 @@ func TestClientEnqueueAt(t *testing.T) {
var ( var (
now = time.Now() now = time.Now()
oneHourLater = now.Add(time.Hour) oneHourLater = now.Add(time.Hour)
noTimeout = time.Duration(0).String()
noDeadline = time.Time{}.Format(time.RFC3339)
) )
tests := []struct { tests := []struct {
@@ -41,13 +46,14 @@ func TestClientEnqueueAt(t *testing.T) {
processAt: now, processAt: now,
opts: []Option{}, opts: []Option{},
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -62,11 +68,12 @@ func TestClientEnqueueAt(t *testing.T) {
wantScheduled: []h.ZSetEntry{ wantScheduled: []h.ZSetEntry{
{ {
Msg: &base.TaskMessage{ Msg: &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
Score: float64(oneHourLater.Unix()), Score: float64(oneHourLater.Unix()),
}, },
@@ -106,6 +113,11 @@ func TestClientEnqueue(t *testing.T) {
task := NewTask("send_email", map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"}) task := NewTask("send_email", map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"})
var (
noTimeout = time.Duration(0).String()
noDeadline = time.Time{}.Format(time.RFC3339)
)
tests := []struct { tests := []struct {
desc string desc string
task *Task task *Task
@@ -119,13 +131,14 @@ func TestClientEnqueue(t *testing.T) {
MaxRetry(3), MaxRetry(3),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: 3, Retry: 3,
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -137,13 +150,14 @@ func TestClientEnqueue(t *testing.T) {
MaxRetry(-2), MaxRetry(-2),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: 0, // Retry count should be set to zero Retry: 0, // Retry count should be set to zero
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -156,13 +170,14 @@ func TestClientEnqueue(t *testing.T) {
MaxRetry(10), MaxRetry(10),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: 10, // Last option takes precedence Retry: 10, // Last option takes precedence
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -174,13 +189,14 @@ func TestClientEnqueue(t *testing.T) {
Queue("custom"), Queue("custom"),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"custom": []*base.TaskMessage{ "custom": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "custom", Queue: "custom",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -192,31 +208,52 @@ func TestClientEnqueue(t *testing.T) {
Queue("HIGH"), Queue("HIGH"),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"high": []*base.TaskMessage{ "high": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "high", Queue: "high",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
}, },
{ {
desc: "Timeout option sets the timeout duration", desc: "With timeout option",
task: task, task: task,
opts: []Option{ opts: []Option{
Timeout(20 * time.Second), Timeout(20 * time.Second),
}, },
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
Timeout: (20 * time.Second).String(), Timeout: (20 * time.Second).String(),
Deadline: noDeadline,
},
},
},
},
{
desc: "With deadline option",
task: task,
opts: []Option{
Deadline(time.Date(2020, time.June, 24, 0, 0, 0, 0, time.UTC)),
},
wantEnqueued: map[string][]*base.TaskMessage{
"default": {
{
Type: task.Type,
Payload: task.Payload.data,
Retry: defaultMaxRetry,
Queue: "default",
Timeout: noTimeout,
Deadline: time.Date(2020, time.June, 24, 0, 0, 0, 0, time.UTC).Format(time.RFC3339),
}, },
}, },
}, },
@@ -250,6 +287,11 @@ func TestClientEnqueueIn(t *testing.T) {
task := NewTask("send_email", map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"}) task := NewTask("send_email", map[string]interface{}{"to": "customer@gmail.com", "from": "merchant@example.com"})
var (
noTimeout = time.Duration(0).String()
noDeadline = time.Time{}.Format(time.RFC3339)
)
tests := []struct { tests := []struct {
desc string desc string
task *Task task *Task
@@ -267,11 +309,12 @@ func TestClientEnqueueIn(t *testing.T) {
wantScheduled: []h.ZSetEntry{ wantScheduled: []h.ZSetEntry{
{ {
Msg: &base.TaskMessage{ Msg: &base.TaskMessage{
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
Score: float64(time.Now().Add(time.Hour).Unix()), Score: float64(time.Now().Add(time.Hour).Unix()),
}, },
@@ -283,13 +326,14 @@ func TestClientEnqueueIn(t *testing.T) {
delay: 0, delay: 0,
opts: []Option{}, opts: []Option{},
wantEnqueued: map[string][]*base.TaskMessage{ wantEnqueued: map[string][]*base.TaskMessage{
"default": []*base.TaskMessage{ "default": {
&base.TaskMessage{ {
Type: task.Type, Type: task.Type,
Payload: task.Payload.data, Payload: task.Payload.data,
Retry: defaultMaxRetry, Retry: defaultMaxRetry,
Queue: "default", Queue: "default",
Timeout: time.Duration(0).String(), Timeout: noTimeout,
Deadline: noDeadline,
}, },
}, },
}, },
@@ -319,3 +363,210 @@ func TestClientEnqueueIn(t *testing.T) {
} }
} }
} }
func TestUniqueKey(t *testing.T) {
tests := []struct {
desc string
task *Task
ttl time.Duration
qname string
want string
}{
{
"with zero TTL",
NewTask("email:send", map[string]interface{}{"a": 123, "b": "hello", "c": true}),
0,
"default",
"",
},
{
"with primitive types",
NewTask("email:send", map[string]interface{}{"a": 123, "b": "hello", "c": true}),
10 * time.Minute,
"default",
"email:send:a=123,b=hello,c=true:default",
},
{
"with unsorted keys",
NewTask("email:send", map[string]interface{}{"b": "hello", "c": true, "a": 123}),
10 * time.Minute,
"default",
"email:send:a=123,b=hello,c=true:default",
},
{
"with composite types",
NewTask("email:send",
map[string]interface{}{
"address": map[string]string{"line": "123 Main St", "city": "Boston", "state": "MA"},
"names": []string{"bob", "mike", "rob"}}),
10 * time.Minute,
"default",
"email:send:address=map[city:Boston line:123 Main St state:MA],names=[bob mike rob]:default",
},
{
"with complex types",
NewTask("email:send",
map[string]interface{}{
"time": time.Date(2020, time.July, 28, 0, 0, 0, 0, time.UTC),
"duration": time.Hour}),
10 * time.Minute,
"default",
"email:send:duration=1h0m0s,time=2020-07-28 00:00:00 +0000 UTC:default",
},
{
"with nil payload",
NewTask("reindex", nil),
10 * time.Minute,
"default",
"reindex:nil:default",
},
}
for _, tc := range tests {
got := uniqueKey(tc.task, tc.ttl, tc.qname)
if got != tc.want {
t.Errorf("%s: uniqueKey(%v, %v, %q) = %q, want %q", tc.desc, tc.task, tc.ttl, tc.qname, got, tc.want)
}
}
}
func TestEnqueueUnique(t *testing.T) {
r := setup(t)
c := NewClient(RedisClientOpt{
Addr: redisAddr,
DB: redisDB,
})
tests := []struct {
task *Task
ttl time.Duration
}{
{
NewTask("email", map[string]interface{}{"user_id": 123}),
time.Hour,
},
}
for _, tc := range tests {
h.FlushDB(t, r) // clean up db before each test case.
// Enqueue the task first. It should succeed.
err := c.Enqueue(tc.task, Unique(tc.ttl))
if err != nil {
t.Fatal(err)
}
gotTTL := r.TTL(uniqueKey(tc.task, tc.ttl, base.DefaultQueueName)).Val()
if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL = %v, want %v", gotTTL, tc.ttl)
continue
}
// Enqueue the task again. It should fail.
err = c.Enqueue(tc.task, Unique(tc.ttl))
if err == nil {
t.Errorf("Enqueueing %+v did not return an error", tc.task)
continue
}
if !errors.Is(err, ErrDuplicateTask) {
t.Errorf("Enqueueing %+v returned an error that is not ErrDuplicateTask", tc.task)
continue
}
}
}
func TestEnqueueInUnique(t *testing.T) {
r := setup(t)
c := NewClient(RedisClientOpt{
Addr: redisAddr,
DB: redisDB,
})
tests := []struct {
task *Task
d time.Duration
ttl time.Duration
}{
{
NewTask("reindex", nil),
time.Hour,
10 * time.Minute,
},
}
for _, tc := range tests {
h.FlushDB(t, r) // clean up db before each test case.
// Enqueue the task first. It should succeed.
err := c.EnqueueIn(tc.d, tc.task, Unique(tc.ttl))
if err != nil {
t.Fatal(err)
}
gotTTL := r.TTL(uniqueKey(tc.task, tc.ttl, base.DefaultQueueName)).Val()
wantTTL := time.Duration(tc.ttl.Seconds()+tc.d.Seconds()) * time.Second
if !cmp.Equal(wantTTL.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL = %v, want %v", gotTTL, wantTTL)
continue
}
// Enqueue the task again. It should fail.
err = c.EnqueueIn(tc.d, tc.task, Unique(tc.ttl))
if err == nil {
t.Errorf("Enqueueing %+v did not return an error", tc.task)
continue
}
if !errors.Is(err, ErrDuplicateTask) {
t.Errorf("Enqueueing %+v returned an error that is not ErrDuplicateTask", tc.task)
continue
}
}
}
func TestEnqueueAtUnique(t *testing.T) {
r := setup(t)
c := NewClient(RedisClientOpt{
Addr: redisAddr,
DB: redisDB,
})
tests := []struct {
task *Task
at time.Time
ttl time.Duration
}{
{
NewTask("reindex", nil),
time.Now().Add(time.Hour),
10 * time.Minute,
},
}
for _, tc := range tests {
h.FlushDB(t, r) // clean up db before each test case.
// Enqueue the task first. It should succeed.
err := c.EnqueueAt(tc.at, tc.task, Unique(tc.ttl))
if err != nil {
t.Fatal(err)
}
gotTTL := r.TTL(uniqueKey(tc.task, tc.ttl, base.DefaultQueueName)).Val()
wantTTL := tc.at.Add(tc.ttl).Sub(time.Now())
if !cmp.Equal(wantTTL.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL = %v, want %v", gotTTL, wantTTL)
continue
}
// Enqueue the task again. It should fail.
err = c.EnqueueAt(tc.at, tc.task, Unique(tc.ttl))
if err == nil {
t.Errorf("Enqueueing %+v did not return an error", tc.task)
continue
}
if !errors.Is(err, ErrDuplicateTask) {
t.Errorf("Enqueueing %+v returned an error that is not ErrDuplicateTask", tc.task)
continue
}
}
}

BIN
docs/assets/task-queue.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@@ -15,7 +15,8 @@ import (
// heartbeater is responsible for writing process info to redis periodically to // heartbeater is responsible for writing process info to redis periodically to
// indicate that the background worker process is up. // indicate that the background worker process is up.
type heartbeater struct { type heartbeater struct {
rdb *rdb.RDB logger Logger
rdb *rdb.RDB
ps *base.ProcessState ps *base.ProcessState
@@ -26,8 +27,9 @@ type heartbeater struct {
interval time.Duration interval time.Duration
} }
func newHeartbeater(rdb *rdb.RDB, ps *base.ProcessState, interval time.Duration) *heartbeater { func newHeartbeater(l Logger, rdb *rdb.RDB, ps *base.ProcessState, interval time.Duration) *heartbeater {
return &heartbeater{ return &heartbeater{
logger: l,
rdb: rdb, rdb: rdb,
ps: ps, ps: ps,
done: make(chan struct{}), done: make(chan struct{}),
@@ -36,7 +38,7 @@ func newHeartbeater(rdb *rdb.RDB, ps *base.ProcessState, interval time.Duration)
} }
func (h *heartbeater) terminate() { func (h *heartbeater) terminate() {
logger.info("Heartbeater shutting down...") h.logger.Info("Heartbeater shutting down...")
// Signal the heartbeater goroutine to stop. // Signal the heartbeater goroutine to stop.
h.done <- struct{}{} h.done <- struct{}{}
} }
@@ -52,7 +54,7 @@ func (h *heartbeater) start(wg *sync.WaitGroup) {
select { select {
case <-h.done: case <-h.done:
h.rdb.ClearProcessState(h.ps) h.rdb.ClearProcessState(h.ps)
logger.info("Heartbeater done") h.logger.Info("Heartbeater done")
return return
case <-time.After(h.interval): case <-time.After(h.interval):
h.beat() h.beat()
@@ -66,6 +68,6 @@ func (h *heartbeater) beat() {
// and short enough to expire quickly once the process is shut down or killed. // and short enough to expire quickly once the process is shut down or killed.
err := h.rdb.WriteProcessState(h.ps, h.interval*2) err := h.rdb.WriteProcessState(h.ps, h.interval*2)
if err != nil { if err != nil {
logger.error("could not write heartbeat data: %v", err) h.logger.Error("could not write heartbeat data: %v", err)
} }
} }

View File

@@ -36,7 +36,7 @@ func TestHeartbeater(t *testing.T) {
h.FlushDB(t, r) h.FlushDB(t, r)
state := base.NewProcessState(tc.host, tc.pid, tc.concurrency, tc.queues, false) state := base.NewProcessState(tc.host, tc.pid, tc.concurrency, tc.queues, false)
hb := newHeartbeater(rdbClient, state, tc.interval) hb := newHeartbeater(testLogger, rdbClient, state, tc.interval)
var wg sync.WaitGroup var wg sync.WaitGroup
hb.start(&wg) hb.start(&wg)

View File

@@ -90,6 +90,18 @@ type TaskMessage struct {
// //
// Zero means no limit. // Zero means no limit.
Timeout string Timeout string
// Deadline specifies the deadline for the task.
// Task won't be processed if it exceeded its deadline.
// The string shoulbe be in RFC3339 format.
//
// time.Time's zero value means no deadline.
Deadline string
// UniqueKey holds the redis key used for uniqueness lock for this task.
//
// Empty string indicates that no uniqueness lock was used.
UniqueKey string
} }
// ProcessState holds process level information. // ProcessState holds process level information.

View File

@@ -110,9 +110,9 @@ func TestProcessStateConcurrentAccess(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
started := time.Now() started := time.Now()
msgs := []*TaskMessage{ msgs := []*TaskMessage{
&TaskMessage{ID: xid.New(), Type: "type1", Payload: map[string]interface{}{"user_id": 42}}, {ID: xid.New(), Type: "type1", Payload: map[string]interface{}{"user_id": 42}},
&TaskMessage{ID: xid.New(), Type: "type2"}, {ID: xid.New(), Type: "type2"},
&TaskMessage{ID: xid.New(), Type: "type3"}, {ID: xid.New(), Type: "type3"},
} }
// Simulate hearbeater calling SetStatus and SetStarted. // Simulate hearbeater calling SetStatus and SetStarted.

57
internal/log/log.go Normal file
View File

@@ -0,0 +1,57 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
// Package log exports logging related types and functions.
package log
import (
"io"
stdlog "log"
"os"
)
// NewLogger creates and returns a new instance of Logger.
func NewLogger(out io.Writer) *Logger {
return &Logger{
stdlog.New(out, "", stdlog.Ldate|stdlog.Ltime|stdlog.Lmicroseconds|stdlog.LUTC),
}
}
// Logger is a wrapper object around log.Logger from the standard library.
// It supports logging at various log levels.
type Logger struct {
*stdlog.Logger
}
// Debug logs a message at Debug level.
func (l *Logger) Debug(format string, args ...interface{}) {
format = "DEBUG: " + format
l.Printf(format, args...)
}
// Info logs a message at Info level.
func (l *Logger) Info(format string, args ...interface{}) {
format = "INFO: " + format
l.Printf(format, args...)
}
// Warn logs a message at Warning level.
func (l *Logger) Warn(format string, args ...interface{}) {
format = "WARN: " + format
l.Printf(format, args...)
}
// Error logs a message at Error level.
func (l *Logger) Error(format string, args ...interface{}) {
format = "ERROR: " + format
l.Printf(format, args...)
}
// Fatal logs a message at Fatal level
// and process will exit with status set to 1.
func (l *Logger) Fatal(format string, args ...interface{}) {
format = "FATAL: " + format
l.Printf(format, args...)
os.Exit(1)
}

View File

@@ -1,4 +1,8 @@
package asynq // Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package log
import ( import (
"bytes" "bytes"
@@ -20,6 +24,38 @@ type tester struct {
wantPattern string // regexp that log output must match wantPattern string // regexp that log output must match
} }
func TestLoggerDebug(t *testing.T) {
tests := []tester{
{
desc: "without trailing newline, logger adds newline",
message: "hello, world!",
wantPattern: fmt.Sprintf("^%s %s%s DEBUG: hello, world!\n$", rgxdate, rgxtime, rgxmicroseconds),
},
{
desc: "with trailing newline, logger preserves newline",
message: "hello, world!\n",
wantPattern: fmt.Sprintf("^%s %s%s DEBUG: hello, world!\n$", rgxdate, rgxtime, rgxmicroseconds),
},
}
for _, tc := range tests {
var buf bytes.Buffer
logger := NewLogger(&buf)
logger.Debug(tc.message)
got := buf.String()
matched, err := regexp.MatchString(tc.wantPattern, got)
if err != nil {
t.Fatal("pattern did not compile:", err)
}
if !matched {
t.Errorf("logger.info(%q) outputted %q, should match pattern %q",
tc.message, got, tc.wantPattern)
}
}
}
func TestLoggerInfo(t *testing.T) { func TestLoggerInfo(t *testing.T) {
tests := []tester{ tests := []tester{
{ {
@@ -36,9 +72,9 @@ func TestLoggerInfo(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
var buf bytes.Buffer var buf bytes.Buffer
logger := newLogger(&buf) logger := NewLogger(&buf)
logger.info(tc.message) logger.Info(tc.message)
got := buf.String() got := buf.String()
matched, err := regexp.MatchString(tc.wantPattern, got) matched, err := regexp.MatchString(tc.wantPattern, got)
@@ -68,9 +104,9 @@ func TestLoggerWarn(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
var buf bytes.Buffer var buf bytes.Buffer
logger := newLogger(&buf) logger := NewLogger(&buf)
logger.warn(tc.message) logger.Warn(tc.message)
got := buf.String() got := buf.String()
matched, err := regexp.MatchString(tc.wantPattern, got) matched, err := regexp.MatchString(tc.wantPattern, got)
@@ -100,9 +136,9 @@ func TestLoggerError(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
var buf bytes.Buffer var buf bytes.Buffer
logger := newLogger(&buf) logger := NewLogger(&buf)
logger.error(tc.message) logger.Error(tc.message)
got := buf.String() got := buf.String()
matched, err := regexp.MatchString(tc.wantPattern, got) matched, err := regexp.MatchString(tc.wantPattern, got)

View File

@@ -22,6 +22,9 @@ var (
// ErrTaskNotFound indicates that a task that matches the given identifier was not found. // ErrTaskNotFound indicates that a task that matches the given identifier was not found.
ErrTaskNotFound = errors.New("could not find a task") ErrTaskNotFound = errors.New("could not find a task")
// ErrDuplicateTask indicates that another task with the same unique key holds the uniqueness lock.
ErrDuplicateTask = errors.New("task already exists")
) )
const statsTTL = 90 * 24 * time.Hour // 90 days const statsTTL = 90 * 24 * time.Hour // 90 days
@@ -59,6 +62,46 @@ func (r *RDB) Enqueue(msg *base.TaskMessage) error {
return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, bytes).Err() return enqueueCmd.Run(r.client, []string{key, base.AllQueues}, bytes).Err()
} }
// KEYS[1] -> unique key in the form <type>:<payload>:<qname>
// KEYS[2] -> asynq:queues:<qname>
// KEYS[2] -> asynq:queues
// ARGV[1] -> task ID
// ARGV[2] -> uniqueness lock TTL
// ARGV[3] -> task message data
var enqueueUniqueCmd = redis.NewScript(`
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
if not ok then
return 0
end
redis.call("LPUSH", KEYS[2], ARGV[3])
redis.call("SADD", KEYS[3], KEYS[2])
return 1
`)
// EnqueueUnique inserts the given task if the task's uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) EnqueueUnique(msg *base.TaskMessage, ttl time.Duration) error {
bytes, err := json.Marshal(msg)
if err != nil {
return err
}
key := base.QueueKey(msg.Queue)
res, err := enqueueUniqueCmd.Run(r.client,
[]string{msg.UniqueKey, key, base.AllQueues},
msg.ID.String(), int(ttl.Seconds()), bytes).Result()
if err != nil {
return err
}
n, ok := res.(int64)
if !ok {
return fmt.Errorf("could not cast %v to int64", res)
}
if n == 0 {
return ErrDuplicateTask
}
return nil
}
// Dequeue queries given queues in order and pops a task message if there is one and returns it. // Dequeue queries given queues in order and pops a task message if there is one and returns it.
// If all queues are empty, ErrNoProcessableTask error is returned. // If all queues are empty, ErrNoProcessableTask error is returned.
func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) { func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) {
@@ -67,7 +110,6 @@ func (r *RDB) Dequeue(qnames ...string) (*base.TaskMessage, error) {
if len(qnames) == 1 { if len(qnames) == 1 {
data, err = r.dequeueSingle(base.QueueKey(qnames[0])) data, err = r.dequeueSingle(base.QueueKey(qnames[0]))
} else { } else {
// TODO(hibiken): Take keys are argument and don't compute every time
var keys []string var keys []string
for _, q := range qnames { for _, q := range qnames {
keys = append(keys, base.QueueKey(q)) keys = append(keys, base.QueueKey(q))
@@ -119,8 +161,10 @@ func (r *RDB) dequeue(queues ...string) (data string, err error) {
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress
// KEYS[2] -> asynq:processed:<yyyy-mm-dd> // KEYS[2] -> asynq:processed:<yyyy-mm-dd>
// KEYS[3] -> unique key in the format <type>:<payload>:<qname>
// ARGV[1] -> base.TaskMessage value // ARGV[1] -> base.TaskMessage value
// ARGV[2] -> stats expiration timestamp // ARGV[2] -> stats expiration timestamp
// ARGV[3] -> task ID
// Note: LREM count ZERO means "remove all elements equal to val" // Note: LREM count ZERO means "remove all elements equal to val"
var doneCmd = redis.NewScript(` var doneCmd = redis.NewScript(`
redis.call("LREM", KEYS[1], 0, ARGV[1]) redis.call("LREM", KEYS[1], 0, ARGV[1])
@@ -128,10 +172,14 @@ local n = redis.call("INCR", KEYS[2])
if tonumber(n) == 1 then if tonumber(n) == 1 then
redis.call("EXPIREAT", KEYS[2], ARGV[2]) redis.call("EXPIREAT", KEYS[2], ARGV[2])
end end
if string.len(KEYS[3]) > 0 and redis.call("GET", KEYS[3]) == ARGV[3] then
redis.call("DEL", KEYS[3])
end
return redis.status_reply("OK") return redis.status_reply("OK")
`) `)
// Done removes the task from in-progress queue to mark the task as done. // Done removes the task from in-progress queue to mark the task as done.
// It removes a uniqueness lock acquired by the task, if any.
func (r *RDB) Done(msg *base.TaskMessage) error { func (r *RDB) Done(msg *base.TaskMessage) error {
bytes, err := json.Marshal(msg) bytes, err := json.Marshal(msg)
if err != nil { if err != nil {
@@ -141,8 +189,8 @@ func (r *RDB) Done(msg *base.TaskMessage) error {
processedKey := base.ProcessedKey(now) processedKey := base.ProcessedKey(now)
expireAt := now.Add(statsTTL) expireAt := now.Add(statsTTL)
return doneCmd.Run(r.client, return doneCmd.Run(r.client,
[]string{base.InProgressQueue, processedKey}, []string{base.InProgressQueue, processedKey, msg.UniqueKey},
bytes, expireAt.Unix()).Err() bytes, expireAt.Unix(), msg.ID.String()).Err()
} }
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress
@@ -165,15 +213,71 @@ func (r *RDB) Requeue(msg *base.TaskMessage) error {
string(bytes)).Err() string(bytes)).Err()
} }
// KEYS[1] -> asynq:scheduled
// KEYS[2] -> asynq:queues
// ARGV[1] -> score (process_at timestamp)
// ARGV[2] -> task message
// ARGV[3] -> queue key
var scheduleCmd = redis.NewScript(`
redis.call("ZADD", KEYS[1], ARGV[1], ARGV[2])
redis.call("SADD", KEYS[2], ARGV[3])
return 1
`)
// Schedule adds the task to the backlog queue to be processed in the future. // Schedule adds the task to the backlog queue to be processed in the future.
func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error { func (r *RDB) Schedule(msg *base.TaskMessage, processAt time.Time) error {
bytes, err := json.Marshal(msg) bytes, err := json.Marshal(msg)
if err != nil { if err != nil {
return err return err
} }
qkey := base.QueueKey(msg.Queue)
score := float64(processAt.Unix()) score := float64(processAt.Unix())
return r.client.ZAdd(base.ScheduledQueue, return scheduleCmd.Run(r.client,
&redis.Z{Member: string(bytes), Score: score}).Err() []string{base.ScheduledQueue, base.AllQueues},
score, bytes, qkey).Err()
}
// KEYS[1] -> unique key in the format <type>:<payload>:<qname>
// KEYS[2] -> asynq:scheduled
// KEYS[3] -> asynq:queues
// ARGV[1] -> task ID
// ARGV[2] -> uniqueness lock TTL
// ARGV[3] -> score (process_at timestamp)
// ARGV[4] -> task message
// ARGV[5] -> queue key
var scheduleUniqueCmd = redis.NewScript(`
local ok = redis.call("SET", KEYS[1], ARGV[1], "NX", "EX", ARGV[2])
if not ok then
return 0
end
redis.call("ZADD", KEYS[2], ARGV[3], ARGV[4])
redis.call("SADD", KEYS[3], ARGV[5])
return 1
`)
// ScheduleUnique adds the task to the backlog queue to be processed in the future if the uniqueness lock can be acquired.
// It returns ErrDuplicateTask if the lock cannot be acquired.
func (r *RDB) ScheduleUnique(msg *base.TaskMessage, processAt time.Time, ttl time.Duration) error {
bytes, err := json.Marshal(msg)
if err != nil {
return err
}
qkey := base.QueueKey(msg.Queue)
score := float64(processAt.Unix())
res, err := scheduleUniqueCmd.Run(r.client,
[]string{msg.UniqueKey, base.ScheduledQueue, base.AllQueues},
msg.ID.String(), int(ttl.Seconds()), score, bytes, qkey).Result()
if err != nil {
return err
}
n, ok := res.(int64)
if !ok {
return fmt.Errorf("could not cast %v to int64", res)
}
if n == 0 {
return ErrDuplicateTask
}
return nil
} }
// KEYS[1] -> asynq:in_progress // KEYS[1] -> asynq:in_progress

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/rs/xid"
) )
// TODO(hibiken): Get Redis address and db number from ENV variables. // TODO(hibiken): Get Redis address and db number from ENV variables.
@@ -69,6 +70,48 @@ func TestEnqueue(t *testing.T) {
} }
} }
func TestEnqueueUnique(t *testing.T) {
r := setup(t)
m1 := base.TaskMessage{
ID: xid.New(),
Type: "email",
Payload: map[string]interface{}{"user_id": 123},
Queue: base.DefaultQueueName,
UniqueKey: "email:user_id=123:default",
}
tests := []struct {
msg *base.TaskMessage
ttl time.Duration // uniqueness ttl
}{
{&m1, time.Minute},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case.
err := r.EnqueueUnique(tc.msg, tc.ttl)
if err != nil {
t.Errorf("First message: (*RDB).EnqueueUnique(%v, %v) = %v, want nil",
tc.msg, tc.ttl, err)
continue
}
got := r.EnqueueUnique(tc.msg, tc.ttl)
if got != ErrDuplicateTask {
t.Errorf("Second message: (*RDB).EnqueueUnique(%v, %v) = %v, want %v",
tc.msg, tc.ttl, got, ErrDuplicateTask)
continue
}
gotTTL := r.client.TTL(tc.msg.UniqueKey).Val()
if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl)
continue
}
}
}
func TestDequeue(t *testing.T) { func TestDequeue(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello!"}) t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "hello!"})
@@ -188,6 +231,13 @@ func TestDone(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", nil) t1 := h.NewTaskMessage("send_email", nil)
t2 := h.NewTaskMessage("export_csv", nil) t2 := h.NewTaskMessage("export_csv", nil)
t3 := &base.TaskMessage{
ID: xid.New(),
Type: "reindex",
Payload: nil,
UniqueKey: "reindex:nil:default",
Queue: "default",
}
tests := []struct { tests := []struct {
inProgress []*base.TaskMessage // initial state of the in-progress list inProgress []*base.TaskMessage // initial state of the in-progress list
@@ -204,11 +254,25 @@ func TestDone(t *testing.T) {
target: t1, target: t1,
wantInProgress: []*base.TaskMessage{}, wantInProgress: []*base.TaskMessage{},
}, },
{
inProgress: []*base.TaskMessage{t1, t2, t3},
target: t3,
wantInProgress: []*base.TaskMessage{t1, t2},
},
} }
for _, tc := range tests { for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case h.FlushDB(t, r.client) // clean up db before each test case
h.SeedInProgressQueue(t, r.client, tc.inProgress) h.SeedInProgressQueue(t, r.client, tc.inProgress)
for _, msg := range tc.inProgress {
// Set uniqueness lock if unique key is present.
if len(msg.UniqueKey) > 0 {
err := r.client.SetNX(msg.UniqueKey, msg.ID.String(), time.Minute).Err()
if err != nil {
t.Fatal(err)
}
}
}
err := r.Done(tc.target) err := r.Done(tc.target)
if err != nil { if err != nil {
@@ -232,6 +296,10 @@ func TestDone(t *testing.T) {
if gotTTL > statsTTL { if gotTTL > statsTTL {
t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL) t.Errorf("TTL %q = %v, want less than or equal to %v", processedKey, gotTTL, statsTTL)
} }
if len(tc.target.UniqueKey) > 0 && r.client.Exists(tc.target.UniqueKey).Val() != 0 {
t.Errorf("Uniqueness lock %q still exists", tc.target.UniqueKey)
}
} }
} }
@@ -344,6 +412,58 @@ func TestSchedule(t *testing.T) {
} }
} }
func TestScheduleUnique(t *testing.T) {
r := setup(t)
m1 := base.TaskMessage{
ID: xid.New(),
Type: "email",
Payload: map[string]interface{}{"user_id": 123},
Queue: base.DefaultQueueName,
UniqueKey: "email:user_id=123:default",
}
tests := []struct {
msg *base.TaskMessage
processAt time.Time
ttl time.Duration // uniqueness lock ttl
}{
{&m1, time.Now().Add(15 * time.Minute), time.Minute},
}
for _, tc := range tests {
h.FlushDB(t, r.client) // clean up db before each test case
desc := fmt.Sprintf("(*RDB).ScheduleUnique(%v, %v, %v)", tc.msg, tc.processAt, tc.ttl)
err := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl)
if err != nil {
t.Errorf("Frist task: %s = %v, want nil", desc, err)
continue
}
gotScheduled := h.GetScheduledEntries(t, r.client)
if len(gotScheduled) != 1 {
t.Errorf("%s inserted %d items to %q, want 1 items inserted", desc, len(gotScheduled), base.ScheduledQueue)
continue
}
if int64(gotScheduled[0].Score) != tc.processAt.Unix() {
t.Errorf("%s inserted an item with score %d, want %d", desc, int64(gotScheduled[0].Score), tc.processAt.Unix())
continue
}
got := r.ScheduleUnique(tc.msg, tc.processAt, tc.ttl)
if got != ErrDuplicateTask {
t.Errorf("Second task: %s = %v, want %v",
desc, got, ErrDuplicateTask)
}
gotTTL := r.client.TTL(tc.msg.UniqueKey).Val()
if !cmp.Equal(tc.ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
t.Errorf("TTL %q = %v, want %v", tc.msg.UniqueKey, gotTTL, tc.ttl)
continue
}
}
}
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
r := setup(t) r := setup(t)
t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "Hola!"}) t1 := h.NewTaskMessage("send_email", map[string]interface{}{"subject": "Hola!"})
@@ -784,8 +904,7 @@ func TestWriteProcessState(t *testing.T) {
} }
// Check ProcessInfo TTL was set correctly // Check ProcessInfo TTL was set correctly
gotTTL := r.client.TTL(pkey).Val() gotTTL := r.client.TTL(pkey).Val()
timeCmpOpt := cmpopts.EquateApproxTime(time.Second) if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
if !cmp.Equal(ttl, gotTTL, timeCmpOpt) {
t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl)
} }
// Check ProcessInfo key was added to the set correctly // Check ProcessInfo key was added to the set correctly
@@ -858,8 +977,7 @@ func TestWriteProcessStateWithWorkers(t *testing.T) {
} }
// Check ProcessInfo TTL was set correctly // Check ProcessInfo TTL was set correctly
gotTTL := r.client.TTL(pkey).Val() gotTTL := r.client.TTL(pkey).Val()
timeCmpOpt := cmpopts.EquateApproxTime(time.Second) if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) {
if !cmp.Equal(ttl, gotTTL, timeCmpOpt) {
t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl) t.Errorf("TTL of %q was %v, want %v", pkey, gotTTL, ttl)
} }
// Check ProcessInfo key was added to the set correctly // Check ProcessInfo key was added to the set correctly
@@ -884,7 +1002,7 @@ func TestWriteProcessStateWithWorkers(t *testing.T) {
gotWorkers[key] = &w gotWorkers[key] = &w
} }
wantWorkers := map[string]*base.WorkerInfo{ wantWorkers := map[string]*base.WorkerInfo{
msg1.ID.String(): &base.WorkerInfo{ msg1.ID.String(): {
Host: host, Host: host,
PID: pid, PID: pid,
ID: msg1.ID, ID: msg1.ID,
@@ -893,7 +1011,7 @@ func TestWriteProcessStateWithWorkers(t *testing.T) {
Payload: msg1.Payload, Payload: msg1.Payload,
Started: w1Started, Started: w1Started,
}, },
msg2.ID.String(): &base.WorkerInfo{ msg2.ID.String(): {
Host: host, Host: host,
PID: pid, PID: pid,
ID: msg2.ID, ID: msg2.ID,

View File

@@ -1,35 +0,0 @@
package asynq
import (
"io"
"log"
"os"
)
// global logger used in asynq package.
var logger = newLogger(os.Stderr)
func newLogger(out io.Writer) *asynqLogger {
return &asynqLogger{
log.New(out, "", log.Ldate|log.Ltime|log.Lmicroseconds|log.LUTC),
}
}
type asynqLogger struct {
*log.Logger
}
func (l *asynqLogger) info(format string, args ...interface{}) {
format = "INFO: " + format
l.Printf(format, args...)
}
func (l *asynqLogger) warn(format string, args ...interface{}) {
format = "WARN: " + format
l.Printf(format, args...)
}
func (l *asynqLogger) error(format string, args ...interface{}) {
format = "ERROR: " + format
l.Printf(format, args...)
}

View File

@@ -14,333 +14,626 @@ import (
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
) )
func TestPayloadGet(t *testing.T) { type payloadTest struct {
names := []string{"luke", "anakin", "rey"} data map[string]interface{}
primes := []int{2, 3, 5, 7, 11, 13, 17} key string
user := map[string]interface{}{"name": "Ken", "score": 3.14} nonkey string
location := map[string]string{"address": "123 Main St.", "state": "NY", "zipcode": "10002"} }
favs := map[string][]string{
"movies": []string{"forrest gump", "star wars"}, func TestPayloadString(t *testing.T) {
"tv_shows": []string{"game of thrones", "HIMYM", "breaking bad"}, tests := []payloadTest{
{
data: map[string]interface{}{"name": "gopher"},
key: "name",
nonkey: "unknown",
},
} }
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetString(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("Payload.GetString(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetString(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("With Marshaling: Payload.GetString(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetString(tc.nonkey)
if err == nil || got != "" {
t.Errorf("Payload.GetString(%q) = %v, %v; want '', error",
tc.key, got, err)
}
}
}
func TestPayloadInt(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"user_id": 42},
key: "user_id",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetInt(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("Payload.GetInt(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetInt(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("With Marshaling: Payload.GetInt(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetInt(tc.nonkey)
if err == nil || got != 0 {
t.Errorf("Payload.GetInt(%q) = %v, %v; want 0, error",
tc.key, got, err)
}
}
}
func TestPayloadFloat64(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"pi": 3.14},
key: "pi",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetFloat64(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("Payload.GetFloat64(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetFloat64(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("With Marshaling: Payload.GetFloat64(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetFloat64(tc.nonkey)
if err == nil || got != 0 {
t.Errorf("Payload.GetFloat64(%q) = %v, %v; want 0, error",
tc.key, got, err)
}
}
}
func TestPayloadBool(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"enabled": true},
key: "enabled",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetBool(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("Payload.GetBool(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetBool(tc.key)
if err != nil || got != tc.data[tc.key] {
t.Errorf("With Marshaling: Payload.GetBool(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetBool(tc.nonkey)
if err == nil || got != false {
t.Errorf("Payload.GetBool(%q) = %v, %v; want false, error",
tc.key, got, err)
}
}
}
func TestPayloadStringSlice(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"names": []string{"luke", "rey", "anakin"}},
key: "names",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetStringSlice(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetStringSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringSlice(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetStringSlice(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetStringSlice(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadIntSlice(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"nums": []int{9, 8, 7}},
key: "nums",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetIntSlice(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetIntSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetIntSlice(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetIntSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetIntSlice(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetIntSlice(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadStringMap(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"user": map[string]interface{}{"name": "Jon Doe", "score": 2.2}},
key: "user",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetStringMap(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetStringMap(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringMap(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMap(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetStringMap(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetStringMap(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadStringMapString(t *testing.T) {
tests := []payloadTest{
{
data: map[string]interface{}{"address": map[string]string{"line": "123 Main St", "city": "San Francisco", "state": "CA"}},
key: "address",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetStringMapString(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetStringMapString(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringMapString(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMapString(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetStringMapString(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetStringMapString(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadStringMapStringSlice(t *testing.T) {
favs := map[string][]string{
"movies": {"forrest gump", "star wars"},
"tv_shows": {"game of thrones", "HIMYM", "breaking bad"},
}
tests := []payloadTest{
{
data: map[string]interface{}{"favorites": favs},
key: "favorites",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetStringMapStringSlice(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringMapStringSlice(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetStringMapStringSlice(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadStringMapInt(t *testing.T) {
counter := map[string]int{ counter := map[string]int{
"a": 1, "a": 1,
"b": 101, "b": 101,
"c": 42, "c": 42,
} }
tests := []payloadTest{
{
data: map[string]interface{}{"counts": counter},
key: "counts",
nonkey: "unknown",
},
}
for _, tc := range tests {
payload := Payload{tc.data}
got, err := payload.GetStringMapInt(tc.key)
diff := cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// encode and then decode task messsage.
in := h.NewTaskMessage("testing", tc.data)
b, err := json.Marshal(in)
if err != nil {
t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringMapInt(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMapInt(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
// access non-existent key.
got, err = payload.GetStringMapInt(tc.nonkey)
if err == nil || got != nil {
t.Errorf("Payload.GetStringMapInt(%q) = %v, %v; want nil, error",
tc.key, got, err)
}
}
}
func TestPayloadStringMapBool(t *testing.T) {
features := map[string]bool{ features := map[string]bool{
"A": false, "A": false,
"B": true, "B": true,
"C": true, "C": true,
} }
now := time.Now() tests := []payloadTest{
duration := 15 * time.Minute {
data: map[string]interface{}{"features": features},
data := map[string]interface{}{ key: "features",
"greeting": "Hello", nonkey: "unknown",
"user_id": 9876, },
"pi": 3.1415,
"enabled": false,
"names": names,
"primes": primes,
"user": user,
"location": location,
"favs": favs,
"counter": counter,
"features": features,
"timestamp": now,
"duration": duration,
}
payload := Payload{data}
gotStr, err := payload.GetString("greeting")
if gotStr != "Hello" || err != nil {
t.Errorf("Payload.GetString(%q) = %v, %v, want %v, nil",
"greeting", gotStr, err, "Hello")
} }
gotInt, err := payload.GetInt("user_id") for _, tc := range tests {
if gotInt != 9876 || err != nil { payload := Payload{tc.data}
t.Errorf("Payload.GetInt(%q) = %v, %v, want, %v, nil",
"user_id", gotInt, err, 9876)
}
gotFloat, err := payload.GetFloat64("pi") got, err := payload.GetStringMapBool(tc.key)
if gotFloat != 3.1415 || err != nil { diff := cmp.Diff(got, tc.data[tc.key])
t.Errorf("Payload.GetFloat64(%q) = %v, %v, want, %v, nil", if err != nil || diff != "" {
"pi", gotFloat, err, 3.141592) t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want %v, nil",
} tc.key, got, err, tc.data[tc.key])
}
gotBool, err := payload.GetBool("enabled") // encode and then decode task messsage.
if gotBool != false || err != nil { in := h.NewTaskMessage("testing", tc.data)
t.Errorf("Payload.GetBool(%q) = %v, %v, want, %v, nil", b, err := json.Marshal(in)
"enabled", gotBool, err, false) if err != nil {
} t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetStringMapBool(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetStringMapBool(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
gotStrSlice, err := payload.GetStringSlice("names") // access non-existent key.
if diff := cmp.Diff(gotStrSlice, names); diff != "" { got, err = payload.GetStringMapBool(tc.nonkey)
t.Errorf("Payload.GetStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", if err == nil || got != nil {
"names", gotStrSlice, err, names, diff) t.Errorf("Payload.GetStringMapBool(%q) = %v, %v; want nil, error",
} tc.key, got, err)
}
gotIntSlice, err := payload.GetIntSlice("primes")
if diff := cmp.Diff(gotIntSlice, primes); diff != "" {
t.Errorf("Payload.GetIntSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"primes", gotIntSlice, err, primes, diff)
}
gotStrMap, err := payload.GetStringMap("user")
if diff := cmp.Diff(gotStrMap, user); diff != "" {
t.Errorf("Payload.GetStringMap(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"user", gotStrMap, err, user, diff)
}
gotStrMapStr, err := payload.GetStringMapString("location")
if diff := cmp.Diff(gotStrMapStr, location); diff != "" {
t.Errorf("Payload.GetStringMapString(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"location", gotStrMapStr, err, location, diff)
}
gotStrMapStrSlice, err := payload.GetStringMapStringSlice("favs")
if diff := cmp.Diff(gotStrMapStrSlice, favs); diff != "" {
t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"favs", gotStrMapStrSlice, err, favs, diff)
}
gotStrMapInt, err := payload.GetStringMapInt("counter")
if diff := cmp.Diff(gotStrMapInt, counter); diff != "" {
t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"counter", gotStrMapInt, err, counter, diff)
}
gotStrMapBool, err := payload.GetStringMapBool("features")
if diff := cmp.Diff(gotStrMapBool, features); diff != "" {
t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"features", gotStrMapBool, err, features, diff)
}
gotTime, err := payload.GetTime("timestamp")
if !gotTime.Equal(now) {
t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, nil",
"timestamp", gotTime, err, now)
}
gotDuration, err := payload.GetDuration("duration")
if gotDuration != duration {
t.Errorf("Payload.GetDuration(%q) = %v, %v, want %v, nil",
"duration", gotDuration, err, duration)
} }
} }
func TestPayloadGetWithMarshaling(t *testing.T) { func TestPayloadTime(t *testing.T) {
names := []string{"luke", "anakin", "rey"} tests := []payloadTest{
primes := []int{2, 3, 5, 7, 11, 13, 17} {
user := map[string]interface{}{"name": "Ken", "score": 3.14} data: map[string]interface{}{"current": time.Now()},
location := map[string]string{"address": "123 Main St.", "state": "NY", "zipcode": "10002"} key: "current",
favs := map[string][]string{ nonkey: "unknown",
"movies": []string{"forrest gump", "star wars"}, },
"tv_shows": []string{"game of throwns", "HIMYM", "breaking bad"},
}
counter := map[string]int{
"a": 1,
"b": 101,
"c": 42,
}
features := map[string]bool{
"A": false,
"B": true,
"C": true,
}
now := time.Now()
duration := 15 * time.Minute
in := Payload{map[string]interface{}{
"subject": "Hello",
"recipient_id": 9876,
"pi": 3.14,
"enabled": true,
"names": names,
"primes": primes,
"user": user,
"location": location,
"favs": favs,
"counter": counter,
"features": features,
"timestamp": now,
"duration": duration,
}}
// encode and then decode task messsage
inMsg := h.NewTaskMessage("testing", in.data)
data, err := json.Marshal(inMsg)
if err != nil {
t.Fatal(err)
}
var outMsg base.TaskMessage
err = json.Unmarshal(data, &outMsg)
if err != nil {
t.Fatal(err)
}
out := Payload{outMsg.Payload}
gotStr, err := out.GetString("subject")
if gotStr != "Hello" || err != nil {
t.Errorf("Payload.GetString(%q) = %v, %v; want %q, nil",
"subject", gotStr, err, "Hello")
} }
gotInt, err := out.GetInt("recipient_id") for _, tc := range tests {
if gotInt != 9876 || err != nil { payload := Payload{tc.data}
t.Errorf("Payload.GetInt(%q) = %v, %v; want %v, nil",
"recipient_id", gotInt, err, 9876)
}
gotFloat, err := out.GetFloat64("pi") got, err := payload.GetTime(tc.key)
if gotFloat != 3.14 || err != nil { diff := cmp.Diff(got, tc.data[tc.key])
t.Errorf("Payload.GetFloat64(%q) = %v, %v; want %v, nil", if err != nil || diff != "" {
"pi", gotFloat, err, 3.14) t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, nil",
} tc.key, got, err, tc.data[tc.key])
}
gotBool, err := out.GetBool("enabled") // encode and then decode task messsage.
if gotBool != true || err != nil { in := h.NewTaskMessage("testing", tc.data)
t.Errorf("Payload.GetBool(%q) = %v, %v; want %v, nil", b, err := json.Marshal(in)
"enabled", gotBool, err, true) if err != nil {
} t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetTime(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetTime(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
gotStrSlice, err := out.GetStringSlice("names") // access non-existent key.
if diff := cmp.Diff(gotStrSlice, names); diff != "" { got, err = payload.GetTime(tc.nonkey)
t.Errorf("Payload.GetStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s", if err == nil || !got.IsZero() {
"names", gotStrSlice, err, names, diff) t.Errorf("Payload.GetTime(%q) = %v, %v; want %v, error",
} tc.key, got, err, time.Time{})
}
gotIntSlice, err := out.GetIntSlice("primes")
if diff := cmp.Diff(gotIntSlice, primes); diff != "" {
t.Errorf("Payload.GetIntSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"primes", gotIntSlice, err, primes, diff)
}
gotStrMap, err := out.GetStringMap("user")
if diff := cmp.Diff(gotStrMap, user); diff != "" {
t.Errorf("Payload.GetStringMap(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"user", gotStrMap, err, user, diff)
}
gotStrMapStr, err := out.GetStringMapString("location")
if diff := cmp.Diff(gotStrMapStr, location); diff != "" {
t.Errorf("Payload.GetStringMapString(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"location", gotStrMapStr, err, location, diff)
}
gotStrMapStrSlice, err := out.GetStringMapStringSlice("favs")
if diff := cmp.Diff(gotStrMapStrSlice, favs); diff != "" {
t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"favs", gotStrMapStrSlice, err, favs, diff)
}
gotStrMapInt, err := out.GetStringMapInt("counter")
if diff := cmp.Diff(gotStrMapInt, counter); diff != "" {
t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"counter", gotStrMapInt, err, counter, diff)
}
gotStrMapBool, err := out.GetStringMapBool("features")
if diff := cmp.Diff(gotStrMapBool, features); diff != "" {
t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want %v, nil;\n(-want,+got)\n%s",
"features", gotStrMapBool, err, features, diff)
}
gotTime, err := out.GetTime("timestamp")
if !gotTime.Equal(now) {
t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, nil",
"timestamp", gotTime, err, now)
}
gotDuration, err := out.GetDuration("duration")
if gotDuration != duration {
t.Errorf("Payload.GetDuration(%q) = %v, %v, want %v, nil",
"duration", gotDuration, err, duration)
} }
} }
func TestPayloadKeyNotFound(t *testing.T) { func TestPayloadDuration(t *testing.T) {
payload := Payload{nil} tests := []payloadTest{
{
key := "something" data: map[string]interface{}{"duration": 15 * time.Minute},
gotStr, err := payload.GetString(key) key: "duration",
if err == nil || gotStr != "" { nonkey: "unknown",
t.Errorf("Payload.GetString(%q) = %v, %v; want '', error", },
key, gotStr, err)
} }
gotInt, err := payload.GetInt(key) for _, tc := range tests {
if err == nil || gotInt != 0 { payload := Payload{tc.data}
t.Errorf("Payload.GetInt(%q) = %v, %v; want 0, error",
key, gotInt, err)
}
gotFloat, err := payload.GetFloat64(key) got, err := payload.GetDuration(tc.key)
if err == nil || gotFloat != 0 { diff := cmp.Diff(got, tc.data[tc.key])
t.Errorf("Payload.GetFloat64(%q = %v, %v; want 0, error", if err != nil || diff != "" {
key, gotFloat, err) t.Errorf("Payload.GetDuration(%q) = %v, %v, want %v, nil",
} tc.key, got, err, tc.data[tc.key])
}
gotBool, err := payload.GetBool(key) // encode and then decode task messsage.
if err == nil || gotBool != false { in := h.NewTaskMessage("testing", tc.data)
t.Errorf("Payload.GetBool(%q) = %v, %v; want false, error", b, err := json.Marshal(in)
key, gotBool, err) if err != nil {
} t.Fatal(err)
}
var out base.TaskMessage
err = json.Unmarshal(b, &out)
if err != nil {
t.Fatal(err)
}
payload = Payload{out.Payload}
got, err = payload.GetDuration(tc.key)
diff = cmp.Diff(got, tc.data[tc.key])
if err != nil || diff != "" {
t.Errorf("With Marshaling: Payload.GetDuration(%q) = %v, %v, want %v, nil",
tc.key, got, err, tc.data[tc.key])
}
gotStrSlice, err := payload.GetStringSlice(key) // access non-existent key.
if err == nil || gotStrSlice != nil { got, err = payload.GetDuration(tc.nonkey)
t.Errorf("Payload.GetStringSlice(%q) = %v, %v; want nil, error", if err == nil || got != 0 {
key, gotStrSlice, err) t.Errorf("Payload.GetDuration(%q) = %v, %v; want %v, error",
} tc.key, got, err, time.Duration(0))
}
gotIntSlice, err := payload.GetIntSlice(key)
if err == nil || gotIntSlice != nil {
t.Errorf("Payload.GetIntSlice(%q) = %v, %v; want nil, error",
key, gotIntSlice, err)
}
gotStrMap, err := payload.GetStringMap(key)
if err == nil || gotStrMap != nil {
t.Errorf("Payload.GetStringMap(%q) = %v, %v; want nil, error",
key, gotStrMap, err)
}
gotStrMapStr, err := payload.GetStringMapString(key)
if err == nil || gotStrMapStr != nil {
t.Errorf("Payload.GetStringMapString(%q) = %v, %v; want nil, error",
key, gotStrMapStr, err)
}
gotStrMapStrSlice, err := payload.GetStringMapStringSlice(key)
if err == nil || gotStrMapStrSlice != nil {
t.Errorf("Payload.GetStringMapStringSlice(%q) = %v, %v; want nil, error",
key, gotStrMapStrSlice, err)
}
gotStrMapInt, err := payload.GetStringMapInt(key)
if err == nil || gotStrMapInt != nil {
t.Errorf("Payload.GetStringMapInt(%q) = %v, %v, want nil, error",
key, gotStrMapInt, err)
}
gotStrMapBool, err := payload.GetStringMapBool(key)
if err == nil || gotStrMapBool != nil {
t.Errorf("Payload.GetStringMapBool(%q) = %v, %v, want nil, error",
key, gotStrMapBool, err)
}
gotTime, err := payload.GetTime(key)
if err == nil || !gotTime.IsZero() {
t.Errorf("Payload.GetTime(%q) = %v, %v, want %v, error",
key, gotTime, err, time.Time{})
}
gotDuration, err := payload.GetDuration(key)
if err == nil || gotDuration != 0 {
t.Errorf("Payload.GetDuration(%q) = %v, %v, want 0, error",
key, gotDuration, err)
} }
} }

View File

@@ -18,7 +18,8 @@ import (
) )
type processor struct { type processor struct {
rdb *rdb.RDB logger Logger
rdb *rdb.RDB
ps *base.ProcessState ps *base.ProcessState
@@ -31,6 +32,8 @@ type processor struct {
retryDelayFunc retryDelayFunc retryDelayFunc retryDelayFunc
errHandler ErrorHandler
// channel via which to send sync requests to syncer. // channel via which to send sync requests to syncer.
syncRequestCh chan<- *syncRequest syncRequestCh chan<- *syncRequest
@@ -59,7 +62,8 @@ type processor struct {
type retryDelayFunc func(n int, err error, task *Task) time.Duration type retryDelayFunc func(n int, err error, task *Task) time.Duration
// newProcessor constructs a new processor. // newProcessor constructs a new processor.
func newProcessor(r *rdb.RDB, ps *base.ProcessState, fn retryDelayFunc, syncCh chan<- *syncRequest, c *base.Cancelations) *processor { func newProcessor(l Logger, r *rdb.RDB, ps *base.ProcessState, fn retryDelayFunc,
syncCh chan<- *syncRequest, c *base.Cancelations, errHandler ErrorHandler) *processor {
info := ps.Get() info := ps.Get()
qcfg := normalizeQueueCfg(info.Queues) qcfg := normalizeQueueCfg(info.Queues)
orderedQueues := []string(nil) orderedQueues := []string(nil)
@@ -67,6 +71,7 @@ func newProcessor(r *rdb.RDB, ps *base.ProcessState, fn retryDelayFunc, syncCh c
orderedQueues = sortByPriority(qcfg) orderedQueues = sortByPriority(qcfg)
} }
return &processor{ return &processor{
logger: l,
rdb: r, rdb: r,
ps: ps, ps: ps,
queueConfig: qcfg, queueConfig: qcfg,
@@ -79,6 +84,7 @@ func newProcessor(r *rdb.RDB, ps *base.ProcessState, fn retryDelayFunc, syncCh c
done: make(chan struct{}), done: make(chan struct{}),
abort: make(chan struct{}), abort: make(chan struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
errHandler: errHandler,
handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }),
} }
} }
@@ -87,7 +93,7 @@ func newProcessor(r *rdb.RDB, ps *base.ProcessState, fn retryDelayFunc, syncCh c
// It's safe to call this method multiple times. // It's safe to call this method multiple times.
func (p *processor) stop() { func (p *processor) stop() {
p.once.Do(func() { p.once.Do(func() {
logger.info("Processor shutting down...") p.logger.Info("Processor shutting down...")
// Unblock if processor is waiting for sema token. // Unblock if processor is waiting for sema token.
close(p.abort) close(p.abort)
// Signal the processor goroutine to stop processing tasks // Signal the processor goroutine to stop processing tasks
@@ -103,7 +109,7 @@ func (p *processor) terminate() {
// IDEA: Allow user to customize this timeout value. // IDEA: Allow user to customize this timeout value.
const timeout = 8 * time.Second const timeout = 8 * time.Second
time.AfterFunc(timeout, func() { close(p.quit) }) time.AfterFunc(timeout, func() { close(p.quit) })
logger.info("Waiting for all workers to finish...") p.logger.Info("Waiting for all workers to finish...")
// send cancellation signal to all in-progress task handlers // send cancellation signal to all in-progress task handlers
for _, cancel := range p.cancelations.GetAll() { for _, cancel := range p.cancelations.GetAll() {
@@ -114,7 +120,7 @@ func (p *processor) terminate() {
for i := 0; i < cap(p.sema); i++ { for i := 0; i < cap(p.sema); i++ {
p.sema <- struct{}{} p.sema <- struct{}{}
} }
logger.info("All workers have finished") p.logger.Info("All workers have finished")
p.restore() // move any unfinished tasks back to the queue. p.restore() // move any unfinished tasks back to the queue.
} }
@@ -128,7 +134,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
for { for {
select { select {
case <-p.done: case <-p.done:
logger.info("Processor done") p.logger.Info("Processor done")
return return
default: default:
p.exec() p.exec()
@@ -154,7 +160,7 @@ func (p *processor) exec() {
} }
if err != nil { if err != nil {
if p.errLogLimiter.Allow() { if p.errLogLimiter.Allow() {
logger.error("Dequeue error: %v", err) p.logger.Error("Dequeue error: %v", err)
} }
return return
} }
@@ -184,7 +190,7 @@ func (p *processor) exec() {
select { select {
case <-p.quit: case <-p.quit:
// time is up, quit this worker goroutine. // time is up, quit this worker goroutine.
logger.warn("Quitting worker to process task id=%s", msg.ID) p.logger.Warn("Quitting worker. task id=%s", msg.ID)
return return
case resErr := <-resCh: case resErr := <-resCh:
// Note: One of three things should happen. // Note: One of three things should happen.
@@ -192,6 +198,9 @@ func (p *processor) exec() {
// 2) Retry -> Removes the message from InProgress & Adds the message to Retry // 2) Retry -> Removes the message from InProgress & Adds the message to Retry
// 3) Kill -> Removes the message from InProgress & Adds the message to Dead // 3) Kill -> Removes the message from InProgress & Adds the message to Dead
if resErr != nil { if resErr != nil {
if p.errHandler != nil {
p.errHandler.HandleError(task, resErr, msg.Retried, msg.Retry)
}
if msg.Retried >= msg.Retry { if msg.Retried >= msg.Retry {
p.kill(msg, resErr) p.kill(msg, resErr)
} else { } else {
@@ -210,17 +219,17 @@ func (p *processor) exec() {
func (p *processor) restore() { func (p *processor) restore() {
n, err := p.rdb.RequeueAll() n, err := p.rdb.RequeueAll()
if err != nil { if err != nil {
logger.error("Could not restore unfinished tasks: %v", err) p.logger.Error("Could not restore unfinished tasks: %v", err)
} }
if n > 0 { if n > 0 {
logger.info("Restored %d unfinished tasks back to queue", n) p.logger.Info("Restored %d unfinished tasks back to queue", n)
} }
} }
func (p *processor) requeue(msg *base.TaskMessage) { func (p *processor) requeue(msg *base.TaskMessage) {
err := p.rdb.Requeue(msg) err := p.rdb.Requeue(msg)
if err != nil { if err != nil {
logger.error("Could not push task id=%s back to queue: %v", msg.ID, err) p.logger.Error("Could not push task id=%s back to queue: %v", msg.ID, err)
} }
} }
@@ -228,7 +237,7 @@ func (p *processor) markAsDone(msg *base.TaskMessage) {
err := p.rdb.Done(msg) err := p.rdb.Done(msg)
if err != nil { if err != nil {
errMsg := fmt.Sprintf("Could not remove task id=%s from %q", msg.ID, base.InProgressQueue) errMsg := fmt.Sprintf("Could not remove task id=%s from %q", msg.ID, base.InProgressQueue)
logger.warn("%s; Will retry syncing", errMsg) p.logger.Warn("%s; Will retry syncing", errMsg)
p.syncRequestCh <- &syncRequest{ p.syncRequestCh <- &syncRequest{
fn: func() error { fn: func() error {
return p.rdb.Done(msg) return p.rdb.Done(msg)
@@ -244,7 +253,7 @@ func (p *processor) retry(msg *base.TaskMessage, e error) {
err := p.rdb.Retry(msg, retryAt, e.Error()) err := p.rdb.Retry(msg, retryAt, e.Error())
if err != nil { if err != nil {
errMsg := fmt.Sprintf("Could not move task id=%s from %q to %q", msg.ID, base.InProgressQueue, base.RetryQueue) errMsg := fmt.Sprintf("Could not move task id=%s from %q to %q", msg.ID, base.InProgressQueue, base.RetryQueue)
logger.warn("%s; Will retry syncing", errMsg) p.logger.Warn("%s; Will retry syncing", errMsg)
p.syncRequestCh <- &syncRequest{ p.syncRequestCh <- &syncRequest{
fn: func() error { fn: func() error {
return p.rdb.Retry(msg, retryAt, e.Error()) return p.rdb.Retry(msg, retryAt, e.Error())
@@ -255,11 +264,11 @@ func (p *processor) retry(msg *base.TaskMessage, e error) {
} }
func (p *processor) kill(msg *base.TaskMessage, e error) { func (p *processor) kill(msg *base.TaskMessage, e error) {
logger.warn("Retry exhausted for task id=%s", msg.ID) p.logger.Warn("Retry exhausted for task id=%s", msg.ID)
err := p.rdb.Kill(msg, e.Error()) err := p.rdb.Kill(msg, e.Error())
if err != nil { if err != nil {
errMsg := fmt.Sprintf("Could not move task id=%s from %q to %q", msg.ID, base.InProgressQueue, base.DeadQueue) errMsg := fmt.Sprintf("Could not move task id=%s from %q to %q", msg.ID, base.InProgressQueue, base.DeadQueue)
logger.warn("%s; Will retry syncing", errMsg) p.logger.Warn("%s; Will retry syncing", errMsg)
p.syncRequestCh <- &syncRequest{ p.syncRequestCh <- &syncRequest{
fn: func() error { fn: func() error {
return p.rdb.Kill(msg, e.Error()) return p.rdb.Kill(msg, e.Error())
@@ -384,14 +393,18 @@ func gcd(xs ...int) int {
} }
// createContext returns a context and cancel function for a given task message. // createContext returns a context and cancel function for a given task message.
func createContext(msg *base.TaskMessage) (context.Context, context.CancelFunc) { func createContext(msg *base.TaskMessage) (ctx context.Context, cancel context.CancelFunc) {
ctx = context.Background()
timeout, err := time.ParseDuration(msg.Timeout) timeout, err := time.ParseDuration(msg.Timeout)
if err != nil { if err == nil && timeout != 0 {
logger.error("cannot parse timeout duration for %+v", msg) ctx, cancel = context.WithTimeout(ctx, timeout)
return context.WithCancel(context.Background())
} }
if timeout == 0 { deadline, err := time.Parse(time.RFC3339, msg.Deadline)
return context.WithCancel(context.Background()) if err == nil && !deadline.IsZero() {
ctx, cancel = context.WithDeadline(ctx, deadline)
} }
return context.WithTimeout(context.Background(), timeout) if cancel == nil {
ctx, cancel = context.WithCancel(ctx)
}
return ctx, cancel
} }

View File

@@ -17,6 +17,7 @@ import (
h "github.com/hibiken/asynq/internal/asynqtest" h "github.com/hibiken/asynq/internal/asynqtest"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/rdb" "github.com/hibiken/asynq/internal/rdb"
"github.com/rs/xid"
) )
func TestProcessorSuccess(t *testing.T) { func TestProcessorSuccess(t *testing.T) {
@@ -66,11 +67,9 @@ func TestProcessorSuccess(t *testing.T) {
processed = append(processed, task) processed = append(processed, task)
return nil return nil
} }
workerCh := make(chan int)
go fakeHeartbeater(workerCh)
ps := base.NewProcessState("localhost", 1234, 10, defaultQueueConfig, false) ps := base.NewProcessState("localhost", 1234, 10, defaultQueueConfig, false)
cancelations := base.NewCancelations() cancelations := base.NewCancelations()
p := newProcessor(rdbClient, ps, defaultDelayFunc, nil, cancelations) p := newProcessor(testLogger, rdbClient, ps, defaultDelayFunc, nil, cancelations, nil)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -84,7 +83,6 @@ func TestProcessorSuccess(t *testing.T) {
} }
time.Sleep(tc.wait) time.Sleep(tc.wait)
p.terminate() p.terminate()
close(workerCh)
if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt, cmp.AllowUnexported(Payload{})); diff != "" { if diff := cmp.Diff(tc.wantProcessed, processed, sortTaskOpt, cmp.AllowUnexported(Payload{})); diff != "" {
t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
@@ -123,24 +121,30 @@ func TestProcessorRetry(t *testing.T) {
now := time.Now() now := time.Now()
tests := []struct { tests := []struct {
enqueued []*base.TaskMessage // initial default queue state enqueued []*base.TaskMessage // initial default queue state
incoming []*base.TaskMessage // tasks to be enqueued during run incoming []*base.TaskMessage // tasks to be enqueued during run
delay time.Duration // retry delay duration delay time.Duration // retry delay duration
wait time.Duration // wait duration between starting and stopping processor for this test case handler Handler // task handler
wantRetry []h.ZSetEntry // tasks in retry queue at the end wait time.Duration // wait duration between starting and stopping processor for this test case
wantDead []*base.TaskMessage // tasks in dead queue at the end wantRetry []h.ZSetEntry // tasks in retry queue at the end
wantDead []*base.TaskMessage // tasks in dead queue at the end
wantErrCount int // number of times error handler should be called
}{ }{
{ {
enqueued: []*base.TaskMessage{m1, m2}, enqueued: []*base.TaskMessage{m1, m2},
incoming: []*base.TaskMessage{m3, m4}, incoming: []*base.TaskMessage{m3, m4},
delay: time.Minute, delay: time.Minute,
wait: time.Second, handler: HandlerFunc(func(ctx context.Context, task *Task) error {
return fmt.Errorf(errMsg)
}),
wait: time.Second,
wantRetry: []h.ZSetEntry{ wantRetry: []h.ZSetEntry{
{Msg: &r2, Score: float64(now.Add(time.Minute).Unix())}, {Msg: &r2, Score: float64(now.Add(time.Minute).Unix())},
{Msg: &r3, Score: float64(now.Add(time.Minute).Unix())}, {Msg: &r3, Score: float64(now.Add(time.Minute).Unix())},
{Msg: &r4, Score: float64(now.Add(time.Minute).Unix())}, {Msg: &r4, Score: float64(now.Add(time.Minute).Unix())},
}, },
wantDead: []*base.TaskMessage{&r1}, wantDead: []*base.TaskMessage{&r1},
wantErrCount: 4,
}, },
} }
@@ -152,15 +156,19 @@ func TestProcessorRetry(t *testing.T) {
delayFunc := func(n int, e error, t *Task) time.Duration { delayFunc := func(n int, e error, t *Task) time.Duration {
return tc.delay return tc.delay
} }
handler := func(ctx context.Context, task *Task) error { var (
return fmt.Errorf(errMsg) mu sync.Mutex // guards n
n int // number of times error handler is called
)
errHandler := func(t *Task, err error, retried, maxRetry int) {
mu.Lock()
defer mu.Unlock()
n++
} }
workerCh := make(chan int)
go fakeHeartbeater(workerCh)
ps := base.NewProcessState("localhost", 1234, 10, defaultQueueConfig, false) ps := base.NewProcessState("localhost", 1234, 10, defaultQueueConfig, false)
cancelations := base.NewCancelations() cancelations := base.NewCancelations()
p := newProcessor(rdbClient, ps, delayFunc, nil, cancelations) p := newProcessor(testLogger, rdbClient, ps, delayFunc, nil, cancelations, ErrorHandlerFunc(errHandler))
p.handler = HandlerFunc(handler) p.handler = tc.handler
var wg sync.WaitGroup var wg sync.WaitGroup
p.start(&wg) p.start(&wg)
@@ -173,7 +181,6 @@ func TestProcessorRetry(t *testing.T) {
} }
time.Sleep(tc.wait) time.Sleep(tc.wait)
p.terminate() p.terminate()
close(workerCh)
cmpOpt := cmpopts.EquateApprox(0, float64(time.Second)) // allow up to second difference in zset score cmpOpt := cmpopts.EquateApprox(0, float64(time.Second)) // allow up to second difference in zset score
gotRetry := h.GetRetryEntries(t, r) gotRetry := h.GetRetryEntries(t, r)
@@ -189,6 +196,10 @@ func TestProcessorRetry(t *testing.T) {
if l := r.LLen(base.InProgressQueue).Val(); l != 0 { if l := r.LLen(base.InProgressQueue).Val(); l != 0 {
t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l) t.Errorf("%q has %d tasks, want 0", base.InProgressQueue, l)
} }
if n != tc.wantErrCount {
t.Errorf("error handler was called %d times, want %d", n, tc.wantErrCount)
}
} }
} }
@@ -222,7 +233,7 @@ func TestProcessorQueues(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
cancelations := base.NewCancelations() cancelations := base.NewCancelations()
ps := base.NewProcessState("localhost", 1234, 10, tc.queueCfg, false) ps := base.NewProcessState("localhost", 1234, 10, tc.queueCfg, false)
p := newProcessor(nil, ps, defaultDelayFunc, nil, cancelations) p := newProcessor(testLogger, nil, ps, defaultDelayFunc, nil, cancelations, nil)
got := p.queues() got := p.queues()
if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" { if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" {
t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s", t.Errorf("with queue config: %v\n(*processor).queues() = %v, want %v\n(-want,+got):\n%s",
@@ -288,18 +299,15 @@ func TestProcessorWithStrictPriority(t *testing.T) {
"low": 1, "low": 1,
} }
// Note: Set concurrency to 1 to make sure tasks are processed one at a time. // Note: Set concurrency to 1 to make sure tasks are processed one at a time.
workerCh := make(chan int)
go fakeHeartbeater(workerCh)
cancelations := base.NewCancelations() cancelations := base.NewCancelations()
ps := base.NewProcessState("localhost", 1234, 1 /* concurrency */, queueCfg, true /*strict*/) ps := base.NewProcessState("localhost", 1234, 1 /* concurrency */, queueCfg, true /*strict*/)
p := newProcessor(rdbClient, ps, defaultDelayFunc, nil, cancelations) p := newProcessor(testLogger, rdbClient, ps, defaultDelayFunc, nil, cancelations, nil)
p.handler = HandlerFunc(handler) p.handler = HandlerFunc(handler)
var wg sync.WaitGroup var wg sync.WaitGroup
p.start(&wg) p.start(&wg)
time.Sleep(tc.wait) time.Sleep(tc.wait)
p.terminate() p.terminate()
close(workerCh)
if diff := cmp.Diff(tc.wantProcessed, processed, cmp.AllowUnexported(Payload{})); diff != "" { if diff := cmp.Diff(tc.wantProcessed, processed, cmp.AllowUnexported(Payload{})); diff != "" {
t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff) t.Errorf("mismatch found in processed tasks; (-want, +got)\n%s", diff)
@@ -357,8 +365,84 @@ func TestPerform(t *testing.T) {
} }
} }
// fake heartbeater to receive sends from the worker channel. func TestCreateContextWithTimeRestrictions(t *testing.T) {
func fakeHeartbeater(ch <-chan int) { var (
for range ch { noTimeout = time.Duration(0)
noDeadline = time.Time{}
)
tests := []struct {
desc string
timeout time.Duration
deadline time.Time
wantDeadline time.Time
}{
{"only with timeout", 10 * time.Second, noDeadline, time.Now().Add(10 * time.Second)},
{"only with deadline", noTimeout, time.Now().Add(time.Hour), time.Now().Add(time.Hour)},
{"with timeout and deadline (timeout < deadline)", 10 * time.Second, time.Now().Add(time.Hour), time.Now().Add(10 * time.Second)},
{"with timeout and deadline (timeout > deadline)", 10 * time.Minute, time.Now().Add(30 * time.Second), time.Now().Add(30 * time.Second)},
}
for _, tc := range tests {
msg := &base.TaskMessage{
Type: "something",
ID: xid.New(),
Timeout: tc.timeout.String(),
Deadline: tc.deadline.Format(time.RFC3339),
}
ctx, cancel := createContext(msg)
select {
case x := <-ctx.Done():
t.Errorf("%s: <-ctx.Done() == %v, want nothing (it should block)", tc.desc, x)
default:
}
got, ok := ctx.Deadline()
if !ok {
t.Errorf("%s: ctx.Deadline() returned false, want deadline to be set", tc.desc)
}
if !cmp.Equal(tc.wantDeadline, got, cmpopts.EquateApproxTime(time.Second)) {
t.Errorf("%s: ctx.Deadline() returned %v, want %v", tc.desc, got, tc.wantDeadline)
}
cancel()
select {
case <-ctx.Done():
default:
t.Errorf("ctx.Done() blocked, want it to be non-blocking")
}
}
}
func TestCreateContextWithoutTimeRestrictions(t *testing.T) {
msg := &base.TaskMessage{
Type: "something",
ID: xid.New(),
Timeout: time.Duration(0).String(), // zero value to indicate no timeout
Deadline: time.Time{}.Format(time.RFC3339), // zero value to indicate no deadline
}
ctx, cancel := createContext(msg)
select {
case x := <-ctx.Done():
t.Errorf("<-ctx.Done() == %v, want nothing (it should block)", x)
default:
}
_, ok := ctx.Deadline()
if ok {
t.Error("ctx.Deadline() returned true, want deadline to not be set")
}
cancel()
select {
case <-ctx.Done():
default:
t.Error("ctx.Done() blocked, want it to be non-blocking")
} }
} }

View File

@@ -12,7 +12,8 @@ import (
) )
type scheduler struct { type scheduler struct {
rdb *rdb.RDB logger Logger
rdb *rdb.RDB
// channel to communicate back to the long running "scheduler" goroutine. // channel to communicate back to the long running "scheduler" goroutine.
done chan struct{} done chan struct{}
@@ -24,12 +25,13 @@ type scheduler struct {
qnames []string qnames []string
} }
func newScheduler(r *rdb.RDB, avgInterval time.Duration, qcfg map[string]int) *scheduler { func newScheduler(l Logger, r *rdb.RDB, avgInterval time.Duration, qcfg map[string]int) *scheduler {
var qnames []string var qnames []string
for q := range qcfg { for q := range qcfg {
qnames = append(qnames, q) qnames = append(qnames, q)
} }
return &scheduler{ return &scheduler{
logger: l,
rdb: r, rdb: r,
done: make(chan struct{}), done: make(chan struct{}),
avgInterval: avgInterval, avgInterval: avgInterval,
@@ -38,7 +40,7 @@ func newScheduler(r *rdb.RDB, avgInterval time.Duration, qcfg map[string]int) *s
} }
func (s *scheduler) terminate() { func (s *scheduler) terminate() {
logger.info("Scheduler shutting down...") s.logger.Info("Scheduler shutting down...")
// Signal the scheduler goroutine to stop polling. // Signal the scheduler goroutine to stop polling.
s.done <- struct{}{} s.done <- struct{}{}
} }
@@ -51,7 +53,7 @@ func (s *scheduler) start(wg *sync.WaitGroup) {
for { for {
select { select {
case <-s.done: case <-s.done:
logger.info("Scheduler done") s.logger.Info("Scheduler done")
return return
case <-time.After(s.avgInterval): case <-time.After(s.avgInterval):
s.exec() s.exec()
@@ -62,6 +64,6 @@ func (s *scheduler) start(wg *sync.WaitGroup) {
func (s *scheduler) exec() { func (s *scheduler) exec() {
if err := s.rdb.CheckAndEnqueue(s.qnames...); err != nil { if err := s.rdb.CheckAndEnqueue(s.qnames...); err != nil {
logger.error("Could not enqueue scheduled tasks: %v", err) s.logger.Error("Could not enqueue scheduled tasks: %v", err)
} }
} }

View File

@@ -19,7 +19,7 @@ func TestScheduler(t *testing.T) {
r := setup(t) r := setup(t)
rdbClient := rdb.NewRDB(r) rdbClient := rdb.NewRDB(r)
const pollInterval = time.Second const pollInterval = time.Second
s := newScheduler(rdbClient, pollInterval, defaultQueueConfig) s := newScheduler(testLogger, rdbClient, pollInterval, defaultQueueConfig)
t1 := h.NewTaskMessage("gen_thumbnail", nil) t1 := h.NewTaskMessage("gen_thumbnail", nil)
t2 := h.NewTaskMessage("send_email", nil) t2 := h.NewTaskMessage("send_email", nil)
t3 := h.NewTaskMessage("reindex", nil) t3 := h.NewTaskMessage("reindex", nil)

158
servemux.go Normal file
View File

@@ -0,0 +1,158 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package asynq
import (
"context"
"fmt"
"sort"
"strings"
"sync"
)
// ServeMux is a multiplexer for asynchronous tasks.
// It matches the type of each task against a list of registered patterns
// and calls the handler for the pattern that most closely matches the
// taks's type name.
//
// Longer patterns take precedence over shorter ones, so that if there are
// handlers registered for both "images" and "images:thumbnails",
// the latter handler will be called for tasks with a type name beginning with
// "images:thumbnails" and the former will receive tasks with type name beginning
// with "images".
type ServeMux struct {
mu sync.RWMutex
m map[string]muxEntry
es []muxEntry // slice of entries sorted from longest to shortest.
mws []MiddlewareFunc
}
type muxEntry struct {
h Handler
pattern string
}
// MiddlewareFunc is a function which receives an asynq.Handler and returns another asynq.Handler.
// Typically, the returned handler is a closure which does something with the context and task passed
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
type MiddlewareFunc func(Handler) Handler
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux {
return new(ServeMux)
}
// ProcessTask dispatches the task to the handler whose
// pattern most closely matches the task type.
func (mux *ServeMux) ProcessTask(ctx context.Context, task *Task) error {
h, _ := mux.Handler(task)
return h.ProcessTask(ctx, task)
}
// Handler returns the handler to use for the given task.
// It always return a non-nil handler.
//
// Handler also returns the registered pattern that matches the task.
//
// If there is no registered handler that applies to the task,
// handler returns a 'not found' handler which returns an error.
func (mux *ServeMux) Handler(t *Task) (h Handler, pattern string) {
mux.mu.RLock()
defer mux.mu.RUnlock()
h, pattern = mux.match(t.Type)
if h == nil {
h, pattern = NotFoundHandler(), ""
}
for i := len(mux.mws) - 1; i >= 0; i-- {
h = mux.mws[i](h)
}
return h, pattern
}
// Find a handler on a handler map given a typename string.
// Most-specific (longest) pattern wins.
func (mux *ServeMux) match(typename string) (h Handler, pattern string) {
// Check for exact match first.
v, ok := mux.m[typename]
if ok {
return v.h, v.pattern
}
// Check for longest valid match.
// mux.es contains all patterns from longest to shortest.
for _, e := range mux.es {
if strings.HasPrefix(typename, e.pattern) {
return e.h, e.pattern
}
}
return nil, ""
}
// Handle registers the handler for the given pattern.
// If a handler already exists for pattern, Handle panics.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
mux.mu.Lock()
defer mux.mu.Unlock()
if pattern == "" {
panic("asynq: invalid pattern")
}
if handler == nil {
panic("asynq: nil handler")
}
if _, exist := mux.m[pattern]; exist {
panic("asynq: multiple registrations for " + pattern)
}
if mux.m == nil {
mux.m = make(map[string]muxEntry)
}
e := muxEntry{h: handler, pattern: pattern}
mux.m[pattern] = e
mux.es = appendSorted(mux.es, e)
}
func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
n := len(es)
i := sort.Search(n, func(i int) bool {
return len(es[i].pattern) < len(e.pattern)
})
if i == n {
return append(es, e)
}
// we now know that i points at where we want to insert.
es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
copy(es[i+1:], es[i:]) // shift shorter entries down.
es[i] = e
return es
}
// HandleFunc registers the handler function for the given pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(context.Context, *Task) error) {
if handler == nil {
panic("asynq: nil handler")
}
mux.Handle(pattern, HandlerFunc(handler))
}
// Use appends a MiddlewareFunc to the chain.
// Middlewares are executed in the order that they are applied to the ServeMux.
func (mux *ServeMux) Use(mws ...MiddlewareFunc) {
mux.mu.Lock()
defer mux.mu.Unlock()
for _, fn := range mws {
mux.mws = append(mux.mws, fn)
}
}
// NotFound returns an error indicating that the handler was not found for the given task.
func NotFound(ctx context.Context, task *Task) error {
return fmt.Errorf("handler not found for task %q", task.Type)
}
// NotFoundHandler returns a simple task handler that returns a ``not found`` error.
func NotFoundHandler() Handler { return HandlerFunc(NotFound) }

170
servemux_test.go Normal file
View File

@@ -0,0 +1,170 @@
// Copyright 2020 Kentaro Hibino. All rights reserved.
// Use of this source code is governed by a MIT license
// that can be found in the LICENSE file.
package asynq
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
)
var called string // identity of the handler that was called.
var invoked []string // list of middlewares in the order they were invoked.
// makeFakeHandler returns a handler that updates the global called variable
// to the given identity.
func makeFakeHandler(identity string) Handler {
return HandlerFunc(func(ctx context.Context, t *Task) error {
called = identity
return nil
})
}
// makeFakeMiddleware returns a middleware function that appends the given identity
//to the global invoked slice.
func makeFakeMiddleware(identity string) MiddlewareFunc {
return func(next Handler) Handler {
return HandlerFunc(func(ctx context.Context, t *Task) error {
invoked = append(invoked, identity)
return next.ProcessTask(ctx, t)
})
}
}
// A list of pattern, handler pair that is registered with mux.
var serveMuxRegister = []struct {
pattern string
h Handler
}{
{"email:", makeFakeHandler("default email handler")},
{"email:signup", makeFakeHandler("signup email handler")},
{"csv:export", makeFakeHandler("csv export handler")},
}
var serveMuxTests = []struct {
typename string // task's type name
want string // identifier of the handler that should be called
}{
{"email:signup", "signup email handler"},
{"csv:export", "csv export handler"},
{"email:daily", "default email handler"},
}
func TestServeMux(t *testing.T) {
mux := NewServeMux()
for _, e := range serveMuxRegister {
mux.Handle(e.pattern, e.h)
}
for _, tc := range serveMuxTests {
called = "" // reset to zero value
task := NewTask(tc.typename, nil)
if err := mux.ProcessTask(context.Background(), task); err != nil {
t.Fatal(err)
}
if called != tc.want {
t.Errorf("%q handler was called for task %q, want %q to be called", called, task.Type, tc.want)
}
}
}
func TestServeMuxRegisterNilHandler(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Error("expected call to mux.HandleFunc to panic")
}
}()
mux := NewServeMux()
mux.HandleFunc("email:signup", nil)
}
func TestServeMuxRegisterEmptyPattern(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Error("expected call to mux.HandleFunc to panic")
}
}()
mux := NewServeMux()
mux.Handle("", makeFakeHandler("email"))
}
func TestServeMuxRegisterDuplicatePattern(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Error("expected call to mux.HandleFunc to panic")
}
}()
mux := NewServeMux()
mux.Handle("email", makeFakeHandler("email"))
mux.Handle("email", makeFakeHandler("email:default"))
}
var notFoundTests = []struct {
typename string // task's type name
}{
{"image:minimize"},
{"csv:"}, // registered patterns match the task's type prefix, not the other way around.
}
func TestServeMuxNotFound(t *testing.T) {
mux := NewServeMux()
for _, e := range serveMuxRegister {
mux.Handle(e.pattern, e.h)
}
for _, tc := range notFoundTests {
task := NewTask(tc.typename, nil)
err := mux.ProcessTask(context.Background(), task)
if err == nil {
t.Errorf("ProcessTask did not return error for task %q, should return 'not found' error", task.Type)
}
}
}
var middlewareTests = []struct {
typename string // task's type name
middlewares []string // middlewares to use. They should be called in this order.
want string // identifier of the handler that should be called
}{
{"email:signup", []string{"logging", "expiration"}, "signup email handler"},
{"csv:export", []string{}, "csv export handler"},
{"email:daily", []string{"expiration", "logging"}, "default email handler"},
}
func TestServeMuxMiddlewares(t *testing.T) {
for _, tc := range middlewareTests {
mux := NewServeMux()
for _, e := range serveMuxRegister {
mux.Handle(e.pattern, e.h)
}
var mws []MiddlewareFunc
for _, s := range tc.middlewares {
mws = append(mws, makeFakeMiddleware(s))
}
mux.Use(mws...)
invoked = []string{} // reset to empty slice
called = "" // reset to zero value
task := NewTask(tc.typename, nil)
if err := mux.ProcessTask(context.Background(), task); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(invoked, tc.middlewares); diff != "" {
t.Errorf("invoked middlewares were %v, want %v", invoked, tc.middlewares)
}
if called != tc.want {
t.Errorf("%q handler was called for task %q, want %q to be called", called, task.Type, tc.want)
}
}
}

View File

@@ -12,7 +12,8 @@ import (
) )
type subscriber struct { type subscriber struct {
rdb *rdb.RDB logger Logger
rdb *rdb.RDB
// channel to communicate back to the long running "subscriber" goroutine. // channel to communicate back to the long running "subscriber" goroutine.
done chan struct{} done chan struct{}
@@ -21,8 +22,9 @@ type subscriber struct {
cancelations *base.Cancelations cancelations *base.Cancelations
} }
func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber { func newSubscriber(l Logger, rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber {
return &subscriber{ return &subscriber{
logger: l,
rdb: rdb, rdb: rdb,
done: make(chan struct{}), done: make(chan struct{}),
cancelations: cancelations, cancelations: cancelations,
@@ -30,7 +32,7 @@ func newSubscriber(rdb *rdb.RDB, cancelations *base.Cancelations) *subscriber {
} }
func (s *subscriber) terminate() { func (s *subscriber) terminate() {
logger.info("Subscriber shutting down...") s.logger.Info("Subscriber shutting down...")
// Signal the subscriber goroutine to stop. // Signal the subscriber goroutine to stop.
s.done <- struct{}{} s.done <- struct{}{}
} }
@@ -39,7 +41,7 @@ func (s *subscriber) start(wg *sync.WaitGroup) {
pubsub, err := s.rdb.CancelationPubSub() pubsub, err := s.rdb.CancelationPubSub()
cancelCh := pubsub.Channel() cancelCh := pubsub.Channel()
if err != nil { if err != nil {
logger.error("cannot subscribe to cancelation channel: %v", err) s.logger.Error("cannot subscribe to cancelation channel: %v", err)
return return
} }
wg.Add(1) wg.Add(1)
@@ -49,7 +51,7 @@ func (s *subscriber) start(wg *sync.WaitGroup) {
select { select {
case <-s.done: case <-s.done:
pubsub.Close() pubsub.Close()
logger.info("Subscriber done") s.logger.Info("Subscriber done")
return return
case msg := <-cancelCh: case msg := <-cancelCh:
cancel, ok := s.cancelations.Get(msg.Payload) cancel, ok := s.cancelations.Get(msg.Payload)

View File

@@ -37,7 +37,7 @@ func TestSubscriber(t *testing.T) {
cancelations := base.NewCancelations() cancelations := base.NewCancelations()
cancelations.Add(tc.registeredID, fakeCancelFunc) cancelations.Add(tc.registeredID, fakeCancelFunc)
subscriber := newSubscriber(rdbClient, cancelations) subscriber := newSubscriber(testLogger, rdbClient, cancelations)
var wg sync.WaitGroup var wg sync.WaitGroup
subscriber.start(&wg) subscriber.start(&wg)

View File

@@ -12,6 +12,8 @@ import (
// syncer is responsible for queuing up failed requests to redis and retry // syncer is responsible for queuing up failed requests to redis and retry
// those requests to sync state between the background process and redis. // those requests to sync state between the background process and redis.
type syncer struct { type syncer struct {
logger Logger
requestsCh <-chan *syncRequest requestsCh <-chan *syncRequest
// channel to communicate back to the long running "syncer" goroutine. // channel to communicate back to the long running "syncer" goroutine.
@@ -26,8 +28,9 @@ type syncRequest struct {
errMsg string // error message errMsg string // error message
} }
func newSyncer(requestsCh <-chan *syncRequest, interval time.Duration) *syncer { func newSyncer(l Logger, requestsCh <-chan *syncRequest, interval time.Duration) *syncer {
return &syncer{ return &syncer{
logger: l,
requestsCh: requestsCh, requestsCh: requestsCh,
done: make(chan struct{}), done: make(chan struct{}),
interval: interval, interval: interval,
@@ -35,7 +38,7 @@ func newSyncer(requestsCh <-chan *syncRequest, interval time.Duration) *syncer {
} }
func (s *syncer) terminate() { func (s *syncer) terminate() {
logger.info("Syncer shutting down...") s.logger.Info("Syncer shutting down...")
// Signal the syncer goroutine to stop. // Signal the syncer goroutine to stop.
s.done <- struct{}{} s.done <- struct{}{}
} }
@@ -51,10 +54,10 @@ func (s *syncer) start(wg *sync.WaitGroup) {
// Try sync one last time before shutting down. // Try sync one last time before shutting down.
for _, req := range requests { for _, req := range requests {
if err := req.fn(); err != nil { if err := req.fn(); err != nil {
logger.error(req.errMsg) s.logger.Error(req.errMsg)
} }
} }
logger.info("Syncer done") s.logger.Info("Syncer done")
return return
case req := <-s.requestsCh: case req := <-s.requestsCh:
requests = append(requests, req) requests = append(requests, req)

View File

@@ -27,7 +27,7 @@ func TestSyncer(t *testing.T) {
const interval = time.Second const interval = time.Second
syncRequestCh := make(chan *syncRequest) syncRequestCh := make(chan *syncRequest)
syncer := newSyncer(syncRequestCh, interval) syncer := newSyncer(testLogger, syncRequestCh, interval)
var wg sync.WaitGroup var wg sync.WaitGroup
syncer.start(&wg) syncer.start(&wg)
defer syncer.terminate() defer syncer.terminate()
@@ -52,7 +52,7 @@ func TestSyncer(t *testing.T) {
func TestSyncerRetry(t *testing.T) { func TestSyncerRetry(t *testing.T) {
const interval = time.Second const interval = time.Second
syncRequestCh := make(chan *syncRequest) syncRequestCh := make(chan *syncRequest)
syncer := newSyncer(syncRequestCh, interval) syncer := newSyncer(testLogger, syncRequestCh, interval)
var wg sync.WaitGroup var wg sync.WaitGroup
syncer.start(&wg) syncer.start(&wg)

View File

@@ -22,7 +22,7 @@ var workersCmd = &cobra.Command{
Short: "Shows all running workers information", Short: "Shows all running workers information",
Long: `Workers (asynqmon workers) will show all running workers information. Long: `Workers (asynqmon workers) will show all running workers information.
The command shows the follwoing for each worker: The command shows the following for each worker:
* Process in which the worker is running * Process in which the worker is running
* ID of the task worker is processing * ID of the task worker is processing
* Type of the task worker is processing * Type of the task worker is processing