diff --git a/cmux.go b/cmux.go index 76a5bae..8f268e4 100644 --- a/cmux.go +++ b/cmux.go @@ -158,11 +158,12 @@ type muxListener struct { } func (l muxListener) Accept() (c net.Conn, err error) { - c, ok := <-l.connc - if !ok { + select { + case c = <-l.connc: + return c, nil + case <-l.donec: return nil, ErrListenerClosed } - return c, nil } type MuxConn struct { diff --git a/cmux_test.go b/cmux_test.go index ad40be8..6b950ce 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -180,4 +180,8 @@ func TestClosed(t *testing.T) { lis := mux.Match(Any()).(muxListener) close(lis.donec) mux.serve(closerConn{}) + _, err := lis.Accept() + if _, ok := err.(errListenerClosed); !ok { + t.Errorf("expected errListenerClosed got %v", err) + } }