mirror of
				https://github.com/soheilhy/cmux.git
				synced 2025-10-26 16:26:31 +08:00 
			
		
		
		
	| @@ -4,6 +4,7 @@ import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| @@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) { | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	donec := make(chan struct{}) | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(b.N) | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		c := &mockConn{ | ||||
| 			r: bytes.NewReader(benchHTTPPayload), | ||||
| 		} | ||||
| 		m.serve(c) | ||||
| 		m.serve(c, donec, &wg) | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										34
									
								
								cmux.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								cmux.go
									
									
									
									
									
								
							| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| // Matcher matches a connection based on its content. | ||||
| @@ -48,6 +49,7 @@ func New(l net.Listener) CMux { | ||||
| 		root:   l, | ||||
| 		bufLen: 1024, | ||||
| 		errh:   func(_ error) bool { return true }, | ||||
| 		donec:  make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -74,6 +76,7 @@ type cMux struct { | ||||
| 	root   net.Listener | ||||
| 	bufLen int | ||||
| 	errh   ErrorHandler | ||||
| 	donec  chan struct{} | ||||
| 	sls    []matchersListener | ||||
| } | ||||
|  | ||||
| @@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener { | ||||
| 	ml := muxListener{ | ||||
| 		Listener: m.root, | ||||
| 		connc:    make(chan net.Conn, m.bufLen), | ||||
| 		donec:    make(chan struct{}), | ||||
| 	} | ||||
| 	m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) | ||||
| 	return ml | ||||
| } | ||||
|  | ||||
| func (m *cMux) Serve() error { | ||||
| 	var wg sync.WaitGroup | ||||
|  | ||||
| 	defer func() { | ||||
| 		close(m.donec) | ||||
| 		wg.Wait() | ||||
|  | ||||
| 		for _, sl := range m.sls { | ||||
| 			close(sl.l.donec) | ||||
| 			close(sl.l.connc) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| @@ -103,11 +110,14 @@ func (m *cMux) Serve() error { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		go m.serve(c) | ||||
| 		wg.Add(1) | ||||
| 		go m.serve(c, m.donec, &wg) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m *cMux) serve(c net.Conn) { | ||||
| func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { | ||||
| 	defer wg.Done() | ||||
|  | ||||
| 	muc := newMuxConn(c) | ||||
| 	for _, sl := range m.sls { | ||||
| 		for _, s := range sl.ss { | ||||
| @@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) { | ||||
| 			if matched { | ||||
| 				select { | ||||
| 				case sl.l.connc <- muc: | ||||
| 				case <-sl.l.donec: | ||||
| 					_ = c.Close() | ||||
| 				default: | ||||
| 					select { | ||||
| 					case <-donec: | ||||
| 						_ = c.Close() | ||||
| 					default: | ||||
| 					} | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| @@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool { | ||||
| type muxListener struct { | ||||
| 	net.Listener | ||||
| 	connc chan net.Conn | ||||
| 	donec chan struct{} | ||||
| } | ||||
|  | ||||
| func (l muxListener) Accept() (net.Conn, error) { | ||||
| 	select { | ||||
| 	case c := <-l.connc: | ||||
| 		return c, nil | ||||
| 	case <-l.donec: | ||||
| 	c, ok := <-l.connc | ||||
| 	if !ok { | ||||
| 		return nil, ErrListenerClosed | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
|  | ||||
| // MuxConn wraps a net.Conn and provides transparent sniffing of connection data. | ||||
|   | ||||
							
								
								
									
										76
									
								
								cmux_test.go
									
									
									
									
									
								
							
							
						
						
									
										76
									
								
								cmux_test.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package cmux | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| @@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type chanListener struct { | ||||
| 	net.Listener | ||||
| 	connCh chan net.Conn | ||||
| } | ||||
|  | ||||
| func newChanListener() *chanListener { | ||||
| 	return &chanListener{connCh: make(chan net.Conn, 1)} | ||||
| } | ||||
|  | ||||
| func (l *chanListener) Accept() (net.Conn, error) { | ||||
| 	if c, ok := <-l.connCh; ok { | ||||
| 		return c, nil | ||||
| 	} | ||||
| 	return nil, errors.New("use of closed network connection") | ||||
| } | ||||
|  | ||||
| func testListener(t *testing.T) (net.Listener, func()) { | ||||
| 	l, err := net.Listen("tcp", ":0") | ||||
| 	if err != nil { | ||||
| @@ -227,30 +244,49 @@ func TestErrorHandler(t *testing.T) { | ||||
| 	defer cleanup() | ||||
|  | ||||
| 	var num int | ||||
| 	if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { | ||||
| 		// The connection is simply closed. | ||||
| 		t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount)) | ||||
| 	} | ||||
| 	if atomic.LoadUint32(&errCount) == 0 { | ||||
| 		t.Errorf("expected at least 1 error(s), got none") | ||||
| 	for atomic.LoadUint32(&errCount) == 0 { | ||||
| 		if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { | ||||
| 			// The connection is simply closed. | ||||
| 			t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type closerConn struct { | ||||
| 	net.Conn | ||||
| } | ||||
|  | ||||
| func (c closerConn) Close() error { return nil } | ||||
|  | ||||
| func TestClosed(t *testing.T) { | ||||
| func TestClose(t *testing.T) { | ||||
| 	defer leakCheck(t)() | ||||
| 	mux := &cMux{} | ||||
| 	lis := mux.Match(Any()).(muxListener) | ||||
| 	close(lis.donec) | ||||
| 	mux.serve(closerConn{}) | ||||
| 	_, err := lis.Accept() | ||||
| 	if _, ok := err.(errListenerClosed); !ok { | ||||
| 		t.Errorf("expected errListenerClosed got %v", err) | ||||
| 	errCh := make(chan error) | ||||
| 	defer func() { | ||||
| 		select { | ||||
| 		case err := <-errCh: | ||||
| 			t.Fatal(err) | ||||
| 		default: | ||||
| 		} | ||||
| 	}() | ||||
| 	l := newChanListener() | ||||
|  | ||||
| 	c1, c2 := net.Pipe() | ||||
|  | ||||
| 	muxl := New(l) | ||||
| 	anyl := muxl.Match(Any()) | ||||
|  | ||||
| 	go safeServe(errCh, muxl) | ||||
|  | ||||
| 	l.connCh <- c1 | ||||
|  | ||||
| 	// First connection goes through. | ||||
| 	if _, err := anyl.Accept(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Second connection is sent | ||||
| 	l.connCh <- c2 | ||||
|  | ||||
| 	// Listener is closed. | ||||
| 	close(l.connCh) | ||||
|  | ||||
| 	// Second connection goes through. | ||||
| 	if _, err := anyl.Accept(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user