From e0e5d1ac2487495954a31dc3dc5f6949e502090b Mon Sep 17 00:00:00 2001 From: Chih Sean Hsu Date: Sat, 28 May 2022 01:50:02 +0800 Subject: [PATCH] Add pre and post enqueue callback options for Scheduler --- CHANGELOG.md | 2 + scheduler.go | 95 +++++++++++++++++++++++++++++------------------ scheduler_test.go | 54 +++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f41de1..d001908 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- `PreEnqueueFunc`, `PostEnqueueFunc` is added in `Scheduler` and deprecated `EnqueueErrorHandler` (PR: https://github.com/hibiken/asynq/pull/476) + ## [0.23.0] - 2022-04-11 ### Added diff --git a/scheduler.go b/scheduler.go index 4a9eed5..f57c40d 100644 --- a/scheduler.go +++ b/scheduler.go @@ -26,14 +26,16 @@ type Scheduler struct { state *serverState - logger *log.Logger - client *Client - rdb *rdb.RDB - cron *cron.Cron - location *time.Location - done chan struct{} - wg sync.WaitGroup - errHandler func(task *Task, opts []Option, err error) + logger *log.Logger + client *Client + rdb *rdb.RDB + cron *cron.Cron + location *time.Location + done chan struct{} + wg sync.WaitGroup + preEnqueueFunc func(task *Task, opts []Option) + postEnqueueFunc func(info *TaskInfo, err error) + errHandler func(task *Task, opts []Option, err error) // guards idmap mu sync.Mutex @@ -67,16 +69,18 @@ func NewScheduler(r RedisConnOpt, opts *SchedulerOpts) *Scheduler { } return &Scheduler{ - id: generateSchedulerID(), - state: &serverState{value: srvStateNew}, - logger: logger, - client: NewClient(r), - rdb: rdb.NewRDB(c), - cron: cron.New(cron.WithLocation(loc)), - location: loc, - done: make(chan struct{}), - errHandler: opts.EnqueueErrorHandler, - idmap: make(map[string]cron.EntryID), + id: generateSchedulerID(), + state: &serverState{value: srvStateNew}, + logger: logger, + client: NewClient(r), + rdb: rdb.NewRDB(c), + cron: cron.New(cron.WithLocation(loc)), + location: loc, + done: make(chan struct{}), + preEnqueueFunc: opts.PreEnqueueFunc, + postEnqueueFunc: opts.PostEnqueueFunc, + errHandler: opts.EnqueueErrorHandler, + idmap: make(map[string]cron.EntryID), } } @@ -105,6 +109,15 @@ type SchedulerOpts struct { // If unset, the UTC time zone (time.UTC) is used. Location *time.Location + // PreEnqueueFunc, if provided, is called before a task gets enqueued by Scheduler. + // The callback function should return quickly to not block the current thread. + PreEnqueueFunc func(task *Task, opts []Option) + + // PostEnqueueFunc, if provided, is called after a task gets enqueued by Scheduler. + // The callback function should return quickly to not block the current thread. + PostEnqueueFunc func(info *TaskInfo, err error) + + // Deprecated: Use PostEnqueueFunc instead // EnqueueErrorHandler gets called when scheduler cannot enqueue a registered task // due to an error. EnqueueErrorHandler func(task *Task, opts []Option, err error) @@ -112,19 +125,27 @@ type SchedulerOpts struct { // enqueueJob encapsulates the job of enqueuing a task and recording the event. type enqueueJob struct { - id uuid.UUID - cronspec string - task *Task - opts []Option - location *time.Location - logger *log.Logger - client *Client - rdb *rdb.RDB - errHandler func(task *Task, opts []Option, err error) + id uuid.UUID + cronspec string + task *Task + opts []Option + location *time.Location + logger *log.Logger + client *Client + rdb *rdb.RDB + preEnqueueFunc func(task *Task, opts []Option) + postEnqueueFunc func(info *TaskInfo, err error) + errHandler func(task *Task, opts []Option, err error) } func (j *enqueueJob) Run() { + if j.preEnqueueFunc != nil { + j.preEnqueueFunc(j.task, j.opts) + } info, err := j.client.Enqueue(j.task, j.opts...) + if j.postEnqueueFunc != nil { + j.postEnqueueFunc(info, err) + } if err != nil { j.logger.Errorf("scheduler could not enqueue a task %+v: %v", j.task, err) if j.errHandler != nil { @@ -147,15 +168,17 @@ func (j *enqueueJob) Run() { // It returns an ID of the newly registered entry. func (s *Scheduler) Register(cronspec string, task *Task, opts ...Option) (entryID string, err error) { job := &enqueueJob{ - id: uuid.New(), - cronspec: cronspec, - task: task, - opts: opts, - location: s.location, - client: s.client, - rdb: s.rdb, - logger: s.logger, - errHandler: s.errHandler, + id: uuid.New(), + cronspec: cronspec, + task: task, + opts: opts, + location: s.location, + client: s.client, + rdb: s.rdb, + logger: s.logger, + preEnqueueFunc: s.preEnqueueFunc, + postEnqueueFunc: s.postEnqueueFunc, + errHandler: s.errHandler, } cronID, err := s.cron.AddJob(cronspec, job) if err != nil { diff --git a/scheduler_test.go b/scheduler_test.go index 2685667..fea048e 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/testutil" ) @@ -154,3 +155,56 @@ func TestSchedulerUnregister(t *testing.T) { } } } + +func TestSchedulerPostAndPreEnqueueHandler(t *testing.T) { + var ( + preMu sync.Mutex + preCounter int + postMu sync.Mutex + postCounter int + ) + preHandler := func(task *Task, opts []Option) { + preMu.Lock() + preCounter++ + preMu.Unlock() + } + postHandler := func(info *TaskInfo, err error) { + postMu.Lock() + postCounter++ + postMu.Unlock() + } + + // Connect to non-existent redis instance to simulate a redis server being down. + scheduler := NewScheduler( + getRedisConnOpt(t), + &SchedulerOpts{ + PreEnqueueFunc: preHandler, + PostEnqueueFunc: postHandler, + }, + ) + + task := NewTask("test", nil) + + if _, err := scheduler.Register("@every 3s", task); err != nil { + t.Fatal(err) + } + + if err := scheduler.Start(); err != nil { + t.Fatal(err) + } + // Scheduler should attempt to enqueue the task three times (every 3s). + time.Sleep(10 * time.Second) + scheduler.Shutdown() + + preMu.Lock() + if preCounter != 3 { + t.Errorf("PreEnqueueFunc was called %d times, want 3", preCounter) + } + preMu.Unlock() + + postMu.Lock() + if postCounter != 3 { + t.Errorf("PostEnqueueFunc was called %d times, want 3", postCounter) + } + postMu.Unlock() +}