diff --git a/internal/base/base.go b/internal/base/base.go index 7070c18..809beca 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -95,15 +95,15 @@ type TaskMessage struct { // // ProcessStates are safe for concurrent use by multiple goroutines. type ProcessState struct { - mu sync.Mutex // guards all data fields - concurrency int - queues map[string]int - strictPriority bool - pid int - host string - status PStatus - started time.Time - activeWorkerCount int + mu sync.Mutex // guards all data fields + concurrency int + queues map[string]int + strictPriority bool + pid int + host string + status PStatus + started time.Time + workers map[string]*workerStats } // PStatus represents status of a process. @@ -133,6 +133,11 @@ func (s PStatus) String() string { return "unknown status" } +type workerStats struct { + msg *TaskMessage + started time.Time +} + // NewProcessState returns a new instance of ProcessState. func NewProcessState(host string, pid, concurrency int, queues map[string]int, strict bool) *ProcessState { return &ProcessState{ @@ -142,6 +147,7 @@ func NewProcessState(host string, pid, concurrency int, queues map[string]int, s queues: cloneQueueConfig(queues), strictPriority: strict, status: StatusIdle, + workers: make(map[string]*workerStats), } } @@ -159,11 +165,18 @@ func (ps *ProcessState) SetStarted(t time.Time) { ps.started = t } -// IncrWorkerCount increments the worker count by delta. -func (ps *ProcessState) IncrWorkerCount(delta int) { +// AddWorkerStats records when a worker started and which task it's processing. +func (ps *ProcessState) AddWorkerStats(msg *TaskMessage, started time.Time) { ps.mu.Lock() defer ps.mu.Unlock() - ps.activeWorkerCount += delta + ps.workers[msg.ID.String()] = &workerStats{msg, started} +} + +// DeleteWorkerStats removes a worker's entry from the process state. +func (ps *ProcessState) DeleteWorkerStats(msg *TaskMessage) { + ps.mu.Lock() + defer ps.mu.Unlock() + delete(ps.workers, msg.ID.String()) } // Get returns current state of process as a ProcessInfo. @@ -178,10 +191,29 @@ func (ps *ProcessState) Get() *ProcessInfo { StrictPriority: ps.strictPriority, Status: ps.status.String(), Started: ps.started, - ActiveWorkerCount: ps.activeWorkerCount, + ActiveWorkerCount: len(ps.workers), } } +// GetWorkers returns a list of currently running workers' info. +func (ps *ProcessState) GetWorkers() []*WorkerInfo { + ps.mu.Lock() + defer ps.mu.Unlock() + var res []*WorkerInfo + for _, w := range ps.workers { + res = append(res, &WorkerInfo{ + Host: ps.host, + PID: ps.pid, + ID: w.msg.ID, + Type: w.msg.Type, + Queue: w.msg.Queue, + Payload: clonePayload(w.msg.Payload), + Started: w.started, + }) + } + return res +} + func cloneQueueConfig(qcfg map[string]int) map[string]int { res := make(map[string]int) for qname, n := range qcfg { @@ -190,18 +222,37 @@ func cloneQueueConfig(qcfg map[string]int) map[string]int { return res } -// ProcessInfo holds information about running background worker process. +func clonePayload(payload map[string]interface{}) map[string]interface{} { + res := make(map[string]interface{}) + for k, v := range payload { + res[k] = v + } + return res +} + +// ProcessInfo holds information about a running background worker process. type ProcessInfo struct { + Host string + PID int Concurrency int Queues map[string]int StrictPriority bool - PID int - Host string Status string Started time.Time ActiveWorkerCount int } +// WorkerInfo holds information about a running worker. +type WorkerInfo struct { + Host string + PID int + ID xid.ID + Type string + Queue string + Payload map[string]interface{} + Started time.Time +} + // Cancelations is a collection that holds cancel functions for all in-progress tasks. // // Cancelations are safe for concurrent use by multipel goroutines. @@ -232,10 +283,11 @@ func (c *Cancelations) Delete(id string) { } // Get returns a cancel func given an id. -func (c *Cancelations) Get(id string) context.CancelFunc { +func (c *Cancelations) Get(id string) (fn context.CancelFunc, ok bool) { c.mu.Lock() defer c.mu.Unlock() - return c.cancelFuncs[id] + fn, ok = c.cancelFuncs[id] + return fn, ok } // GetAll returns all cancel funcs. diff --git a/internal/base/base_test.go b/internal/base/base_test.go index b6156b0..ee43185 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -5,8 +5,14 @@ package base import ( + "context" + "math/rand" + "sync" "testing" "time" + + "github.com/google/go-cmp/cmp" + "github.com/rs/xid" ) func TestQueueKey(t *testing.T) { @@ -96,3 +102,115 @@ func TestWorkersKey(t *testing.T) { } } } + +// Test for process state being accessed by multiple goroutines. +// Run with -race flag to check for data race. +func TestProcessStateConcurrentAccess(t *testing.T) { + ps := NewProcessState("127.0.0.1", 1234, 10, map[string]int{"default": 1}, false) + var wg sync.WaitGroup + started := time.Now() + msgs := []*TaskMessage{ + &TaskMessage{ID: xid.New(), Type: "type1", Payload: map[string]interface{}{"user_id": 42}}, + &TaskMessage{ID: xid.New(), Type: "type2"}, + &TaskMessage{ID: xid.New(), Type: "type3"}, + } + + // Simulate hearbeater calling SetStatus and SetStarted. + wg.Add(1) + go func() { + defer wg.Done() + ps.SetStarted(started) + ps.SetStatus(StatusRunning) + }() + + // Simulate processor starting worker goroutines. + for _, msg := range msgs { + wg.Add(1) + ps.AddWorkerStats(msg, time.Now()) + go func(msg *TaskMessage) { + defer wg.Done() + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) + ps.DeleteWorkerStats(msg) + }(msg) + } + + // Simulate hearbeater calling Get and GetWorkers + wg.Add(1) + go func() { + wg.Done() + for i := 0; i < 5; i++ { + ps.Get() + ps.GetWorkers() + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + } + }() + + wg.Wait() + + want := &ProcessInfo{ + Host: "127.0.0.1", + PID: 1234, + Concurrency: 10, + Queues: map[string]int{"default": 1}, + StrictPriority: false, + Status: "running", + Started: started, + ActiveWorkerCount: 0, + } + + got := ps.Get() + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("(*ProcessState).Get() = %+v, want %+v; (-want,+got)\n%s", + got, want, diff) + } +} + +// Test for cancelations being accessed by multiple goroutines. +// Run with -race flag to check for data race. +func TestCancelationsConcurrentAccess(t *testing.T) { + c := NewCancelations() + + _, cancel1 := context.WithCancel(context.Background()) + _, cancel2 := context.WithCancel(context.Background()) + _, cancel3 := context.WithCancel(context.Background()) + var key1, key2, key3 = "key1", "key2", "key3" + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.Add(key1, cancel1) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.Add(key2, cancel2) + time.Sleep(200 * time.Millisecond) + c.Delete(key2) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.Add(key3, cancel3) + }() + + wg.Wait() + + _, ok := c.Get(key1) + if !ok { + t.Errorf("(*Cancelations).Get(%q) = _, false, want , true", key1) + } + + _, ok = c.Get(key2) + if ok { + t.Errorf("(*Cancelations).Get(%q) = _, true, want , false", key2) + } + + funcs := c.GetAll() + if len(funcs) != 2 { + t.Errorf("(*Cancelations).GetAll() returns %d functions, want 2", len(funcs)) + } +} diff --git a/processor.go b/processor.go index b5248c8..0852461 100644 --- a/processor.go +++ b/processor.go @@ -165,10 +165,10 @@ func (p *processor) exec() { p.requeue(msg) return case p.sema <- struct{}{}: // acquire token - p.ps.IncrWorkerCount(1) + p.ps.AddWorkerStats(msg, time.Now()) go func() { defer func() { - p.ps.IncrWorkerCount(-1) + p.ps.DeleteWorkerStats(msg) <-p.sema /* release token */ }() diff --git a/subscriber.go b/subscriber.go index fc50c06..e8d3731 100644 --- a/subscriber.go +++ b/subscriber.go @@ -52,8 +52,8 @@ func (s *subscriber) start(wg *sync.WaitGroup) { logger.info("Subscriber done") return case msg := <-cancelCh: - cancel := s.cancelations.Get(msg.Payload) - if cancel != nil { + cancel, ok := s.cancelations.Get(msg.Payload) + if ok { cancel() } }