2
0
mirror of https://github.com/hibiken/asynq.git synced 2024-11-10 03:21:55 +08:00

Add pre and post enqueue callback options for Scheduler

This commit is contained in:
Chih Sean Hsu 2022-05-28 01:50:02 +08:00 committed by GitHub
parent 30d409371b
commit e0e5d1ac24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 36 deletions

View File

@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
- `PreEnqueueFunc`, `PostEnqueueFunc` is added in `Scheduler` and deprecated `EnqueueErrorHandler` (PR: https://github.com/hibiken/asynq/pull/476)
## [0.23.0] - 2022-04-11 ## [0.23.0] - 2022-04-11
### Added ### Added

View File

@ -26,14 +26,16 @@ type Scheduler struct {
state *serverState state *serverState
logger *log.Logger logger *log.Logger
client *Client client *Client
rdb *rdb.RDB rdb *rdb.RDB
cron *cron.Cron cron *cron.Cron
location *time.Location location *time.Location
done chan struct{} done chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
errHandler func(task *Task, opts []Option, err error) preEnqueueFunc func(task *Task, opts []Option)
postEnqueueFunc func(info *TaskInfo, err error)
errHandler func(task *Task, opts []Option, err error)
// guards idmap // guards idmap
mu sync.Mutex mu sync.Mutex
@ -67,16 +69,18 @@ func NewScheduler(r RedisConnOpt, opts *SchedulerOpts) *Scheduler {
} }
return &Scheduler{ return &Scheduler{
id: generateSchedulerID(), id: generateSchedulerID(),
state: &serverState{value: srvStateNew}, state: &serverState{value: srvStateNew},
logger: logger, logger: logger,
client: NewClient(r), client: NewClient(r),
rdb: rdb.NewRDB(c), rdb: rdb.NewRDB(c),
cron: cron.New(cron.WithLocation(loc)), cron: cron.New(cron.WithLocation(loc)),
location: loc, location: loc,
done: make(chan struct{}), done: make(chan struct{}),
errHandler: opts.EnqueueErrorHandler, preEnqueueFunc: opts.PreEnqueueFunc,
idmap: make(map[string]cron.EntryID), postEnqueueFunc: opts.PostEnqueueFunc,
errHandler: opts.EnqueueErrorHandler,
idmap: make(map[string]cron.EntryID),
} }
} }
@ -105,6 +109,15 @@ type SchedulerOpts struct {
// If unset, the UTC time zone (time.UTC) is used. // If unset, the UTC time zone (time.UTC) is used.
Location *time.Location Location *time.Location
// PreEnqueueFunc, if provided, is called before a task gets enqueued by Scheduler.
// The callback function should return quickly to not block the current thread.
PreEnqueueFunc func(task *Task, opts []Option)
// PostEnqueueFunc, if provided, is called after a task gets enqueued by Scheduler.
// The callback function should return quickly to not block the current thread.
PostEnqueueFunc func(info *TaskInfo, err error)
// Deprecated: Use PostEnqueueFunc instead
// EnqueueErrorHandler gets called when scheduler cannot enqueue a registered task // EnqueueErrorHandler gets called when scheduler cannot enqueue a registered task
// due to an error. // due to an error.
EnqueueErrorHandler func(task *Task, opts []Option, err error) EnqueueErrorHandler func(task *Task, opts []Option, err error)
@ -112,19 +125,27 @@ type SchedulerOpts struct {
// enqueueJob encapsulates the job of enqueuing a task and recording the event. // enqueueJob encapsulates the job of enqueuing a task and recording the event.
type enqueueJob struct { type enqueueJob struct {
id uuid.UUID id uuid.UUID
cronspec string cronspec string
task *Task task *Task
opts []Option opts []Option
location *time.Location location *time.Location
logger *log.Logger logger *log.Logger
client *Client client *Client
rdb *rdb.RDB rdb *rdb.RDB
errHandler func(task *Task, opts []Option, err error) preEnqueueFunc func(task *Task, opts []Option)
postEnqueueFunc func(info *TaskInfo, err error)
errHandler func(task *Task, opts []Option, err error)
} }
func (j *enqueueJob) Run() { func (j *enqueueJob) Run() {
if j.preEnqueueFunc != nil {
j.preEnqueueFunc(j.task, j.opts)
}
info, err := j.client.Enqueue(j.task, j.opts...) info, err := j.client.Enqueue(j.task, j.opts...)
if j.postEnqueueFunc != nil {
j.postEnqueueFunc(info, err)
}
if err != nil { if err != nil {
j.logger.Errorf("scheduler could not enqueue a task %+v: %v", j.task, err) j.logger.Errorf("scheduler could not enqueue a task %+v: %v", j.task, err)
if j.errHandler != nil { if j.errHandler != nil {
@ -147,15 +168,17 @@ func (j *enqueueJob) Run() {
// It returns an ID of the newly registered entry. // It returns an ID of the newly registered entry.
func (s *Scheduler) Register(cronspec string, task *Task, opts ...Option) (entryID string, err error) { func (s *Scheduler) Register(cronspec string, task *Task, opts ...Option) (entryID string, err error) {
job := &enqueueJob{ job := &enqueueJob{
id: uuid.New(), id: uuid.New(),
cronspec: cronspec, cronspec: cronspec,
task: task, task: task,
opts: opts, opts: opts,
location: s.location, location: s.location,
client: s.client, client: s.client,
rdb: s.rdb, rdb: s.rdb,
logger: s.logger, logger: s.logger,
errHandler: s.errHandler, preEnqueueFunc: s.preEnqueueFunc,
postEnqueueFunc: s.postEnqueueFunc,
errHandler: s.errHandler,
} }
cronID, err := s.cron.AddJob(cronspec, job) cronID, err := s.cron.AddJob(cronspec, job)
if err != nil { if err != nil {

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/testutil" "github.com/hibiken/asynq/internal/testutil"
) )
@ -154,3 +155,56 @@ func TestSchedulerUnregister(t *testing.T) {
} }
} }
} }
func TestSchedulerPostAndPreEnqueueHandler(t *testing.T) {
var (
preMu sync.Mutex
preCounter int
postMu sync.Mutex
postCounter int
)
preHandler := func(task *Task, opts []Option) {
preMu.Lock()
preCounter++
preMu.Unlock()
}
postHandler := func(info *TaskInfo, err error) {
postMu.Lock()
postCounter++
postMu.Unlock()
}
// Connect to non-existent redis instance to simulate a redis server being down.
scheduler := NewScheduler(
getRedisConnOpt(t),
&SchedulerOpts{
PreEnqueueFunc: preHandler,
PostEnqueueFunc: postHandler,
},
)
task := NewTask("test", nil)
if _, err := scheduler.Register("@every 3s", task); err != nil {
t.Fatal(err)
}
if err := scheduler.Start(); err != nil {
t.Fatal(err)
}
// Scheduler should attempt to enqueue the task three times (every 3s).
time.Sleep(10 * time.Second)
scheduler.Shutdown()
preMu.Lock()
if preCounter != 3 {
t.Errorf("PreEnqueueFunc was called %d times, want 3", preCounter)
}
preMu.Unlock()
postMu.Lock()
if postCounter != 3 {
t.Errorf("PostEnqueueFunc was called %d times, want 3", postCounter)
}
postMu.Unlock()
}