diff --git a/cmux.go b/cmux.go index 10a8932..3e899e7 100644 --- a/cmux.go +++ b/cmux.go @@ -98,6 +98,10 @@ func (m *cMux) Serve() error { for _, sl := range m.sls { close(sl.l.connc) + // Drain the connections enqueued for the listener. + for c := range sl.l.connc { + _ = c.Close() + } } }() @@ -126,12 +130,8 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { if matched { select { case sl.l.connc <- muc: - default: - select { - case <-donec: - _ = c.Close() - default: - } + case <-donec: + _ = c.Close() } return } diff --git a/cmux_test.go b/cmux_test.go index d7ddca4..2dde655 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -262,17 +262,17 @@ func TestHTTP2(t *testing.T) { }) h2l := muxl.Match(HTTP2()) go safeServe(errCh, muxl) + muxedConn, err := h2l.Accept() close(l.connCh) - if muxedConn, err := h2l.Accept(); err != nil { + if err != nil { t.Fatal(err) - } else { - var b [len(http2.ClientPreface)]byte - if _, err := muxedConn.Read(b[:]); err != io.EOF { - t.Fatal(err) - } - if string(b[:]) != http2.ClientPreface { - t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) - } + } + var b [len(http2.ClientPreface)]byte + if _, err := muxedConn.Read(b[:]); err != io.EOF { + t.Fatal(err) + } + if string(b[:]) != http2.ClientPreface { + t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) } } @@ -374,9 +374,14 @@ func TestClose(t *testing.T) { // Listener is closed. close(l.connCh) - // Second connection goes through. + // Second connection either goes through or it is closed. if _, err := anyl.Accept(); err != nil { - t.Fatal(err) + if err != ErrListenerClosed { + t.Fatal(err) + } + if _, err := c2.Read([]byte{}); err != io.ErrClosedPipe { + t.Fatalf("connection is not closed and is leaked: %v", err) + } } }