From a38f628f3b24090a63bececc44d87e11d9edbb75 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Mon, 18 May 2020 20:47:35 -0700 Subject: [PATCH] Refactor server state management --- heartbeat.go | 113 +++++++++-- heartbeat_test.go | 36 ++-- internal/asynqtest/asynqtest.go | 2 +- internal/base/base.go | 156 +++------------ internal/base/base_test.go | 66 +------ internal/rdb/inspect_test.go | 86 +++------ internal/rdb/rdb.go | 27 ++- internal/rdb/rdb_test.go | 304 +++++++++++++++++------------- internal/testbroker/testbroker.go | 8 +- processor.go | 38 ++-- processor_test.go | 67 +++++-- server.go | 47 ++--- 12 files changed, 482 insertions(+), 468 deletions(-) diff --git a/heartbeat.go b/heartbeat.go index 61b7706..26907f3 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -5,11 +5,13 @@ package asynq import ( + "os" "sync" "time" "github.com/hibiken/asynq/internal/base" "github.com/hibiken/asynq/internal/log" + "github.com/rs/xid" ) // heartbeater is responsible for writing process info to redis periodically to @@ -18,29 +20,69 @@ type heartbeater struct { logger *log.Logger broker base.Broker - ss *base.ServerState - // channel to communicate back to the long running "heartbeater" goroutine. done chan struct{} // interval between heartbeats. interval time.Duration + + // following fields are initialized at construction time and are immutable. + host string + pid int + serverID string + concurrency int + queues map[string]int + strictPriority bool + + // following fields are mutable and should be accessed only by the + // heartbeater goroutine. In other words, confine these variables + // to this goroutine only. + started time.Time + workers map[string]workerStat + + // status is shared with other goroutine but is concurrency safe. + status *base.ServerStatus + + // channels to receive updates on active workers. + starting <-chan *base.TaskMessage + finished <-chan *base.TaskMessage } type heartbeaterParams struct { - logger *log.Logger - broker base.Broker - serverState *base.ServerState - interval time.Duration + logger *log.Logger + broker base.Broker + interval time.Duration + concurrency int + queues map[string]int + strictPriority bool + status *base.ServerStatus + starting <-chan *base.TaskMessage + finished <-chan *base.TaskMessage } func newHeartbeater(params heartbeaterParams) *heartbeater { + host, err := os.Hostname() + if err != nil { + host = "unknown-host" + } + return &heartbeater{ logger: params.logger, broker: params.broker, - ss: params.serverState, done: make(chan struct{}), interval: params.interval, + + host: host, + pid: os.Getpid(), + serverID: xid.New().String(), + concurrency: params.concurrency, + queues: params.queues, + strictPriority: params.strictPriority, + + status: params.status, + workers: make(map[string]workerStat), + starting: params.starting, + finished: params.finished, } } @@ -50,31 +92,74 @@ func (h *heartbeater) terminate() { h.done <- struct{}{} } +// A workerStat records the message a worker is working on +// and the time the worker has started processing the message. +type workerStat struct { + started time.Time + msg *base.TaskMessage +} + func (h *heartbeater) start(wg *sync.WaitGroup) { - h.ss.SetStarted(time.Now()) - h.ss.SetStatus(base.StatusRunning) wg.Add(1) go func() { defer wg.Done() + + h.started = time.Now() + h.beat() + + timer := time.NewTimer(h.interval) for { select { case <-h.done: - h.broker.ClearServerState(h.ss) + h.broker.ClearServerState(h.host, h.pid, h.serverID) h.logger.Debug("Heartbeater done") + timer.Stop() return - case <-time.After(h.interval): + + case <-timer.C: h.beat() + timer.Reset(h.interval) + + case msg := <-h.starting: + h.workers[msg.ID.String()] = workerStat{time.Now(), msg} + + case msg := <-h.finished: + delete(h.workers, msg.ID.String()) } } }() } func (h *heartbeater) beat() { + info := base.ServerInfo{ + Host: h.host, + PID: h.pid, + ServerID: h.serverID, + Concurrency: h.concurrency, + Queues: h.queues, + StrictPriority: h.strictPriority, + Status: h.status.String(), + Started: h.started, + ActiveWorkerCount: len(h.workers), + } + + var ws []*base.WorkerInfo + for id, stat := range h.workers { + ws = append(ws, &base.WorkerInfo{ + Host: h.host, + PID: h.pid, + ID: id, + Type: stat.msg.Type, + Queue: stat.msg.Queue, + Payload: stat.msg.Payload, + Started: stat.started, + }) + } + // Note: Set TTL to be long enough so that it won't expire before we write again // and short enough to expire quickly once the process is shut down or killed. - err := h.broker.WriteServerState(h.ss, h.interval*2) - if err != nil { - h.logger.Errorf("could not write heartbeat data: %v", err) + if err := h.broker.WriteServerState(&info, ws, h.interval*2); err != nil { + h.logger.Errorf("could not write server state data: %v", err) } } diff --git a/heartbeat_test.go b/heartbeat_test.go index 92ac9d0..ebd8a75 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -37,14 +37,24 @@ func TestHeartbeater(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r) - state := base.NewServerState(tc.host, tc.pid, tc.concurrency, tc.queues, false) + status := base.NewServerStatus(base.StatusIdle) hb := newHeartbeater(heartbeaterParams{ - logger: testLogger, - broker: rdbClient, - serverState: state, - interval: tc.interval, + logger: testLogger, + broker: rdbClient, + interval: tc.interval, + concurrency: tc.concurrency, + queues: tc.queues, + strictPriority: false, + status: status, + starting: make(chan *base.TaskMessage), + finished: make(chan *base.TaskMessage), }) + // Change host and pid fields for testing purpose. + hb.host = tc.host + hb.pid = tc.pid + + status.Set(base.StatusRunning) var wg sync.WaitGroup hb.start(&wg) @@ -80,7 +90,7 @@ func TestHeartbeater(t *testing.T) { } // status change - state.SetStatus(base.StatusStopped) + status.Set(base.StatusStopped) // allow for heartbeater to write to redis time.Sleep(tc.interval * 2) @@ -119,12 +129,16 @@ func TestHeartbeaterWithRedisDown(t *testing.T) { }() r := rdb.NewRDB(setup(t)) testBroker := testbroker.NewTestBroker(r) - ss := base.NewServerState("localhost", 1234, 10, map[string]int{"default": 1}, false) hb := newHeartbeater(heartbeaterParams{ - logger: testLogger, - broker: testBroker, - serverState: ss, - interval: time.Second, + logger: testLogger, + broker: testBroker, + interval: time.Second, + concurrency: 10, + queues: map[string]int{"default": 1}, + strictPriority: false, + status: base.NewServerStatus(base.StatusRunning), + starting: make(chan *base.TaskMessage), + finished: make(chan *base.TaskMessage), }) testBroker.Sleep() diff --git a/internal/asynqtest/asynqtest.go b/internal/asynqtest/asynqtest.go index 03be697..a713deb 100644 --- a/internal/asynqtest/asynqtest.go +++ b/internal/asynqtest/asynqtest.go @@ -57,7 +57,7 @@ var SortServerInfoOpt = cmp.Transformer("SortServerInfo", func(in []*base.Server var SortWorkerInfoOpt = cmp.Transformer("SortWorkerInfo", func(in []*base.WorkerInfo) []*base.WorkerInfo { out := append([]*base.WorkerInfo(nil), in...) // Copy input to avoid mutating it sort.Slice(out, func(i, j int) bool { - return out[i].ID.String() < out[j].ID.String() + return out[i].ID < out[j].ID }) return out }) diff --git a/internal/base/base.go b/internal/base/base.go index ca71232..e6bd364 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -105,28 +105,23 @@ type TaskMessage struct { UniqueKey string } -// ServerState holds process level information. -// -// ServerStates are safe for concurrent use by multiple goroutines. -type ServerState struct { - mu sync.Mutex // guards all data fields - id xid.ID - concurrency int - queues map[string]int - strictPriority bool - pid int - host string - status ServerStatus - started time.Time - workers map[string]*workerStats +// ServerStatus represents status of a server. +// ServerStatus methods are concurrency safe. +type ServerStatus struct { + mu sync.Mutex + val ServerStatusValue } -// ServerStatus represents status of a server. -type ServerStatus int +// NewServerStatus returns a new status instance given an initial value. +func NewServerStatus(v ServerStatusValue) *ServerStatus { + return &ServerStatus{val: v} +} + +type ServerStatusValue int const ( // StatusIdle indicates the server is in idle state. - StatusIdle ServerStatus = iota + StatusIdle ServerStatusValue = iota // StatusRunning indicates the servier is up and processing tasks. StatusRunning @@ -145,117 +140,28 @@ var statuses = []string{ "stopped", } -func (s ServerStatus) String() string { - if StatusIdle <= s && s <= StatusStopped { - return statuses[s] +func (s *ServerStatus) String() string { + s.mu.Lock() + defer s.mu.Unlock() + if StatusIdle <= s.val && s.val <= StatusStopped { + return statuses[s.val] } return "unknown status" } -type workerStats struct { - msg *TaskMessage - started time.Time +// Get returns the status value. +func (s *ServerStatus) Get() ServerStatusValue { + s.mu.Lock() + v := s.val + s.mu.Unlock() + return v } -// NewServerState returns a new instance of ServerState. -func NewServerState(host string, pid, concurrency int, queues map[string]int, strict bool) *ServerState { - return &ServerState{ - host: host, - pid: pid, - id: xid.New(), - concurrency: concurrency, - queues: cloneQueueConfig(queues), - strictPriority: strict, - status: StatusIdle, - workers: make(map[string]*workerStats), - } -} - -// SetStatus updates the status of server. -func (ss *ServerState) SetStatus(status ServerStatus) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.status = status -} - -// Status returns the status of server. -func (ss *ServerState) Status() ServerStatus { - ss.mu.Lock() - defer ss.mu.Unlock() - return ss.status -} - -// SetStarted records when the process started processing. -func (ss *ServerState) SetStarted(t time.Time) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.started = t -} - -// AddWorkerStats records when a worker started and which task it's processing. -func (ss *ServerState) AddWorkerStats(msg *TaskMessage, started time.Time) { - ss.mu.Lock() - defer ss.mu.Unlock() - ss.workers[msg.ID.String()] = &workerStats{msg, started} -} - -// DeleteWorkerStats removes a worker's entry from the process state. -func (ss *ServerState) DeleteWorkerStats(msg *TaskMessage) { - ss.mu.Lock() - defer ss.mu.Unlock() - delete(ss.workers, msg.ID.String()) -} - -// GetInfo returns current state of server as a ServerInfo. -func (ss *ServerState) GetInfo() *ServerInfo { - ss.mu.Lock() - defer ss.mu.Unlock() - return &ServerInfo{ - Host: ss.host, - PID: ss.pid, - ServerID: ss.id.String(), - Concurrency: ss.concurrency, - Queues: cloneQueueConfig(ss.queues), - StrictPriority: ss.strictPriority, - Status: ss.status.String(), - Started: ss.started, - ActiveWorkerCount: len(ss.workers), - } -} - -// GetWorkers returns a list of currently running workers' info. -func (ss *ServerState) GetWorkers() []*WorkerInfo { - ss.mu.Lock() - defer ss.mu.Unlock() - var res []*WorkerInfo - for _, w := range ss.workers { - res = append(res, &WorkerInfo{ - Host: ss.host, - PID: ss.pid, - ID: w.msg.ID, - Type: w.msg.Type, - Queue: w.msg.Queue, - Payload: clonePayload(w.msg.Payload), - Started: w.started, - }) - } - return res -} - -func cloneQueueConfig(qcfg map[string]int) map[string]int { - res := make(map[string]int) - for qname, n := range qcfg { - res[qname] = n - } - return res -} - -func clonePayload(payload map[string]interface{}) map[string]interface{} { - res := make(map[string]interface{}) - for k, v := range payload { - res[k] = v - } - return res +// Set sets the status value. +func (s *ServerStatus) Set(v ServerStatusValue) { + s.mu.Lock() + s.val = v + s.mu.Unlock() } // ServerInfo holds information about a running server. @@ -275,7 +181,7 @@ type ServerInfo struct { type WorkerInfo struct { Host string PID int - ID xid.ID + ID string Type string Queue string Payload map[string]interface{} @@ -345,8 +251,8 @@ type Broker interface { Kill(msg *TaskMessage, errMsg string) error RequeueAll() (int64, error) CheckAndEnqueue(qnames ...string) error - WriteServerState(ss *ServerState, ttl time.Duration) error - ClearServerState(ss *ServerState) error + WriteServerState(info *ServerInfo, workers []*WorkerInfo, ttl time.Duration) error + ClearServerState(host string, pid int, serverID string) error CancelationPubSub() (*redis.PubSub, error) // TODO: Need to decouple from redis to support other brokers PublishCancelation(id string) error Close() error diff --git a/internal/base/base_test.go b/internal/base/base_test.go index 65c223f..cfe8414 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -6,14 +6,9 @@ package base import ( "context" - "math/rand" "sync" "testing" "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/rs/xid" ) func TestQueueKey(t *testing.T) { @@ -108,69 +103,28 @@ func TestWorkersKey(t *testing.T) { } } -// Test for server state being accessed by multiple goroutines. +// Test for status being accessed by multiple goroutines. // Run with -race flag to check for data race. -func TestServerStateConcurrentAccess(t *testing.T) { - ss := NewServerState("127.0.0.1", 1234, 10, map[string]int{"default": 1}, false) - var wg sync.WaitGroup - started := time.Now() - msgs := []*TaskMessage{ - {ID: xid.New(), Type: "type1", Payload: map[string]interface{}{"user_id": 42}}, - {ID: xid.New(), Type: "type2"}, - {ID: xid.New(), Type: "type3"}, - } +func TestStatusConcurrentAccess(t *testing.T) { + status := NewServerStatus(StatusIdle) + + var wg sync.WaitGroup - // Simulate hearbeater calling SetStatus and SetStarted. wg.Add(1) go func() { defer wg.Done() - ss.SetStarted(started) - ss.SetStatus(StatusRunning) - if status := ss.Status(); status != StatusRunning { - t.Errorf("(*ServerState).Status() = %v, want %v", status, StatusRunning) - } + status.Get() + status.String() }() - // Simulate processor starting worker goroutines. - for _, msg := range msgs { - wg.Add(1) - ss.AddWorkerStats(msg, time.Now()) - go func(msg *TaskMessage) { - defer wg.Done() - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) - ss.DeleteWorkerStats(msg) - }(msg) - } - - // Simulate hearbeater calling Get and GetWorkers wg.Add(1) go func() { - wg.Done() - for i := 0; i < 5; i++ { - ss.GetInfo() - ss.GetWorkers() - time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) - } + defer wg.Done() + status.Set(StatusStopped) + status.String() }() wg.Wait() - - want := &ServerInfo{ - Host: "127.0.0.1", - PID: 1234, - Concurrency: 10, - Queues: map[string]int{"default": 1}, - StrictPriority: false, - Status: "running", - Started: started, - ActiveWorkerCount: 0, - } - - got := ss.GetInfo() - if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(ServerInfo{}, "ServerID")); diff != "" { - t.Errorf("(*ServerState).GetInfo() = %+v, want %+v; (-want,+got)\n%s", - got, want, diff) - } } // Test for cancelations being accessed by multiple goroutines. diff --git a/internal/rdb/inspect_test.go b/internal/rdb/inspect_test.go index 1a1e931..c0b2c09 100644 --- a/internal/rdb/inspect_test.go +++ b/internal/rdb/inspect_test.go @@ -2055,60 +2055,48 @@ func TestListServers(t *testing.T) { r := setup(t) started1 := time.Now().Add(-time.Hour) - ss1 := base.NewServerState("do.droplet1", 1234, 10, map[string]int{"default": 1}, false) - ss1.SetStarted(started1) - ss1.SetStatus(base.StatusRunning) info1 := &base.ServerInfo{ - Concurrency: 10, - Queues: map[string]int{"default": 1}, Host: "do.droplet1", PID: 1234, + ServerID: "server123", + Concurrency: 10, + Queues: map[string]int{"default": 1}, Status: "running", Started: started1, ActiveWorkerCount: 0, } started2 := time.Now().Add(-2 * time.Hour) - ss2 := base.NewServerState("do.droplet2", 9876, 20, map[string]int{"email": 1}, false) - ss2.SetStarted(started2) - ss2.SetStatus(base.StatusStopped) - ss2.AddWorkerStats(h.NewTaskMessage("send_email", nil), time.Now()) info2 := &base.ServerInfo{ - Concurrency: 20, - Queues: map[string]int{"email": 1}, Host: "do.droplet2", PID: 9876, + ServerID: "server456", + Concurrency: 20, + Queues: map[string]int{"email": 1}, Status: "stopped", Started: started2, ActiveWorkerCount: 1, } tests := []struct { - serverStates []*base.ServerState - want []*base.ServerInfo + data []*base.ServerInfo }{ { - serverStates: []*base.ServerState{}, - want: []*base.ServerInfo{}, + data: []*base.ServerInfo{}, }, { - serverStates: []*base.ServerState{ss1}, - want: []*base.ServerInfo{info1}, + data: []*base.ServerInfo{info1}, }, { - serverStates: []*base.ServerState{ss1, ss2}, - want: []*base.ServerInfo{info1, info2}, + data: []*base.ServerInfo{info1, info2}, }, } - ignoreOpt := cmpopts.IgnoreUnexported(base.ServerInfo{}) - ignoreFieldOpt := cmpopts.IgnoreFields(base.ServerInfo{}, "ServerID") - for _, tc := range tests { h.FlushDB(t, r.client) - for _, ss := range tc.serverStates { - if err := r.WriteServerState(ss, 5*time.Second); err != nil { + for _, info := range tc.data { + if err := r.WriteServerState(info, []*base.WorkerInfo{}, 5*time.Second); err != nil { t.Fatal(err) } } @@ -2117,9 +2105,9 @@ func TestListServers(t *testing.T) { if err != nil { t.Errorf("r.ListServers returned an error: %v", err) } - if diff := cmp.Diff(tc.want, got, h.SortServerInfoOpt, ignoreOpt, ignoreFieldOpt); diff != "" { + if diff := cmp.Diff(tc.data, got, h.SortServerInfoOpt); diff != "" { t.Errorf("r.ListServers returned %v, want %v; (-want,+got)\n%s", - got, tc.serverStates, diff) + got, tc.data, diff) } } } @@ -2127,37 +2115,23 @@ func TestListServers(t *testing.T) { func TestListWorkers(t *testing.T) { r := setup(t) - const ( + var ( host = "127.0.0.1" pid = 4567 + + m1 = h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "abc123"}) + m2 = h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/image/file"}) + m3 = h.NewTaskMessage("reindex", map[string]interface{}{}) ) - m1 := h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "abc123"}) - m2 := h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/image/file"}) - m3 := h.NewTaskMessage("reindex", map[string]interface{}{}) - t1 := time.Now().Add(-time.Second) - t2 := time.Now().Add(-10 * time.Second) - t3 := time.Now().Add(-time.Minute) - - type workerStats struct { - msg *base.TaskMessage - started time.Time - } - tests := []struct { - workers []*workerStats - want []*base.WorkerInfo + data []*base.WorkerInfo }{ { - workers: []*workerStats{ - {m1, t1}, - {m2, t2}, - {m3, t3}, - }, - want: []*base.WorkerInfo{ - {Host: host, PID: pid, ID: m1.ID, Type: m1.Type, Queue: m1.Queue, Payload: m1.Payload, Started: t1}, - {Host: host, PID: pid, ID: m2.ID, Type: m2.Type, Queue: m2.Queue, Payload: m2.Payload, Started: t2}, - {Host: host, PID: pid, ID: m3.ID, Type: m3.Type, Queue: m3.Queue, Payload: m3.Payload, Started: t3}, + data: []*base.WorkerInfo{ + {Host: host, PID: pid, ID: m1.ID.String(), Type: m1.Type, Queue: m1.Queue, Payload: m1.Payload, Started: time.Now().Add(-1 * time.Second)}, + {Host: host, PID: pid, ID: m2.ID.String(), Type: m2.Type, Queue: m2.Queue, Payload: m2.Payload, Started: time.Now().Add(-5 * time.Second)}, + {Host: host, PID: pid, ID: m3.ID.String(), Type: m3.Type, Queue: m3.Queue, Payload: m3.Payload, Started: time.Now().Add(-30 * time.Second)}, }, }, } @@ -2165,13 +2139,7 @@ func TestListWorkers(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r.client) - ss := base.NewServerState(host, pid, 10, map[string]int{"default": 1}, false) - - for _, w := range tc.workers { - ss.AddWorkerStats(w.msg, w.started) - } - - err := r.WriteServerState(ss, time.Minute) + err := r.WriteServerState(&base.ServerInfo{}, tc.data, time.Minute) if err != nil { t.Errorf("could not write server state to redis: %v", err) continue @@ -2183,8 +2151,8 @@ func TestListWorkers(t *testing.T) { continue } - if diff := cmp.Diff(tc.want, got, h.SortWorkerInfoOpt); diff != "" { - t.Errorf("(*RDB).ListWorkers() = %v, want = %v; (-want,+got)\n%s", got, tc.want, diff) + if diff := cmp.Diff(tc.data, got, h.SortWorkerInfoOpt); diff != "" { + t.Errorf("(*RDB).ListWorkers() = %v, want = %v; (-want,+got)\n%s", got, tc.data, diff) } } } diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go index 7fc4481..aed147d 100644 --- a/internal/rdb/rdb.go +++ b/internal/rdb/rdb.go @@ -466,14 +466,14 @@ func (r *RDB) forwardSingle(src, dst string) error { // KEYS[1] -> asynq:servers: // KEYS[2] -> asynq:servers // KEYS[3] -> asynq:workers -// keys[4] -> asynq:workers +// KEYS[4] -> asynq:workers // ARGV[1] -> expiration time // ARGV[2] -> TTL in seconds -// ARGV[3] -> process info +// ARGV[3] -> server info // ARGV[4:] -> alternate key-value pair of (worker id, worker data) // Note: Add key to ZSET with expiration time as score. // ref: https://github.com/antirez/redis/issues/135#issuecomment-2361996 -var writeProcessInfoCmd = redis.NewScript(` +var writeServerStateCmd = redis.NewScript(` redis.call("SETEX", KEYS[1], ARGV[2], ARGV[3]) redis.call("ZADD", KEYS[2], ARGV[1], KEYS[1]) redis.call("DEL", KEYS[3]) @@ -484,27 +484,24 @@ redis.call("EXPIRE", KEYS[3], ARGV[2]) redis.call("ZADD", KEYS[4], ARGV[1], KEYS[3]) return redis.status_reply("OK")`) -// WriteServerState writes server state data to redis with expiration set to the value ttl. -func (r *RDB) WriteServerState(ss *base.ServerState, ttl time.Duration) error { - info := ss.GetInfo() +// WriteServerState writes server state data to redis with expiration set to the value ttl. +func (r *RDB) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error { bytes, err := json.Marshal(info) if err != nil { return err } - var args []interface{} // args to the lua script exp := time.Now().Add(ttl).UTC() - workers := ss.GetWorkers() - args = append(args, float64(exp.Unix()), ttl.Seconds(), bytes) + args := []interface{}{float64(exp.Unix()), ttl.Seconds(), bytes} // args to the lua script for _, w := range workers { bytes, err := json.Marshal(w) if err != nil { continue // skip bad data } - args = append(args, w.ID.String(), bytes) + args = append(args, w.ID, bytes) } skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) - return writeProcessInfoCmd.Run(r.client, + return writeServerStateCmd.Run(r.client, []string{skey, base.AllServers, wkey, base.AllWorkers}, args...).Err() } @@ -521,11 +518,9 @@ redis.call("DEL", KEYS[4]) return redis.status_reply("OK")`) // ClearServerState deletes server state data from redis. -func (r *RDB) ClearServerState(ss *base.ServerState) error { - info := ss.GetInfo() - host, pid, id := info.Host, info.PID, info.ServerID - skey := base.ServerInfoKey(host, pid, id) - wkey := base.WorkersKey(host, pid, id) +func (r *RDB) ClearServerState(host string, pid int, serverID string) error { + skey := base.ServerInfoKey(host, pid, serverID) + wkey := base.WorkersKey(host, pid, serverID) return clearProcessInfoCmd.Run(r.client, []string{base.AllServers, skey, base.AllWorkers, wkey}).Err() } diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go index cd0fe00..2172214 100644 --- a/internal/rdb/rdb_test.go +++ b/internal/rdb/rdb_test.go @@ -864,64 +864,63 @@ func TestCheckAndEnqueue(t *testing.T) { func TestWriteServerState(t *testing.T) { r := setup(t) - queues := map[string]int{"default": 2, "email": 5, "low": 1} - started := time.Now() - ss := base.NewServerState("localhost", 4242, 10, queues, false) - ss.SetStarted(started) - ss.SetStatus(base.StatusRunning) - ttl := 5 * time.Second + var ( + host = "localhost" + pid = 4242 + serverID = "server123" - h.FlushDB(t, r.client) + ttl = 5 * time.Second + ) - err := r.WriteServerState(ss, ttl) + info := base.ServerInfo{ + Host: host, + PID: pid, + ServerID: serverID, + Concurrency: 10, + Queues: map[string]int{"default": 2, "email": 5, "low": 1}, + StrictPriority: false, + Started: time.Now(), + Status: "running", + ActiveWorkerCount: 0, + } + + err := r.WriteServerState(&info, nil /* workers */, ttl) if err != nil { t.Errorf("r.WriteServerState returned an error: %v", err) } - // Check ServerInfo was written correctly - info := ss.GetInfo() - skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) + // Check ServerInfo was written correctly. + skey := base.ServerInfoKey(host, pid, serverID) data := r.client.Get(skey).Val() var got base.ServerInfo err = json.Unmarshal([]byte(data), &got) if err != nil { t.Fatalf("could not decode json: %v", err) } - want := base.ServerInfo{ - Host: info.Host, - PID: info.PID, - Concurrency: info.Concurrency, - Queues: map[string]int{"default": 2, "email": 5, "low": 1}, - StrictPriority: false, - Status: "running", - Started: started, - ActiveWorkerCount: 0, - } - ignoreOpt := cmpopts.IgnoreFields(base.ServerInfo{}, "ServerID") - if diff := cmp.Diff(want, got, ignoreOpt); diff != "" { + if diff := cmp.Diff(info, got); diff != "" { t.Errorf("persisted ServerInfo was %v, want %v; (-want,+got)\n%s", - got, want, diff) + got, info, diff) } - // Check ServerInfo TTL was set correctly + // Check ServerInfo TTL was set correctly. gotTTL := r.client.TTL(skey).Val() if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { t.Errorf("TTL of %q was %v, want %v", skey, gotTTL, ttl) } - // Check ServerInfo key was added to the set correctly - gotProcesses := r.client.ZRange(base.AllServers, 0, -1).Val() - wantProcesses := []string{skey} - if diff := cmp.Diff(wantProcesses, gotProcesses); diff != "" { - t.Errorf("%q contained %v, want %v", base.AllServers, gotProcesses, wantProcesses) + // Check ServerInfo key was added to the set all server keys correctly. + gotServerKeys := r.client.ZRange(base.AllServers, 0, -1).Val() + wantServerKeys := []string{skey} + if diff := cmp.Diff(wantServerKeys, gotServerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllServers, gotServerKeys, wantServerKeys) } - // Check WorkersInfo was written correctly - wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) + // Check WorkersInfo was written correctly. + wkey := base.WorkersKey(host, pid, serverID) workerExist := r.client.Exists(wkey).Val() if workerExist != 0 { t.Errorf("%q key exists", wkey) } - // Check WorkersInfo key was added to the set correctly + // Check WorkersInfo key was added to the set correctly. gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() wantWorkerKeys := []string{wkey} if diff := cmp.Diff(wantWorkerKeys, gotWorkerKeys); diff != "" { @@ -931,109 +930,105 @@ func TestWriteServerState(t *testing.T) { func TestWriteServerStateWithWorkers(t *testing.T) { r := setup(t) - queues := map[string]int{"default": 2, "email": 5, "low": 1} - concurrency := 10 - started := time.Now().Add(-10 * time.Minute) - w1Started := time.Now().Add(-time.Minute) - w2Started := time.Now().Add(-time.Second) - msg1 := h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "123"}) - msg2 := h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/imgfile"}) - ss := base.NewServerState("127.0.01", 4242, concurrency, queues, false) - ss.SetStarted(started) - ss.SetStatus(base.StatusRunning) - ss.AddWorkerStats(msg1, w1Started) - ss.AddWorkerStats(msg2, w2Started) - ttl := 5 * time.Second + var ( + host = "127.0.0.1" + pid = 4242 + serverID = "server123" - h.FlushDB(t, r.client) + msg1 = h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "123"}) + msg2 = h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/imgfile"}) - err := r.WriteServerState(ss, ttl) - if err != nil { - t.Errorf("r.WriteServerState returned an error: %v", err) + ttl = 5 * time.Second + ) + + workers := []*base.WorkerInfo{ + { + Host: host, + PID: pid, + ID: msg1.ID.String(), + Type: msg1.Type, + Queue: msg1.Queue, + Payload: msg1.Payload, + Started: time.Now().Add(-10 * time.Second), + }, + { + Host: host, + PID: pid, + ID: msg2.ID.String(), + Type: msg2.Type, + Queue: msg2.Queue, + Payload: msg2.Payload, + Started: time.Now().Add(-2 * time.Minute), + }, } - // Check ServerInfo was written correctly - info := ss.GetInfo() - skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) + serverInfo := base.ServerInfo{ + Host: host, + PID: pid, + ServerID: serverID, + Concurrency: 10, + Queues: map[string]int{"default": 2, "email": 5, "low": 1}, + StrictPriority: false, + Started: time.Now().Add(-10 * time.Minute), + Status: "running", + ActiveWorkerCount: len(workers), + } + + err := r.WriteServerState(&serverInfo, workers, ttl) + if err != nil { + t.Fatalf("r.WriteServerState returned an error: %v", err) + } + + // Check ServerInfo was written correctly. + skey := base.ServerInfoKey(host, pid, serverID) data := r.client.Get(skey).Val() var got base.ServerInfo err = json.Unmarshal([]byte(data), &got) if err != nil { t.Fatalf("could not decode json: %v", err) } - want := base.ServerInfo{ - Host: info.Host, - PID: info.PID, - ServerID: info.ServerID, - Concurrency: concurrency, - Queues: queues, - StrictPriority: false, - Status: "running", - Started: started, - ActiveWorkerCount: 2, - } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(serverInfo, got); diff != "" { t.Errorf("persisted ServerInfo was %v, want %v; (-want,+got)\n%s", - got, want, diff) + got, serverInfo, diff) } - // Check ServerInfo TTL was set correctly + // Check ServerInfo TTL was set correctly. gotTTL := r.client.TTL(skey).Val() if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { t.Errorf("TTL of %q was %v, want %v", skey, gotTTL, ttl) } - // Check ServerInfo key was added to the set correctly - gotProcesses := r.client.ZRange(base.AllServers, 0, -1).Val() - wantProcesses := []string{skey} - if diff := cmp.Diff(wantProcesses, gotProcesses); diff != "" { - t.Errorf("%q contained %v, want %v", base.AllServers, gotProcesses, wantProcesses) + // Check ServerInfo key was added to the set correctly. + gotServerKeys := r.client.ZRange(base.AllServers, 0, -1).Val() + wantServerKeys := []string{skey} + if diff := cmp.Diff(wantServerKeys, gotServerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllServers, gotServerKeys, wantServerKeys) } - // Check WorkersInfo was written correctly - wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) + // Check WorkersInfo was written correctly. + wkey := base.WorkersKey(host, pid, serverID) wdata := r.client.HGetAll(wkey).Val() if len(wdata) != 2 { t.Fatalf("HGETALL %q returned a hash of size %d, want 2", wkey, len(wdata)) } - gotWorkers := make(map[string]*base.WorkerInfo) - for key, val := range wdata { + var gotWorkers []*base.WorkerInfo + for _, val := range wdata { var w base.WorkerInfo if err := json.Unmarshal([]byte(val), &w); err != nil { t.Fatalf("could not unmarshal worker's data: %v", err) } - gotWorkers[key] = &w + gotWorkers = append(gotWorkers, &w) } - wantWorkers := map[string]*base.WorkerInfo{ - msg1.ID.String(): { - Host: info.Host, - PID: info.PID, - ID: msg1.ID, - Type: msg1.Type, - Queue: msg1.Queue, - Payload: msg1.Payload, - Started: w1Started, - }, - msg2.ID.String(): { - Host: info.Host, - PID: info.PID, - ID: msg2.ID, - Type: msg2.Type, - Queue: msg2.Queue, - Payload: msg2.Payload, - Started: w2Started, - }, - } - if diff := cmp.Diff(wantWorkers, gotWorkers); diff != "" { + if diff := cmp.Diff(workers, gotWorkers, h.SortWorkerInfoOpt); diff != "" { t.Errorf("persisted workers info was %v, want %v; (-want,+got)\n%s", - gotWorkers, wantWorkers, diff) + gotWorkers, workers, diff) } - // Check WorkersInfo TTL was set correctly + // Check WorkersInfo TTL was set correctly. gotTTL = r.client.TTL(wkey).Val() - if !cmp.Equal(ttl, gotTTL, timeCmpOpt) { + if !cmp.Equal(ttl.Seconds(), gotTTL.Seconds(), cmpopts.EquateApprox(0, 1)) { t.Errorf("TTL of %q was %v, want %v", wkey, gotTTL, ttl) } - // Check WorkersInfo key was added to the set correctly + // Check WorkersInfo key was added to the set correctly. gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() wantWorkerKeys := []string{wkey} if diff := cmp.Diff(wantWorkerKeys, gotWorkerKeys); diff != "" { @@ -1043,51 +1038,96 @@ func TestWriteServerStateWithWorkers(t *testing.T) { func TestClearServerState(t *testing.T) { r := setup(t) - ss := base.NewServerState("127.0.01", 4242, 10, map[string]int{"default": 1}, false) - info := ss.GetInfo() - h.FlushDB(t, r.client) + var ( + host = "127.0.0.1" + pid = 1234 + serverID = "server123" - skey := base.ServerInfoKey(info.Host, info.PID, info.ServerID) - wkey := base.WorkersKey(info.Host, info.PID, info.ServerID) - otherSKey := base.ServerInfoKey("otherhost", 12345, "server98") - otherWKey := base.WorkersKey("otherhost", 12345, "server98") - // Populate the keys. - if err := r.client.Set(skey, "process-info", 0).Err(); err != nil { - t.Fatal(err) + otherHost = "127.0.0.2" + otherPID = 9876 + otherServerID = "server987" + + msg1 = h.NewTaskMessage("send_email", map[string]interface{}{"user_id": "123"}) + msg2 = h.NewTaskMessage("gen_thumbnail", map[string]interface{}{"path": "some/path/to/imgfile"}) + + ttl = 5 * time.Second + ) + + workers1 := []*base.WorkerInfo{ + { + Host: host, + PID: pid, + ID: msg1.ID.String(), + Type: msg1.Type, + Queue: msg1.Queue, + Payload: msg1.Payload, + Started: time.Now().Add(-10 * time.Second), + }, } - if err := r.client.HSet(wkey, "worker-key", "worker-info").Err(); err != nil { - t.Fatal(err) - } - if err := r.client.ZAdd(base.AllServers, &redis.Z{Member: skey}).Err(); err != nil { - t.Fatal(err) - } - if err := r.client.ZAdd(base.AllServers, &redis.Z{Member: otherSKey}).Err(); err != nil { - t.Fatal(err) - } - if err := r.client.ZAdd(base.AllWorkers, &redis.Z{Member: wkey}).Err(); err != nil { - t.Fatal(err) - } - if err := r.client.ZAdd(base.AllWorkers, &redis.Z{Member: otherWKey}).Err(); err != nil { - t.Fatal(err) + serverInfo1 := base.ServerInfo{ + Host: host, + PID: pid, + ServerID: serverID, + Concurrency: 10, + Queues: map[string]int{"default": 2, "email": 5, "low": 1}, + StrictPriority: false, + Started: time.Now().Add(-10 * time.Minute), + Status: "running", + ActiveWorkerCount: len(workers1), } - err := r.ClearServerState(ss) + workers2 := []*base.WorkerInfo{ + { + Host: otherHost, + PID: otherPID, + ID: msg2.ID.String(), + Type: msg2.Type, + Queue: msg2.Queue, + Payload: msg2.Payload, + Started: time.Now().Add(-30 * time.Second), + }, + } + serverInfo2 := base.ServerInfo{ + Host: otherHost, + PID: otherPID, + ServerID: otherServerID, + Concurrency: 10, + Queues: map[string]int{"default": 2, "email": 5, "low": 1}, + StrictPriority: false, + Started: time.Now().Add(-15 * time.Minute), + Status: "running", + ActiveWorkerCount: len(workers2), + } + + // Write server and workers data. + if err := r.WriteServerState(&serverInfo1, workers1, ttl); err != nil { + t.Fatalf("could not write server state: %v", err) + } + if err := r.WriteServerState(&serverInfo2, workers2, ttl); err != nil { + t.Fatalf("could not write server state: %v", err) + } + + err := r.ClearServerState(host, pid, serverID) if err != nil { t.Fatalf("(*RDB).ClearServerState failed: %v", err) } - // Check all keys are cleared + skey := base.ServerInfoKey(host, pid, serverID) + wkey := base.WorkersKey(host, pid, serverID) + otherSKey := base.ServerInfoKey(otherHost, otherPID, otherServerID) + otherWKey := base.WorkersKey(otherHost, otherPID, otherServerID) + // Check all keys are cleared. if r.client.Exists(skey).Val() != 0 { t.Errorf("Redis key %q exists", skey) } if r.client.Exists(wkey).Val() != 0 { t.Errorf("Redis key %q exists", wkey) } - gotProcessKeys := r.client.ZRange(base.AllServers, 0, -1).Val() - wantProcessKeys := []string{otherSKey} - if diff := cmp.Diff(wantProcessKeys, gotProcessKeys); diff != "" { - t.Errorf("%q contained %v, want %v", base.AllServers, gotProcessKeys, wantProcessKeys) + gotServerKeys := r.client.ZRange(base.AllServers, 0, -1).Val() + wantServerKeys := []string{otherSKey} + if diff := cmp.Diff(wantServerKeys, gotServerKeys); diff != "" { + t.Errorf("%q contained %v, want %v", base.AllServers, gotServerKeys, wantServerKeys) } gotWorkerKeys := r.client.ZRange(base.AllWorkers, 0, -1).Val() wantWorkerKeys := []string{otherWKey} diff --git a/internal/testbroker/testbroker.go b/internal/testbroker/testbroker.go index 0acab70..8bcba61 100644 --- a/internal/testbroker/testbroker.go +++ b/internal/testbroker/testbroker.go @@ -141,22 +141,22 @@ func (tb *TestBroker) CheckAndEnqueue(qnames ...string) error { return tb.real.CheckAndEnqueue() } -func (tb *TestBroker) WriteServerState(ss *base.ServerState, ttl time.Duration) error { +func (tb *TestBroker) WriteServerState(info *base.ServerInfo, workers []*base.WorkerInfo, ttl time.Duration) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.WriteServerState(ss, ttl) + return tb.real.WriteServerState(info, workers, ttl) } -func (tb *TestBroker) ClearServerState(ss *base.ServerState) error { +func (tb *TestBroker) ClearServerState(host string, pid int, serverID string) error { tb.mu.Lock() defer tb.mu.Unlock() if tb.sleeping { return errRedisDown } - return tb.real.ClearServerState(ss) + return tb.real.ClearServerState(host, pid, serverID) } func (tb *TestBroker) CancelationPubSub() (*redis.PubSub, error) { diff --git a/processor.go b/processor.go index 109644d..6b080dc 100644 --- a/processor.go +++ b/processor.go @@ -22,8 +22,6 @@ type processor struct { logger *log.Logger broker base.Broker - ss *base.ServerState - handler Handler queueConfig map[string]int @@ -60,6 +58,9 @@ type processor struct { // cancelations is a set of cancel functions for all in-progress tasks. cancelations *base.Cancelations + + starting chan<- *base.TaskMessage + finished chan<- *base.TaskMessage } type retryDelayFunc func(n int, err error, task *Task) time.Duration @@ -67,38 +68,42 @@ type retryDelayFunc func(n int, err error, task *Task) time.Duration type processorParams struct { logger *log.Logger broker base.Broker - ss *base.ServerState retryDelayFunc retryDelayFunc syncCh chan<- *syncRequest cancelations *base.Cancelations + concurrency int + queues map[string]int + strictPriority bool errHandler ErrorHandler shutdownTimeout time.Duration + starting chan<- *base.TaskMessage + finished chan<- *base.TaskMessage } // newProcessor constructs a new processor. func newProcessor(params processorParams) *processor { - info := params.ss.GetInfo() - qcfg := normalizeQueueCfg(info.Queues) + queues := normalizeQueues(params.queues) orderedQueues := []string(nil) - if info.StrictPriority { - orderedQueues = sortByPriority(qcfg) + if params.strictPriority { + orderedQueues = sortByPriority(queues) } return &processor{ logger: params.logger, broker: params.broker, - ss: params.ss, - queueConfig: qcfg, + queueConfig: queues, orderedQueues: orderedQueues, retryDelayFunc: params.retryDelayFunc, syncRequestCh: params.syncCh, cancelations: params.cancelations, errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1), - sema: make(chan struct{}, info.Concurrency), + sema: make(chan struct{}, params.concurrency), done: make(chan struct{}), abort: make(chan struct{}), quit: make(chan struct{}), errHandler: params.errHandler, handler: HandlerFunc(func(ctx context.Context, t *Task) error { return fmt.Errorf("handler not set") }), + starting: params.starting, + finished: params.finished, } } @@ -183,10 +188,10 @@ func (p *processor) exec() { p.requeue(msg) return case p.sema <- struct{}{}: // acquire token - p.ss.AddWorkerStats(msg, time.Now()) + p.starting <- msg go func() { defer func() { - p.ss.DeleteWorkerStats(msg) + p.finished <- msg <-p.sema // release token }() @@ -374,16 +379,15 @@ func (x byPriority) Len() int { return len(x) } func (x byPriority) Less(i, j int) bool { return x[i].priority < x[j].priority } func (x byPriority) Swap(i, j int) { x[i], x[j] = x[j], x[i] } -// normalizeQueueCfg divides priority numbers by their -// greatest common divisor. -func normalizeQueueCfg(queueCfg map[string]int) map[string]int { +// normalizeQueues divides priority numbers by their greatest common divisor. +func normalizeQueues(queues map[string]int) map[string]int { var xs []int - for _, x := range queueCfg { + for _, x := range queues { xs = append(xs, x) } d := gcd(xs...) res := make(map[string]int) - for q, x := range queueCfg { + for q, x := range queues { res[q] = x / d } return res diff --git a/processor_test.go b/processor_test.go index 954fe0f..9ecb8fe 100644 --- a/processor_test.go +++ b/processor_test.go @@ -19,6 +19,18 @@ import ( "github.com/hibiken/asynq/internal/rdb" ) +// fakeHeartbeater receives from starting and finished channels and do nothing. +func fakeHeartbeater(starting, finished <-chan *base.TaskMessage, done <-chan struct{}) { + for { + select { + case <-starting: + case <-finished: + case <-done: + return + } + } +} + func TestProcessorSuccess(t *testing.T) { r := setup(t) rdbClient := rdb.NewRDB(r) @@ -63,16 +75,24 @@ func TestProcessorSuccess(t *testing.T) { processed = append(processed, task) return nil } - ss := base.NewServerState("localhost", 1234, 10, defaultQueueConfig, false) + starting := make(chan *base.TaskMessage) + finished := make(chan *base.TaskMessage) + done := make(chan struct{}) + defer func() { close(done) }() + go fakeHeartbeater(starting, finished, done) p := newProcessor(processorParams{ logger: testLogger, broker: rdbClient, - ss: ss, retryDelayFunc: defaultDelayFunc, syncCh: nil, cancelations: base.NewCancelations(), + concurrency: 10, + queues: defaultQueueConfig, + strictPriority: false, errHandler: nil, shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, }) p.handler = HandlerFunc(handler) @@ -168,16 +188,24 @@ func TestProcessorRetry(t *testing.T) { defer mu.Unlock() n++ } - ss := base.NewServerState("localhost", 1234, 10, defaultQueueConfig, false) + starting := make(chan *base.TaskMessage) + finished := make(chan *base.TaskMessage) + done := make(chan struct{}) + defer func() { close(done) }() + go fakeHeartbeater(starting, finished, done) p := newProcessor(processorParams{ logger: testLogger, broker: rdbClient, - ss: ss, retryDelayFunc: delayFunc, syncCh: nil, cancelations: base.NewCancelations(), + concurrency: 10, + queues: defaultQueueConfig, + strictPriority: false, errHandler: ErrorHandlerFunc(errHandler), shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, }) p.handler = tc.handler @@ -241,16 +269,24 @@ func TestProcessorQueues(t *testing.T) { } for _, tc := range tests { - ss := base.NewServerState("localhost", 1234, 10, tc.queueCfg, false) + starting := make(chan *base.TaskMessage) + finished := make(chan *base.TaskMessage) + done := make(chan struct{}) + defer func() { close(done) }() + go fakeHeartbeater(starting, finished, done) p := newProcessor(processorParams{ logger: testLogger, broker: nil, - ss: ss, retryDelayFunc: defaultDelayFunc, syncCh: nil, cancelations: base.NewCancelations(), + concurrency: 10, + queues: tc.queueCfg, + strictPriority: false, errHandler: nil, shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, }) got := p.queues() if diff := cmp.Diff(tc.want, got, sortOpt); diff != "" { @@ -316,17 +352,24 @@ func TestProcessorWithStrictPriority(t *testing.T) { base.DefaultQueueName: 2, "low": 1, } - // Note: Set concurrency to 1 to make sure tasks are processed one at a time. - ss := base.NewServerState("localhost", 1234, 1 /* concurrency */, queueCfg, true /*strict*/) + starting := make(chan *base.TaskMessage) + finished := make(chan *base.TaskMessage) + done := make(chan struct{}) + defer func() { close(done) }() + go fakeHeartbeater(starting, finished, done) p := newProcessor(processorParams{ logger: testLogger, broker: rdbClient, - ss: ss, retryDelayFunc: defaultDelayFunc, syncCh: nil, cancelations: base.NewCancelations(), + concurrency: 1, // Set concurrency to 1 to make sure tasks are processed one at a time. + queues: queueCfg, + strictPriority: true, errHandler: nil, shutdownTimeout: defaultShutdownTimeout, + starting: starting, + finished: finished, }) p.handler = HandlerFunc(handler) @@ -412,7 +455,7 @@ func TestGCD(t *testing.T) { } } -func TestNormalizeQueueCfg(t *testing.T) { +func TestNormalizeQueues(t *testing.T) { tests := []struct { input map[string]int want map[string]int @@ -462,9 +505,9 @@ func TestNormalizeQueueCfg(t *testing.T) { } for _, tc := range tests { - got := normalizeQueueCfg(tc.input) + got := normalizeQueues(tc.input) if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("normalizeQueueCfg(%v) = %v, want %v; (-want, +got):\n%s", + t.Errorf("normalizeQueues(%v) = %v, want %v; (-want, +got):\n%s", tc.input, got, tc.want, diff) } } diff --git a/server.go b/server.go index e271dfb..06ca84a 100644 --- a/server.go +++ b/server.go @@ -10,7 +10,6 @@ import ( "fmt" "math" "math/rand" - "os" "runtime" "strings" "sync" @@ -34,12 +33,12 @@ import ( // (e.g., queue size reaches a certain limit, or the task has been in the // queue for a certain amount of time). type Server struct { - ss *base.ServerState - logger *log.Logger broker base.Broker + status *base.ServerStatus + // wait group to wait for all goroutines to finish. wg sync.WaitGroup scheduler *scheduler @@ -283,15 +282,11 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { } logger.SetLevel(toInternalLogLevel(loglevel)) - host, err := os.Hostname() - if err != nil { - host = "unknown-host" - } - pid := os.Getpid() - rdb := rdb.NewRDB(createRedisClient(r)) - ss := base.NewServerState(host, pid, n, queues, cfg.StrictPriority) + starting := make(chan *base.TaskMessage, n) + finished := make(chan *base.TaskMessage, n) syncCh := make(chan *syncRequest) + status := base.NewServerStatus(base.StatusIdle) cancels := base.NewCancelations() syncer := newSyncer(syncerParams{ @@ -300,10 +295,15 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { interval: 5 * time.Second, }) heartbeater := newHeartbeater(heartbeaterParams{ - logger: logger, - broker: rdb, - serverState: ss, - interval: 5 * time.Second, + logger: logger, + broker: rdb, + interval: 5 * time.Second, + concurrency: n, + queues: queues, + strictPriority: cfg.StrictPriority, + status: status, + starting: starting, + finished: finished, }) scheduler := newScheduler(schedulerParams{ logger: logger, @@ -319,17 +319,21 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { processor := newProcessor(processorParams{ logger: logger, broker: rdb, - ss: ss, retryDelayFunc: delayFunc, syncCh: syncCh, cancelations: cancels, + concurrency: n, + queues: queues, + strictPriority: cfg.StrictPriority, errHandler: cfg.ErrorHandler, shutdownTimeout: shutdownTimeout, + starting: starting, + finished: finished, }) return &Server{ - ss: ss, logger: logger, broker: rdb, + status: status, scheduler: scheduler, processor: processor, syncer: syncer, @@ -390,13 +394,13 @@ func (srv *Server) Start(handler Handler) error { if handler == nil { return fmt.Errorf("asynq: server cannot run with nil handler") } - switch srv.ss.Status() { + switch srv.status.Get() { case base.StatusRunning: return fmt.Errorf("asynq: the server is already running") case base.StatusStopped: return ErrServerStopped } - srv.ss.SetStatus(base.StatusRunning) + srv.status.Set(base.StatusRunning) srv.processor.handler = handler srv.logger.Info("Starting processing") @@ -414,7 +418,7 @@ func (srv *Server) Start(handler Handler) error { // active workers to finish processing tasks for duration specified in Config.ShutdownTimeout. // If worker didn't finish processing a task during the timeout, the task will be pushed back to Redis. func (srv *Server) Stop() { - switch srv.ss.Status() { + switch srv.status.Get() { case base.StatusIdle, base.StatusStopped: // server is not running, do nothing and return. return @@ -424,6 +428,7 @@ func (srv *Server) Stop() { // Note: The order of termination is important. // Sender goroutines should be terminated before the receiver goroutines. // processor -> syncer (via syncCh) + // processor -> heartbeater (via starting, finished channels) srv.scheduler.terminate() srv.processor.terminate() srv.syncer.terminate() @@ -433,7 +438,7 @@ func (srv *Server) Stop() { srv.wg.Wait() srv.broker.Close() - srv.ss.SetStatus(base.StatusStopped) + srv.status.Set(base.StatusStopped) srv.logger.Info("Exiting") } @@ -443,6 +448,6 @@ func (srv *Server) Stop() { func (srv *Server) Quiet() { srv.logger.Info("Stopping processor") srv.processor.stop() - srv.ss.SetStatus(base.StatusQuiet) + srv.status.Set(base.StatusQuiet) srv.logger.Info("Processor stopped") }