mirror of
https://github.com/soheilhy/cmux.git
synced 2024-11-14 11:31:28 +08:00
commit
d5b9190ea9
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
b.ResetTimer()
|
donec := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(b.N)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c := &mockConn{
|
c := &mockConn{
|
||||||
r: bytes.NewReader(benchHTTPPayload),
|
r: bytes.NewReader(benchHTTPPayload),
|
||||||
}
|
}
|
||||||
m.serve(c)
|
m.serve(c, donec, &wg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
32
cmux.go
32
cmux.go
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Matcher matches a connection based on its content.
|
// Matcher matches a connection based on its content.
|
||||||
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
|
|||||||
root: l,
|
root: l,
|
||||||
bufLen: 1024,
|
bufLen: 1024,
|
||||||
errh: func(_ error) bool { return true },
|
errh: func(_ error) bool { return true },
|
||||||
|
donec: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,6 +76,7 @@ type cMux struct {
|
|||||||
root net.Listener
|
root net.Listener
|
||||||
bufLen int
|
bufLen int
|
||||||
errh ErrorHandler
|
errh ErrorHandler
|
||||||
|
donec chan struct{}
|
||||||
sls []matchersListener
|
sls []matchersListener
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
|
|||||||
ml := muxListener{
|
ml := muxListener{
|
||||||
Listener: m.root,
|
Listener: m.root,
|
||||||
connc: make(chan net.Conn, m.bufLen),
|
connc: make(chan net.Conn, m.bufLen),
|
||||||
donec: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
||||||
return ml
|
return ml
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *cMux) Serve() error {
|
func (m *cMux) Serve() error {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
close(m.donec)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
for _, sl := range m.sls {
|
for _, sl := range m.sls {
|
||||||
close(sl.l.donec)
|
close(sl.l.connc)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
go m.serve(c)
|
wg.Add(1)
|
||||||
|
go m.serve(c, m.donec, &wg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *cMux) serve(c net.Conn) {
|
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
muc := newMuxConn(c)
|
muc := newMuxConn(c)
|
||||||
for _, sl := range m.sls {
|
for _, sl := range m.sls {
|
||||||
for _, s := range sl.ss {
|
for _, s := range sl.ss {
|
||||||
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
|
|||||||
if matched {
|
if matched {
|
||||||
select {
|
select {
|
||||||
case sl.l.connc <- muc:
|
case sl.l.connc <- muc:
|
||||||
case <-sl.l.donec:
|
default:
|
||||||
|
select {
|
||||||
|
case <-donec:
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
|
|||||||
type muxListener struct {
|
type muxListener struct {
|
||||||
net.Listener
|
net.Listener
|
||||||
connc chan net.Conn
|
connc chan net.Conn
|
||||||
donec chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l muxListener) Accept() (net.Conn, error) {
|
func (l muxListener) Accept() (net.Conn, error) {
|
||||||
select {
|
c, ok := <-l.connc
|
||||||
case c := <-l.connc:
|
if !ok {
|
||||||
return c, nil
|
|
||||||
case <-l.donec:
|
|
||||||
return nil, ErrListenerClosed
|
return nil, ErrListenerClosed
|
||||||
}
|
}
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
||||||
|
68
cmux_test.go
68
cmux_test.go
@ -1,6 +1,7 @@
|
|||||||
package cmux
|
package cmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type chanListener struct {
|
||||||
|
net.Listener
|
||||||
|
connCh chan net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newChanListener() *chanListener {
|
||||||
|
return &chanListener{connCh: make(chan net.Conn, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *chanListener) Accept() (net.Conn, error) {
|
||||||
|
if c, ok := <-l.connCh; ok {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("use of closed network connection")
|
||||||
|
}
|
||||||
|
|
||||||
func testListener(t *testing.T) (net.Listener, func()) {
|
func testListener(t *testing.T) (net.Listener, func()) {
|
||||||
l, err := net.Listen("tcp", ":0")
|
l, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -227,30 +244,49 @@ func TestErrorHandler(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
var num int
|
var num int
|
||||||
|
for atomic.LoadUint32(&errCount) == 0 {
|
||||||
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
|
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
|
||||||
// The connection is simply closed.
|
// The connection is simply closed.
|
||||||
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
|
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(&errCount) == 0 {
|
|
||||||
t.Errorf("expected at least 1 error(s), got none")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type closerConn struct {
|
func TestClose(t *testing.T) {
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c closerConn) Close() error { return nil }
|
|
||||||
|
|
||||||
func TestClosed(t *testing.T) {
|
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
mux := &cMux{}
|
errCh := make(chan error)
|
||||||
lis := mux.Match(Any()).(muxListener)
|
defer func() {
|
||||||
close(lis.donec)
|
select {
|
||||||
mux.serve(closerConn{})
|
case err := <-errCh:
|
||||||
_, err := lis.Accept()
|
t.Fatal(err)
|
||||||
if _, ok := err.(errListenerClosed); !ok {
|
default:
|
||||||
t.Errorf("expected errListenerClosed got %v", err)
|
}
|
||||||
|
}()
|
||||||
|
l := newChanListener()
|
||||||
|
|
||||||
|
c1, c2 := net.Pipe()
|
||||||
|
|
||||||
|
muxl := New(l)
|
||||||
|
anyl := muxl.Match(Any())
|
||||||
|
|
||||||
|
go safeServe(errCh, muxl)
|
||||||
|
|
||||||
|
l.connCh <- c1
|
||||||
|
|
||||||
|
// First connection goes through.
|
||||||
|
if _, err := anyl.Accept(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second connection is sent
|
||||||
|
l.connCh <- c2
|
||||||
|
|
||||||
|
// Listener is closed.
|
||||||
|
close(l.connCh)
|
||||||
|
|
||||||
|
// Second connection goes through.
|
||||||
|
if _, err := anyl.Accept(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user