diff --git a/cmux.go b/cmux.go index 5f964b5..848d742 100644 --- a/cmux.go +++ b/cmux.go @@ -113,7 +113,6 @@ type cMux struct { donec chan struct{} sls []matchersListener readTimeout time.Duration - mu sync.Mutex } func matchersToMatchWriters(matchers []Matcher) []MatchWriter { @@ -149,7 +148,7 @@ func (m *cMux) Serve() error { var wg sync.WaitGroup defer func() { - m.closeDoneChanLocked() + m.closeDoneChan() wg.Wait() for _, sl := range m.sls { @@ -162,13 +161,16 @@ func (m *cMux) Serve() error { }() for { + select { + case <-m.donec: + // cmux was closed with cmux.Close() + return nil + default: + // do nothing + } + c, err := m.root.Accept() if err != nil { - select { - case <-m.getDoneChan(): - // cmux was closed with cmux.Close() - return nil - } if !m.handleErr(err) { return err } @@ -197,7 +199,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { } select { case sl.l.connc <- muc: - case <-m.getDoneChan(): + case <-m.donec: _ = c.Close() } return @@ -213,31 +215,15 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { } func (m *cMux) Close() { - m.mu.Lock() - defer m.mu.Unlock() - m.closeDoneChanLocked() + m.closeDoneChan() } -func (m *cMux) getDoneChan() chan struct{} { - m.mu.Lock() - defer m.mu.Unlock() - return m.getDoneChanLocked() -} - -func (m *cMux) getDoneChanLocked() chan struct{} { - if m.donec == nil { - m.donec = make(chan struct{}) - } - return m.donec -} - -func (m *cMux) closeDoneChanLocked() { - ch := m.getDoneChanLocked() +func (m *cMux) closeDoneChan() { select { - case <-ch: + case <-m.donec: // Already closed. Don't close again default: - close(ch) + close(m.donec) } }