diff --git a/cmux.go b/cmux.go index 8040342..5f964b5 100644 --- a/cmux.go +++ b/cmux.go @@ -93,6 +93,8 @@ type CMux interface { // Serve starts multiplexing the listener. Serve blocks and perhaps // should be invoked concurrently within a go routine. Serve() error + // Closes cmux server and stops accepting any connections on listener + Close() // HandleError registers an error handler that handles listener errors. HandleError(ErrorHandler) // sets a timeout for the read of matchers @@ -111,6 +113,7 @@ type cMux struct { donec chan struct{} sls []matchersListener readTimeout time.Duration + mu sync.Mutex } func matchersToMatchWriters(matchers []Matcher) []MatchWriter { @@ -146,7 +149,7 @@ func (m *cMux) Serve() error { var wg sync.WaitGroup defer func() { - close(m.donec) + m.closeDoneChanLocked() wg.Wait() for _, sl := range m.sls { @@ -161,6 +164,11 @@ func (m *cMux) Serve() error { for { c, err := m.root.Accept() if err != nil { + select { + case <-m.getDoneChan(): + // cmux was closed with cmux.Close() + return nil + } if !m.handleErr(err) { return err } @@ -189,7 +197,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { } select { case sl.l.connc <- muc: - case <-donec: + case <-m.getDoneChan(): _ = c.Close() } return @@ -204,6 +212,35 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { } } +func (m *cMux) Close() { + m.mu.Lock() + defer m.mu.Unlock() + m.closeDoneChanLocked() +} + +func (m *cMux) getDoneChan() chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + return m.getDoneChanLocked() +} + +func (m *cMux) getDoneChanLocked() chan struct{} { + if m.donec == nil { + m.donec = make(chan struct{}) + } + return m.donec +} + +func (m *cMux) closeDoneChanLocked() { + ch := m.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again + default: + close(ch) + } +} + func (m *cMux) HandleError(h ErrorHandler) { m.errh = h }