mirror of
https://github.com/soheilhy/cmux.git
synced 2025-01-19 03:06:07 +08:00
commit
d5b9190ea9
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
|
||||
}
|
||||
}()
|
||||
|
||||
b.ResetTimer()
|
||||
donec := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(b.N)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := &mockConn{
|
||||
r: bytes.NewReader(benchHTTPPayload),
|
||||
}
|
||||
m.serve(c)
|
||||
m.serve(c, donec, &wg)
|
||||
}
|
||||
}
|
||||
|
34
cmux.go
34
cmux.go
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Matcher matches a connection based on its content.
|
||||
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
|
||||
root: l,
|
||||
bufLen: 1024,
|
||||
errh: func(_ error) bool { return true },
|
||||
donec: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,6 +76,7 @@ type cMux struct {
|
||||
root net.Listener
|
||||
bufLen int
|
||||
errh ErrorHandler
|
||||
donec chan struct{}
|
||||
sls []matchersListener
|
||||
}
|
||||
|
||||
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
|
||||
ml := muxListener{
|
||||
Listener: m.root,
|
||||
connc: make(chan net.Conn, m.bufLen),
|
||||
donec: make(chan struct{}),
|
||||
}
|
||||
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
||||
return ml
|
||||
}
|
||||
|
||||
func (m *cMux) Serve() error {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
defer func() {
|
||||
close(m.donec)
|
||||
wg.Wait()
|
||||
|
||||
for _, sl := range m.sls {
|
||||
close(sl.l.donec)
|
||||
close(sl.l.connc)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
|
||||
continue
|
||||
}
|
||||
|
||||
go m.serve(c)
|
||||
wg.Add(1)
|
||||
go m.serve(c, m.donec, &wg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cMux) serve(c net.Conn) {
|
||||
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
muc := newMuxConn(c)
|
||||
for _, sl := range m.sls {
|
||||
for _, s := range sl.ss {
|
||||
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
|
||||
if matched {
|
||||
select {
|
||||
case sl.l.connc <- muc:
|
||||
case <-sl.l.donec:
|
||||
_ = c.Close()
|
||||
default:
|
||||
select {
|
||||
case <-donec:
|
||||
_ = c.Close()
|
||||
default:
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
|
||||
type muxListener struct {
|
||||
net.Listener
|
||||
connc chan net.Conn
|
||||
donec chan struct{}
|
||||
}
|
||||
|
||||
func (l muxListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case c := <-l.connc:
|
||||
return c, nil
|
||||
case <-l.donec:
|
||||
c, ok := <-l.connc
|
||||
if !ok {
|
||||
return nil, ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
||||
|
76
cmux_test.go
76
cmux_test.go
@ -1,6 +1,7 @@
|
||||
package cmux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
|
||||
}
|
||||
}
|
||||
|
||||
type chanListener struct {
|
||||
net.Listener
|
||||
connCh chan net.Conn
|
||||
}
|
||||
|
||||
func newChanListener() *chanListener {
|
||||
return &chanListener{connCh: make(chan net.Conn, 1)}
|
||||
}
|
||||
|
||||
func (l *chanListener) Accept() (net.Conn, error) {
|
||||
if c, ok := <-l.connCh; ok {
|
||||
return c, nil
|
||||
}
|
||||
return nil, errors.New("use of closed network connection")
|
||||
}
|
||||
|
||||
func testListener(t *testing.T) (net.Listener, func()) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
@ -227,30 +244,49 @@ func TestErrorHandler(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
var num int
|
||||
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
|
||||
// The connection is simply closed.
|
||||
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
|
||||
}
|
||||
if atomic.LoadUint32(&errCount) == 0 {
|
||||
t.Errorf("expected at least 1 error(s), got none")
|
||||
for atomic.LoadUint32(&errCount) == 0 {
|
||||
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
|
||||
// The connection is simply closed.
|
||||
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type closerConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c closerConn) Close() error { return nil }
|
||||
|
||||
func TestClosed(t *testing.T) {
|
||||
func TestClose(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
mux := &cMux{}
|
||||
lis := mux.Match(Any()).(muxListener)
|
||||
close(lis.donec)
|
||||
mux.serve(closerConn{})
|
||||
_, err := lis.Accept()
|
||||
if _, ok := err.(errListenerClosed); !ok {
|
||||
t.Errorf("expected errListenerClosed got %v", err)
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
l := newChanListener()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
|
||||
muxl := New(l)
|
||||
anyl := muxl.Match(Any())
|
||||
|
||||
go safeServe(errCh, muxl)
|
||||
|
||||
l.connCh <- c1
|
||||
|
||||
// First connection goes through.
|
||||
if _, err := anyl.Accept(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Second connection is sent
|
||||
l.connCh <- c2
|
||||
|
||||
// Listener is closed.
|
||||
close(l.connCh)
|
||||
|
||||
// Second connection goes through.
|
||||
if _, err := anyl.Accept(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user