diff --git a/cmux.go b/cmux.go index 62712dd..76a5bae 100644 --- a/cmux.go +++ b/cmux.go @@ -80,7 +80,8 @@ type cMux struct { func (m *cMux) Match(matchers ...Matcher) (l net.Listener) { ml := muxListener{ Listener: m.root, - cch: make(chan net.Conn, m.bufLen), + connc: make(chan net.Conn, m.bufLen), + donec: make(chan struct{}), } m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) return ml @@ -89,7 +90,7 @@ func (m *cMux) Match(matchers ...Matcher) (l net.Listener) { func (m *cMux) Serve() error { defer func() { for _, sl := range m.sls { - close(sl.l.cch) + close(sl.l.donec) } }() @@ -109,14 +110,18 @@ func (m *cMux) Serve() error { func (m *cMux) serve(c net.Conn) { muc := newMuxConn(c) matched := false -outer: for _, sl := range m.sls { for _, s := range sl.ss { matched = s(muc.sniffer()) muc.reset() if matched { - sl.l.cch <- muc - break outer + select { + // TODO(soheil): threre is a possiblity of having unclosed connection. + case sl.l.connc <- muc: + case <-sl.l.donec: + c.Close() + } + return } } } @@ -148,11 +153,12 @@ func (m *cMux) handleErr(err error) bool { type muxListener struct { net.Listener - cch chan net.Conn + connc chan net.Conn + donec chan struct{} } func (l muxListener) Accept() (c net.Conn, err error) { - c, ok := <-l.cch + c, ok := <-l.connc if !ok { return nil, ErrListenerClosed } diff --git a/cmux_test.go b/cmux_test.go index 67df37d..ad40be8 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -168,3 +168,16 @@ func TestErrorHandler(t *testing.T) { t.Error("rpc got a response") } } + +type closerConn struct { + net.Conn +} + +func (c closerConn) Close() error { return nil } + +func TestClosed(t *testing.T) { + mux := &cMux{} + lis := mux.Match(Any()).(muxListener) + close(lis.donec) + mux.serve(closerConn{}) +}