From 25832e5e952cea0256c047650efe3826e6464f60 Mon Sep 17 00:00:00 2001 From: Ken Hibino Date: Wed, 12 Jan 2022 09:10:56 -0800 Subject: [PATCH] Fix bug related to concurrently executing server state changes --- heartbeat.go | 10 ++-- heartbeat_test.go | 18 +++++--- internal/base/base.go | 62 ------------------------- internal/base/base_test.go | 24 ---------- scheduler.go | 41 +++++++++++++---- server.go | 94 ++++++++++++++++++++++++++++++-------- 6 files changed, 124 insertions(+), 125 deletions(-) diff --git a/heartbeat.go b/heartbeat.go index fd695ae..4921031 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -41,7 +41,7 @@ type heartbeater struct { workers map[string]*workerInfo // state is shared with other goroutine but is concurrency safe. - state *base.ServerState + state *serverState // channels to receive updates on active workers. starting <-chan *workerInfo @@ -55,7 +55,7 @@ type heartbeaterParams struct { concurrency int queues map[string]int strictPriority bool - state *base.ServerState + state *serverState starting <-chan *workerInfo finished <-chan *base.TaskMessage } @@ -135,6 +135,10 @@ func (h *heartbeater) start(wg *sync.WaitGroup) { } func (h *heartbeater) beat() { + h.state.mu.Lock() + srvStatus := h.state.value.String() + h.state.mu.Unlock() + info := base.ServerInfo{ Host: h.host, PID: h.pid, @@ -142,7 +146,7 @@ func (h *heartbeater) beat() { Concurrency: h.concurrency, Queues: h.queues, StrictPriority: h.strictPriority, - Status: h.state.String(), + Status: srvStatus, Started: h.started, ActiveWorkerCount: len(h.workers), } diff --git a/heartbeat_test.go b/heartbeat_test.go index 518c5b0..4004072 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -38,7 +38,7 @@ func TestHeartbeater(t *testing.T) { for _, tc := range tests { h.FlushDB(t, r) - state := base.NewServerState() + srvState := &serverState{} hb := newHeartbeater(heartbeaterParams{ logger: testLogger, broker: rdbClient, @@ -46,7 +46,7 @@ func TestHeartbeater(t *testing.T) { concurrency: tc.concurrency, queues: tc.queues, strictPriority: false, - state: state, + state: srvState, starting: make(chan *workerInfo), finished: make(chan *base.TaskMessage), }) @@ -55,7 +55,10 @@ func TestHeartbeater(t *testing.T) { hb.host = tc.host hb.pid = tc.pid - state.Set(base.StateActive) + srvState.mu.Lock() + srvState.value = srvStateActive // simulating Server.Start + srvState.mu.Unlock() + var wg sync.WaitGroup hb.start(&wg) @@ -90,8 +93,10 @@ func TestHeartbeater(t *testing.T) { continue } - // status change - state.Set(base.StateClosed) + // server state change; simulating Server.Shutdown + srvState.mu.Lock() + srvState.value = srvStateClosed + srvState.mu.Unlock() // allow for heartbeater to write to redis time.Sleep(tc.interval * 2) @@ -131,8 +136,7 @@ func TestHeartbeaterWithRedisDown(t *testing.T) { r := rdb.NewRDB(setup(t)) defer r.Close() testBroker := testbroker.NewTestBroker(r) - state := base.NewServerState() - state.Set(base.StateActive) + state := &serverState{value: srvStateActive} hb := newHeartbeater(heartbeaterParams{ logger: testLogger, broker: testBroker, diff --git a/internal/base/base.go b/internal/base/base.go index f5a7db5..88078d1 100644 --- a/internal/base/base.go +++ b/internal/base/base.go @@ -318,68 +318,6 @@ type Z struct { Score int64 } -// ServerState represents state of a server. -// ServerState methods are concurrency safe. -type ServerState struct { - mu sync.Mutex - val ServerStateValue -} - -// NewServerState returns a new state instance. -// Initial state is set to StateNew. -func NewServerState() *ServerState { - return &ServerState{val: StateNew} -} - -type ServerStateValue int - -const ( - // StateNew represents a new server. Server begins in - // this state and then transition to StatusActive when - // Start or Run is callled. - StateNew ServerStateValue = iota - - // StateActive indicates the server is up and active. - StateActive - - // StateStopped indicates the server is up but no longer processing new tasks. - StateStopped - - // StateClosed indicates the server has been shutdown. - StateClosed -) - -var serverStates = []string{ - "new", - "active", - "stopped", - "closed", -} - -func (s *ServerState) String() string { - s.mu.Lock() - defer s.mu.Unlock() - if StateNew <= s.val && s.val <= StateClosed { - return serverStates[s.val] - } - return "unknown status" -} - -// Get returns the status value. -func (s *ServerState) Get() ServerStateValue { - s.mu.Lock() - v := s.val - s.mu.Unlock() - return v -} - -// Set sets the status value. -func (s *ServerState) Set(v ServerStateValue) { - s.mu.Lock() - s.val = v - s.mu.Unlock() -} - // ServerInfo holds information about a running server. type ServerInfo struct { Host string diff --git a/internal/base/base_test.go b/internal/base/base_test.go index a3f8c41..209c726 100644 --- a/internal/base/base_test.go +++ b/internal/base/base_test.go @@ -583,30 +583,6 @@ func TestSchedulerEnqueueEventEncoding(t *testing.T) { } } -// Test for status being accessed by multiple goroutines. -// Run with -race flag to check for data race. -func TestStatusConcurrentAccess(t *testing.T) { - status := NewServerState() - - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - status.Get() - _ = status.String() - }() - - wg.Add(1) - go func() { - defer wg.Done() - status.Set(StateClosed) - _ = status.String() - }() - - wg.Wait() -} - // Test for cancelations being accessed by multiple goroutines. // Run with -race flag to check for data race. func TestCancelationsConcurrentAccess(t *testing.T) { diff --git a/scheduler.go b/scheduler.go index 8e2d264..dcb0aa0 100644 --- a/scheduler.go +++ b/scheduler.go @@ -22,8 +22,10 @@ import ( // // Schedulers are safe for concurrent use by multiple goroutines. type Scheduler struct { - id string - state *base.ServerState + id string + + state *serverState + logger *log.Logger client *Client rdb *rdb.RDB @@ -66,7 +68,7 @@ func NewScheduler(r RedisConnOpt, opts *SchedulerOpts) *Scheduler { return &Scheduler{ id: generateSchedulerID(), - state: base.NewServerState(), + state: &serverState{value: srvStateNew}, logger: logger, client: NewClient(r), rdb: rdb.NewRDB(c), @@ -193,23 +195,43 @@ func (s *Scheduler) Run() error { // Start starts the scheduler. // It returns an error if the scheduler is already running or has been shutdown. func (s *Scheduler) Start() error { - switch s.state.Get() { - case base.StateActive: - return fmt.Errorf("asynq: the scheduler is already running") - case base.StateClosed: - return fmt.Errorf("asynq: the scheduler has already been stopped") + if err := s.start(); err != nil { + return err } s.logger.Info("Scheduler starting") s.logger.Infof("Scheduler timezone is set to %v", s.location) s.cron.Start() s.wg.Add(1) go s.runHeartbeater() - s.state.Set(base.StateActive) + return nil +} + +// Checks server state and returns an error if pre-condition is not met. +// Otherwise it sets the server state to active. +func (s *Scheduler) start() error { + s.state.mu.Lock() + defer s.state.mu.Unlock() + switch s.state.value { + case srvStateActive: + return fmt.Errorf("asynq: the scheduler is already running") + case srvStateClosed: + return fmt.Errorf("asynq: the scheduler has already been stopped") + } + s.state.value = srvStateActive return nil } // Shutdown stops and shuts down the scheduler. func (s *Scheduler) Shutdown() { + s.state.mu.Lock() + if s.state.value == srvStateNew || s.state.value == srvStateClosed { + // scheduler is not running, do nothing and return. + s.state.mu.Unlock() + return + } + s.state.value = srvStateClosed + s.state.mu.Unlock() + s.logger.Info("Scheduler shutting down") close(s.done) // signal heartbeater to stop ctx := s.cron.Stop() @@ -219,7 +241,6 @@ func (s *Scheduler) Shutdown() { s.clearHistory() s.client.Close() s.rdb.Close() - s.state.Set(base.StateClosed) s.logger.Info("Scheduler stopped") } diff --git a/server.go b/server.go index cdc14c7..1cf5bed 100644 --- a/server.go +++ b/server.go @@ -38,7 +38,7 @@ type Server struct { broker base.Broker - state *base.ServerState + state *serverState // wait group to wait for all goroutines to finish. wg sync.WaitGroup @@ -52,6 +52,43 @@ type Server struct { janitor *janitor } +type serverState struct { + mu sync.Mutex + value serverStateValue +} + +type serverStateValue int + +const ( + // StateNew represents a new server. Server begins in + // this state and then transition to StatusActive when + // Start or Run is callled. + srvStateNew serverStateValue = iota + + // StateActive indicates the server is up and active. + srvStateActive + + // StateStopped indicates the server is up but no longer processing new tasks. + srvStateStopped + + // StateClosed indicates the server has been shutdown. + srvStateClosed +) + +var serverStates = []string{ + "new", + "active", + "stopped", + "closed", +} + +func (s serverStateValue) String() string { + if srvStateNew <= s && s <= srvStateClosed { + return serverStates[s] + } + return "unknown status" +} + // Config specifies the server's background-task processing behavior. type Config struct { // Maximum number of concurrent processing of tasks. @@ -351,7 +388,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { starting := make(chan *workerInfo) finished := make(chan *base.TaskMessage) syncCh := make(chan *syncRequest) - state := base.NewServerState() + srvState := &serverState{value: srvStateNew} cancels := base.NewCancelations() syncer := newSyncer(syncerParams{ @@ -366,7 +403,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { concurrency: n, queues: queues, strictPriority: cfg.StrictPriority, - state: state, + state: srvState, starting: starting, finished: finished, }) @@ -423,7 +460,7 @@ func NewServer(r RedisConnOpt, cfg Config) *Server { return &Server{ logger: logger, broker: rdb, - state: state, + state: srvState, forwarder: forwarder, processor: processor, syncer: syncer, @@ -493,17 +530,11 @@ func (srv *Server) Start(handler Handler) error { if handler == nil { return fmt.Errorf("asynq: server cannot run with nil handler") } - switch srv.state.Get() { - case base.StateActive: - return fmt.Errorf("asynq: the server is already running") - case base.StateStopped: - return fmt.Errorf("asynq: the server is in the stopped state. Waiting for shutdown.") - case base.StateClosed: - return ErrServerClosed - } - srv.state.Set(base.StateActive) srv.processor.handler = handler + if err := srv.start(); err != nil { + return err + } srv.logger.Info("Starting processing") srv.heartbeater.start(&srv.wg) @@ -517,16 +548,36 @@ func (srv *Server) Start(handler Handler) error { return nil } +// Checks server state and returns an error if pre-condition is not met. +// Otherwise it sets the server state to active. +func (srv *Server) start() error { + srv.state.mu.Lock() + defer srv.state.mu.Unlock() + switch srv.state.value { + case srvStateActive: + return fmt.Errorf("asynq: the server is already running") + case srvStateStopped: + return fmt.Errorf("asynq: the server is in the stopped state. Waiting for shutdown.") + case srvStateClosed: + return ErrServerClosed + } + srv.state.value = srvStateActive + return nil +} + // Shutdown gracefully shuts down the server. // It gracefully closes all active workers. The server will wait for // active workers to finish processing tasks for duration specified in Config.ShutdownTimeout. // If worker didn't finish processing a task during the timeout, the task will be pushed back to Redis. func (srv *Server) Shutdown() { - switch srv.state.Get() { - case base.StateNew, base.StateClosed: + srv.state.mu.Lock() + if srv.state.value == srvStateNew || srv.state.value == srvStateClosed { + srv.state.mu.Unlock() // server is not running, do nothing and return. return } + srv.state.value = srvStateClosed + srv.state.mu.Unlock() srv.logger.Info("Starting graceful shutdown") // Note: The order of shutdown is important. @@ -541,12 +592,9 @@ func (srv *Server) Shutdown() { srv.janitor.shutdown() srv.healthchecker.shutdown() srv.heartbeater.shutdown() - srv.wg.Wait() srv.broker.Close() - srv.state.Set(base.StateClosed) - srv.logger.Info("Exiting") } @@ -556,8 +604,16 @@ func (srv *Server) Shutdown() { // // Stop does not shutdown the server, make sure to call Shutdown before exit. func (srv *Server) Stop() { + srv.state.mu.Lock() + if srv.state.value != srvStateActive { + // Invalid calll to Stop, server can only go from Active state to Stopped state. + srv.state.mu.Unlock() + return + } + srv.state.value = srvStateStopped + srv.state.mu.Unlock() + srv.logger.Info("Stopping processor") srv.processor.stop() - srv.state.Set(base.StateStopped) srv.logger.Info("Processor stopped") }