2
0
mirror of https://github.com/soheilhy/cmux.git synced 2025-01-19 03:06:07 +08:00

Merge pull request #11 from tamird/fix-close

Tweak shutdown behaviour
This commit is contained in:
Soheil Hassas Yeganeh 2016-02-23 11:20:27 -05:00
commit d5b9190ea9
3 changed files with 85 additions and 33 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
) )
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
} }
}() }()
b.ResetTimer() donec := make(chan struct{})
var wg sync.WaitGroup
wg.Add(b.N)
b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c := &mockConn{ c := &mockConn{
r: bytes.NewReader(benchHTTPPayload), r: bytes.NewReader(benchHTTPPayload),
} }
m.serve(c) m.serve(c, donec, &wg)
} }
} }

34
cmux.go
View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
) )
// Matcher matches a connection based on its content. // Matcher matches a connection based on its content.
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
root: l, root: l,
bufLen: 1024, bufLen: 1024,
errh: func(_ error) bool { return true }, errh: func(_ error) bool { return true },
donec: make(chan struct{}),
} }
} }
@ -74,6 +76,7 @@ type cMux struct {
root net.Listener root net.Listener
bufLen int bufLen int
errh ErrorHandler errh ErrorHandler
donec chan struct{}
sls []matchersListener sls []matchersListener
} }
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
ml := muxListener{ ml := muxListener{
Listener: m.root, Listener: m.root,
connc: make(chan net.Conn, m.bufLen), connc: make(chan net.Conn, m.bufLen),
donec: make(chan struct{}),
} }
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
return ml return ml
} }
func (m *cMux) Serve() error { func (m *cMux) Serve() error {
var wg sync.WaitGroup
defer func() { defer func() {
close(m.donec)
wg.Wait()
for _, sl := range m.sls { for _, sl := range m.sls {
close(sl.l.donec) close(sl.l.connc)
} }
}() }()
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
continue continue
} }
go m.serve(c) wg.Add(1)
go m.serve(c, m.donec, &wg)
} }
} }
func (m *cMux) serve(c net.Conn) { func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
muc := newMuxConn(c) muc := newMuxConn(c)
for _, sl := range m.sls { for _, sl := range m.sls {
for _, s := range sl.ss { for _, s := range sl.ss {
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
if matched { if matched {
select { select {
case sl.l.connc <- muc: case sl.l.connc <- muc:
case <-sl.l.donec: default:
_ = c.Close() select {
case <-donec:
_ = c.Close()
default:
}
} }
return return
} }
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
type muxListener struct { type muxListener struct {
net.Listener net.Listener
connc chan net.Conn connc chan net.Conn
donec chan struct{}
} }
func (l muxListener) Accept() (net.Conn, error) { func (l muxListener) Accept() (net.Conn, error) {
select { c, ok := <-l.connc
case c := <-l.connc: if !ok {
return c, nil
case <-l.donec:
return nil, ErrListenerClosed return nil, ErrListenerClosed
} }
return c, nil
} }
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data. // MuxConn wraps a net.Conn and provides transparent sniffing of connection data.

View File

@ -1,6 +1,7 @@
package cmux package cmux
import ( import (
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
} }
} }
type chanListener struct {
net.Listener
connCh chan net.Conn
}
func newChanListener() *chanListener {
return &chanListener{connCh: make(chan net.Conn, 1)}
}
func (l *chanListener) Accept() (net.Conn, error) {
if c, ok := <-l.connCh; ok {
return c, nil
}
return nil, errors.New("use of closed network connection")
}
func testListener(t *testing.T) (net.Listener, func()) { func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
@ -227,30 +244,49 @@ func TestErrorHandler(t *testing.T) {
defer cleanup() defer cleanup()
var num int var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { for atomic.LoadUint32(&errCount) == 0 {
// The connection is simply closed. if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount)) // The connection is simply closed.
} t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
if atomic.LoadUint32(&errCount) == 0 { }
t.Errorf("expected at least 1 error(s), got none")
} }
} }
type closerConn struct { func TestClose(t *testing.T) {
net.Conn
}
func (c closerConn) Close() error { return nil }
func TestClosed(t *testing.T) {
defer leakCheck(t)() defer leakCheck(t)()
mux := &cMux{} errCh := make(chan error)
lis := mux.Match(Any()).(muxListener) defer func() {
close(lis.donec) select {
mux.serve(closerConn{}) case err := <-errCh:
_, err := lis.Accept() t.Fatal(err)
if _, ok := err.(errListenerClosed); !ok { default:
t.Errorf("expected errListenerClosed got %v", err) }
}()
l := newChanListener()
c1, c2 := net.Pipe()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
l.connCh <- c1
// First connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
// Second connection is sent
l.connCh <- c2
// Listener is closed.
close(l.connCh)
// Second connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
} }
} }