2
0
mirror of https://github.com/soheilhy/cmux.git synced 2025-01-19 03:06:07 +08:00

Tweak shutdown behaviour

When the root listener is closed, child listeners will not be closed
until all parked connections are served. This prevents losing
connections that have been read from.

This also allows moving the main test to package cmux_test, but that
will happen in a separate change.
This commit is contained in:
Tamir Duberstein 2015-12-11 14:33:39 -05:00
parent 9a9119af9d
commit 235d98b021
3 changed files with 80 additions and 27 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
) )
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
} }
}() }()
b.ResetTimer() donec := make(chan struct{})
var wg sync.WaitGroup
wg.Add(b.N)
b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c := &mockConn{ c := &mockConn{
r: bytes.NewReader(benchHTTPPayload), r: bytes.NewReader(benchHTTPPayload),
} }
m.serve(c) m.serve(c, donec, &wg)
} }
} }

34
cmux.go
View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
) )
// Matcher matches a connection based on its content. // Matcher matches a connection based on its content.
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
root: l, root: l,
bufLen: 1024, bufLen: 1024,
errh: func(_ error) bool { return true }, errh: func(_ error) bool { return true },
donec: make(chan struct{}),
} }
} }
@ -74,6 +76,7 @@ type cMux struct {
root net.Listener root net.Listener
bufLen int bufLen int
errh ErrorHandler errh ErrorHandler
donec chan struct{}
sls []matchersListener sls []matchersListener
} }
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
ml := muxListener{ ml := muxListener{
Listener: m.root, Listener: m.root,
connc: make(chan net.Conn, m.bufLen), connc: make(chan net.Conn, m.bufLen),
donec: make(chan struct{}),
} }
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
return ml return ml
} }
func (m *cMux) Serve() error { func (m *cMux) Serve() error {
var wg sync.WaitGroup
defer func() { defer func() {
close(m.donec)
wg.Wait()
for _, sl := range m.sls { for _, sl := range m.sls {
close(sl.l.donec) close(sl.l.connc)
} }
}() }()
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
continue continue
} }
go m.serve(c) wg.Add(1)
go m.serve(c, m.donec, &wg)
} }
} }
func (m *cMux) serve(c net.Conn) { func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
muc := newMuxConn(c) muc := newMuxConn(c)
for _, sl := range m.sls { for _, sl := range m.sls {
for _, s := range sl.ss { for _, s := range sl.ss {
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
if matched { if matched {
select { select {
case sl.l.connc <- muc: case sl.l.connc <- muc:
case <-sl.l.donec: default:
_ = c.Close() select {
case <-donec:
_ = c.Close()
default:
}
} }
return return
} }
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
type muxListener struct { type muxListener struct {
net.Listener net.Listener
connc chan net.Conn connc chan net.Conn
donec chan struct{}
} }
func (l muxListener) Accept() (net.Conn, error) { func (l muxListener) Accept() (net.Conn, error) {
select { c, ok := <-l.connc
case c := <-l.connc: if !ok {
return c, nil
case <-l.donec:
return nil, ErrListenerClosed return nil, ErrListenerClosed
} }
return c, nil
} }
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data. // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.

View File

@ -1,6 +1,7 @@
package cmux package cmux
import ( import (
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
} }
} }
type chanListener struct {
net.Listener
connCh chan net.Conn
}
func newChanListener() *chanListener {
return &chanListener{connCh: make(chan net.Conn, 1)}
}
func (l *chanListener) Accept() (net.Conn, error) {
if c, ok := <-l.connCh; ok {
return c, nil
}
return nil, errors.New("use of closed network connection")
}
func testListener(t *testing.T) (net.Listener, func()) { func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
@ -235,21 +252,41 @@ func TestErrorHandler(t *testing.T) {
} }
} }
type closerConn struct { func TestClose(t *testing.T) {
net.Conn
}
func (c closerConn) Close() error { return nil }
func TestClosed(t *testing.T) {
defer leakCheck(t)() defer leakCheck(t)()
mux := &cMux{} errCh := make(chan error)
lis := mux.Match(Any()).(muxListener) defer func() {
close(lis.donec) select {
mux.serve(closerConn{}) case err := <-errCh:
_, err := lis.Accept() t.Fatal(err)
if _, ok := err.(errListenerClosed); !ok { default:
t.Errorf("expected errListenerClosed got %v", err) }
}()
l := newChanListener()
c1, c2 := net.Pipe()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
l.connCh <- c1
// First connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
// Second connection is sent
l.connCh <- c2
// Listener is closed.
close(l.connCh)
// Second connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
} }
} }