mirror of
https://github.com/soheilhy/cmux.git
synced 2025-01-19 03:06:07 +08:00
Tweak shutdown behaviour
When the root listener is closed, child listeners will not be closed until all parked connections are served. This prevents losing connections that have been read from. This also allows moving the main test to package cmux_test, but that will happen in a separate change.
This commit is contained in:
parent
9a9119af9d
commit
235d98b021
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,12 +32,15 @@ func BenchmarkCMuxConn(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
b.ResetTimer()
|
donec := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(b.N)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c := &mockConn{
|
c := &mockConn{
|
||||||
r: bytes.NewReader(benchHTTPPayload),
|
r: bytes.NewReader(benchHTTPPayload),
|
||||||
}
|
}
|
||||||
m.serve(c)
|
m.serve(c, donec, &wg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
34
cmux.go
34
cmux.go
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Matcher matches a connection based on its content.
|
// Matcher matches a connection based on its content.
|
||||||
@ -48,6 +49,7 @@ func New(l net.Listener) CMux {
|
|||||||
root: l,
|
root: l,
|
||||||
bufLen: 1024,
|
bufLen: 1024,
|
||||||
errh: func(_ error) bool { return true },
|
errh: func(_ error) bool { return true },
|
||||||
|
donec: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,6 +76,7 @@ type cMux struct {
|
|||||||
root net.Listener
|
root net.Listener
|
||||||
bufLen int
|
bufLen int
|
||||||
errh ErrorHandler
|
errh ErrorHandler
|
||||||
|
donec chan struct{}
|
||||||
sls []matchersListener
|
sls []matchersListener
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,16 +84,20 @@ func (m *cMux) Match(matchers ...Matcher) net.Listener {
|
|||||||
ml := muxListener{
|
ml := muxListener{
|
||||||
Listener: m.root,
|
Listener: m.root,
|
||||||
connc: make(chan net.Conn, m.bufLen),
|
connc: make(chan net.Conn, m.bufLen),
|
||||||
donec: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
|
||||||
return ml
|
return ml
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *cMux) Serve() error {
|
func (m *cMux) Serve() error {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
close(m.donec)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
for _, sl := range m.sls {
|
for _, sl := range m.sls {
|
||||||
close(sl.l.donec)
|
close(sl.l.connc)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -103,11 +110,14 @@ func (m *cMux) Serve() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
go m.serve(c)
|
wg.Add(1)
|
||||||
|
go m.serve(c, m.donec, &wg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *cMux) serve(c net.Conn) {
|
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
muc := newMuxConn(c)
|
muc := newMuxConn(c)
|
||||||
for _, sl := range m.sls {
|
for _, sl := range m.sls {
|
||||||
for _, s := range sl.ss {
|
for _, s := range sl.ss {
|
||||||
@ -116,8 +126,12 @@ func (m *cMux) serve(c net.Conn) {
|
|||||||
if matched {
|
if matched {
|
||||||
select {
|
select {
|
||||||
case sl.l.connc <- muc:
|
case sl.l.connc <- muc:
|
||||||
case <-sl.l.donec:
|
default:
|
||||||
_ = c.Close()
|
select {
|
||||||
|
case <-donec:
|
||||||
|
_ = c.Close()
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -150,16 +164,14 @@ func (m *cMux) handleErr(err error) bool {
|
|||||||
type muxListener struct {
|
type muxListener struct {
|
||||||
net.Listener
|
net.Listener
|
||||||
connc chan net.Conn
|
connc chan net.Conn
|
||||||
donec chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l muxListener) Accept() (net.Conn, error) {
|
func (l muxListener) Accept() (net.Conn, error) {
|
||||||
select {
|
c, ok := <-l.connc
|
||||||
case c := <-l.connc:
|
if !ok {
|
||||||
return c, nil
|
|
||||||
case <-l.donec:
|
|
||||||
return nil, ErrListenerClosed
|
return nil, ErrListenerClosed
|
||||||
}
|
}
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
|
||||||
|
65
cmux_test.go
65
cmux_test.go
@ -1,6 +1,7 @@
|
|||||||
package cmux
|
package cmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
@ -38,6 +39,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type chanListener struct {
|
||||||
|
net.Listener
|
||||||
|
connCh chan net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newChanListener() *chanListener {
|
||||||
|
return &chanListener{connCh: make(chan net.Conn, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *chanListener) Accept() (net.Conn, error) {
|
||||||
|
if c, ok := <-l.connCh; ok {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("use of closed network connection")
|
||||||
|
}
|
||||||
|
|
||||||
func testListener(t *testing.T) (net.Listener, func()) {
|
func testListener(t *testing.T) (net.Listener, func()) {
|
||||||
l, err := net.Listen("tcp", ":0")
|
l, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -235,21 +252,41 @@ func TestErrorHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type closerConn struct {
|
func TestClose(t *testing.T) {
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c closerConn) Close() error { return nil }
|
|
||||||
|
|
||||||
func TestClosed(t *testing.T) {
|
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
mux := &cMux{}
|
errCh := make(chan error)
|
||||||
lis := mux.Match(Any()).(muxListener)
|
defer func() {
|
||||||
close(lis.donec)
|
select {
|
||||||
mux.serve(closerConn{})
|
case err := <-errCh:
|
||||||
_, err := lis.Accept()
|
t.Fatal(err)
|
||||||
if _, ok := err.(errListenerClosed); !ok {
|
default:
|
||||||
t.Errorf("expected errListenerClosed got %v", err)
|
}
|
||||||
|
}()
|
||||||
|
l := newChanListener()
|
||||||
|
|
||||||
|
c1, c2 := net.Pipe()
|
||||||
|
|
||||||
|
muxl := New(l)
|
||||||
|
anyl := muxl.Match(Any())
|
||||||
|
|
||||||
|
go safeServe(errCh, muxl)
|
||||||
|
|
||||||
|
l.connCh <- c1
|
||||||
|
|
||||||
|
// First connection goes through.
|
||||||
|
if _, err := anyl.Accept(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second connection is sent
|
||||||
|
l.connCh <- c2
|
||||||
|
|
||||||
|
// Listener is closed.
|
||||||
|
close(l.connCh)
|
||||||
|
|
||||||
|
// Second connection goes through.
|
||||||
|
if _, err := anyl.Accept(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user