diff --git a/client.go b/client.go index 7948036..9ab1db1 100644 --- a/client.go +++ b/client.go @@ -426,3 +426,8 @@ func (c *Client) addToGroup(ctx context.Context, msg *base.TaskMessage, group st } return c.broker.AddToGroup(ctx, msg, group) } + +// StateChanged watchs state updates, with more customized detail +func (c *Client) SetTaskProber(prober base.TaskProber) { + c.broker.SetTaskProber(prober) +} diff --git a/internal/base/base.go b/internal/base/base.go index ec342f8..64be48a 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -53,6 +53,11 @@ const ( TaskStateAggregating // describes a state where task is waiting in a group to be aggregated ) +type TaskProber interface { + Changed(map[string]interface{}) + Result(state TaskState, raw *TaskInfo) (string, interface{}) +} + func (s TaskState) String() string { switch s { case TaskStateActive: @@ -752,4 +757,7 @@ type Broker interface { PublishCancelation(id string) error WriteResult(qname, id string, data []byte) (n int, err error) + + // StateChanged watch state updates, with more customized detail + SetTaskProber(prober TaskProber) } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 7b25f15..c07ccdf 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -7,6 +7,7 @@ package rdb import ( "context" + "encoding/json" "fmt" "math" "time" @@ -131,6 +132,7 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error { if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStatePending) return nil } @@ -198,6 +200,7 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStatePending) return nil } @@ -464,9 +467,13 @@ func (r *RDB) MarkAsComplete(ctx context.Context, msg *base.TaskMessage) error { // Note: We cannot pass empty unique key when running this script in redis-cluster. if len(msg.UniqueKey) > 0 { keys = append(keys, msg.UniqueKey) - return r.runScript(ctx, op, markAsCompleteUniqueCmd, keys, argv...) + err := r.runScript(ctx, op, markAsCompleteUniqueCmd, keys, argv...) + r.state(ctx, msg, base.TaskStateCompleted, err) + return err } - return r.runScript(ctx, op, markAsCompleteCmd, keys, argv...) + err = r.runScript(ctx, op, markAsCompleteCmd, keys, argv...) + r.state(ctx, msg, base.TaskStateCompleted, err) + return err } // KEYS[1] -> asynq:{}:active @@ -495,7 +502,9 @@ func (r *RDB) Requeue(ctx context.Context, msg *base.TaskMessage) error { base.PendingKey(msg.Queue), base.TaskKey(msg.Queue, msg.ID), } - return r.runScript(ctx, op, requeueCmd, keys, msg.ID) + err := r.runScript(ctx, op, requeueCmd, keys, msg.ID) + r.state(ctx, msg, base.TaskStatePending, err) + return err } // KEYS[1] -> asynq:{}:t: @@ -550,6 +559,7 @@ func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey st if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStateAggregating) return nil } @@ -617,6 +627,7 @@ func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, group if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStateAggregating) return nil } @@ -667,6 +678,7 @@ func (r *RDB) Schedule(ctx context.Context, msg *base.TaskMessage, processAt tim if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStateScheduled) return nil } @@ -731,6 +743,7 @@ func (r *RDB) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, process if n == 0 { return errors.E(op, errors.AlreadyExists, errors.ErrTaskIdConflict) } + r.state(ctx, msg, base.TaskStateScheduled) return nil } @@ -813,7 +826,9 @@ func (r *RDB) Retry(ctx context.Context, msg *base.TaskMessage, processAt time.T isFailure, int64(math.MaxInt64), } - return r.runScript(ctx, op, retryCmd, keys, argv...) + err = r.runScript(ctx, op, retryCmd, keys, argv...) + r.state(ctx, msg, base.TaskStateRetry, err) + return err } const ( @@ -899,7 +914,9 @@ func (r *RDB) Archive(ctx context.Context, msg *base.TaskMessage, errMsg string) expireAt.Unix(), int64(math.MaxInt64), } - return r.runScript(ctx, op, archiveCmd, keys, argv...) + err = r.runScript(ctx, op, archiveCmd, keys, argv...) + r.state(ctx, msg, base.TaskStateArchived, err) + return err } // ForwardIfReady checks scheduled and retry sets of the given queues @@ -1444,6 +1461,79 @@ func (r *RDB) ClearSchedulerEntries(scheduelrID string) error { return nil } +func (r *RDB) state(ctx context.Context, msg *base.TaskMessage, state base.TaskState, errs ...error) { + var err error + if len(errs) > 0 { + if errs[0] != nil { + err = errs[0] + } + } + out := map[string]interface{}{ + "queue": msg.Queue, + "id": msg.ID, + "state": state, + } + if err != nil { + out["err"] = err.Error() + } + if len(msg.GroupKey) > 0 { + out["group"] = msg.GroupKey + } + if len(msg.UniqueKey) > 0 { + out["unique"] = msg.UniqueKey + } + + payload, _ := json.Marshal(out) + r.client.Publish(ctx, "state-changed", payload) +} + +// StateChanged watch state updates, with more customized detail +func (r *RDB) SetTaskProber(prober base.TaskProber) { + ctx := context.Background() + pubsub := r.client.Subscribe(ctx, "state-changed") + + changed := func(out map[string]interface{}, err error) { + if err != nil { + out["err"] = "prober: " + err.Error() + } + go prober.Changed(out) + } + + handler := func(payload string) { + var out map[string]interface{} + err := json.Unmarshal([]byte(payload), &out) + if err != nil { + changed(out, err) + return + } + s, ok := out["state"].(float64) + if !ok { + changed(out, fmt.Errorf("invalid state %v", out["state"])) + return + } + state := base.TaskState(s) + out["state"] = state.String() + res, err := r.GetTaskInfo(out["queue"].(string), out["id"].(string)) + if err != nil { + changed(out, err) + return + } + msg := res.Message + if state == base.TaskStateCompleted { + out["at"] = msg.CompletedAt + } + key, data := prober.Result(state, res) + if data != nil { + out[key] = data + } + changed(out, nil) + } + + for m := range pubsub.Channel() { + handler(m.Payload) + } +} + // CancelationPubSub returns a pubsub for cancelation messages. func (r *RDB) CancelationPubSub() (*redis.PubSub, error) { var op errors.Op = "rdb.CancelationPubSub" diff --git a/server.go b/server.go index 4bf04e0..77761e8 100644 --- a/server.go +++ b/server.go @@ -220,6 +220,53 @@ type Config struct { // // If unset or nil, the group aggregation feature will be disabled on the server. GroupAggregator GroupAggregator + + // StateChanged called when a task state changed + // + TaskStateProber *TaskStateProber +} + +// TaskStateProber tell there's a state changed happening +type TaskStateProber struct { + Probers map[string]string // map[state-string]data-name + Handler func(map[string]interface{}) +} + +func (p TaskStateProber) Changed(out map[string]interface{}) { + if p.Handler != nil { + p.Handler(out) + } +} + +func (p TaskStateProber) Result(state base.TaskState, raw *base.TaskInfo) (key string, data interface{}) { + if state == base.TaskStateCompleted { + data = *newTaskInfo(raw.Message, raw.State, raw.NextProcessAt, raw.Result) + return + } + + probers := p.Probers + if len(probers) == 0 { + probers = map[string]string{"*": "task"} + } + key, ok := probers["*"] + if !ok { + key, ok = probers[state.String()] + } + if !ok { + return + } + + switch key { + case "next": + data = raw.NextProcessAt + case "task": + data = *newTaskInfo(raw.Message, raw.State, raw.NextProcessAt, raw.Result) + default: + if len(raw.Result) > 0 { + data = raw.Result + } + } + return } // GroupAggregator aggregates a group of tasks into one before the tasks are passed to the Handler. @@ -458,6 +505,11 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { srvState := &serverState{value: srvStateNew} cancels := base.NewCancelations() + taskStateProber := cfg.TaskStateProber + if taskStateProber != nil { + rdb.SetTaskProber(*taskStateProber) + } + syncer := newSyncer(syncerParams{ logger: logger, requestsCh: syncCh, @@ -697,3 +749,8 @@ func (srv *Server) Stop() { srv.processor.stop() srv.logger.Info("Processor stopped") } + +// StateChanged watch state updates, with more customized detail +func (srv *Server) SetTaskStateProber(prober base.TaskProber) { + srv.broker.SetTaskProber(prober) +}