2
0
mirror of https://github.com/soheilhy/cmux.git synced 2025-01-18 18:56:26 +08:00

Merge pull request #11 from tamird/fix-close

Tweak shutdown behaviour
This commit is contained in:
Soheil Hassas Yeganeh 2016-02-23 11:20:27 -05:00
commit d5b9190ea9
3 changed files with 85 additions and 33 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"io"
"net"
"sync"
"testing"
)
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
}
}()
b.ResetTimer()
donec := make(chan struct{})
var wg sync.WaitGroup
wg.Add(b.N)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c := &mockConn{
r: bytes.NewReader(benchHTTPPayload),
}
m.serve(c)
m.serve(c, donec, &wg)
}
}

34
cmux.go
View File

@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net"
"sync"
)
// Matcher matches a connection based on its content.
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
root: l,
bufLen: 1024,
errh: func(_ error) bool { return true },
donec: make(chan struct{}),
}
}
@ -74,6 +76,7 @@ type cMux struct {
root net.Listener
bufLen int
errh ErrorHandler
donec chan struct{}
sls []matchersListener
}
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
ml := muxListener{
Listener: m.root,
connc: make(chan net.Conn, m.bufLen),
donec: make(chan struct{}),
}
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
return ml
}
func (m *cMux) Serve() error {
var wg sync.WaitGroup
defer func() {
close(m.donec)
wg.Wait()
for _, sl := range m.sls {
close(sl.l.donec)
close(sl.l.connc)
}
}()
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
continue
}
go m.serve(c)
wg.Add(1)
go m.serve(c, m.donec, &wg)
}
}
func (m *cMux) serve(c net.Conn) {
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
muc := newMuxConn(c)
for _, sl := range m.sls {
for _, s := range sl.ss {
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
if matched {
select {
case sl.l.connc <- muc:
case <-sl.l.donec:
_ = c.Close()
default:
select {
case <-donec:
_ = c.Close()
default:
}
}
return
}
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
type muxListener struct {
net.Listener
connc chan net.Conn
donec chan struct{}
}
func (l muxListener) Accept() (net.Conn, error) {
select {
case c := <-l.connc:
return c, nil
case <-l.donec:
c, ok := <-l.connc
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.

View File

@ -1,6 +1,7 @@
package cmux
import (
"errors"
"fmt"
"io/ioutil"
"net"
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
}
}
type chanListener struct {
net.Listener
connCh chan net.Conn
}
func newChanListener() *chanListener {
return &chanListener{connCh: make(chan net.Conn, 1)}
}
func (l *chanListener) Accept() (net.Conn, error) {
if c, ok := <-l.connCh; ok {
return c, nil
}
return nil, errors.New("use of closed network connection")
}
func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", ":0")
if err != nil {
@ -227,30 +244,49 @@ func TestErrorHandler(t *testing.T) {
defer cleanup()
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
// The connection is simply closed.
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
}
if atomic.LoadUint32(&errCount) == 0 {
t.Errorf("expected at least 1 error(s), got none")
for atomic.LoadUint32(&errCount) == 0 {
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
// The connection is simply closed.
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
}
}
}
type closerConn struct {
net.Conn
}
func (c closerConn) Close() error { return nil }
func TestClosed(t *testing.T) {
func TestClose(t *testing.T) {
defer leakCheck(t)()
mux := &cMux{}
lis := mux.Match(Any()).(muxListener)
close(lis.donec)
mux.serve(closerConn{})
_, err := lis.Accept()
if _, ok := err.(errListenerClosed); !ok {
t.Errorf("expected errListenerClosed got %v", err)
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l := newChanListener()
c1, c2 := net.Pipe()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
l.connCh <- c1
// First connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
// Second connection is sent
l.connCh <- c2
// Listener is closed.
close(l.connCh)
// Second connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
}