diff --git a/.travis.yml b/.travis.yml index 6be7119..da824cb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ language: go go: - 1.5 - 1.6 + - 1.7 - tip matrix: diff --git a/bench_test.go b/bench_test.go index 782b07b..53b0367 100644 --- a/bench_test.go +++ b/bench_test.go @@ -20,6 +20,7 @@ import ( "net" "sync" "testing" + "time" "golang.org/x/net/http2" ) @@ -43,6 +44,10 @@ func (c *mockConn) Read(b []byte) (n int, err error) { return c.r.Read(b) } +func (c *mockConn) SetReadDeadline(time.Time) error { + return nil +} + func discard(l net.Listener) { for { if _, err := l.Accept(); err != nil { diff --git a/cmux.go b/cmux.go index 69e83fb..9de6b0a 100644 --- a/cmux.go +++ b/cmux.go @@ -19,6 +19,7 @@ import ( "io" "net" "sync" + "time" ) // Matcher matches a connection based on its content. @@ -60,13 +61,17 @@ func (e errListenerClosed) Timeout() bool { return false } // listener is closed. var ErrListenerClosed = errListenerClosed("mux: listener closed") +// for readability of readTimeout +var noTimeout time.Duration + // New instantiates a new connection multiplexer. func New(l net.Listener) CMux { return &cMux{ - root: l, - bufLen: 1024, - errh: func(_ error) bool { return true }, - donec: make(chan struct{}), + root: l, + bufLen: 1024, + errh: func(_ error) bool { return true }, + donec: make(chan struct{}), + readTimeout: noTimeout, } } @@ -90,6 +95,8 @@ type CMux interface { Serve() error // HandleError registers an error handler that handles listener errors. HandleError(ErrorHandler) + // sets a timeout for the read of matchers + SetReadTimeout(time.Duration) } type matchersListener struct { @@ -98,11 +105,12 @@ type matchersListener struct { } type cMux struct { - root net.Listener - bufLen int - errh ErrorHandler - donec chan struct{} - sls []matchersListener + root net.Listener + bufLen int + errh ErrorHandler + donec chan struct{} + sls []matchersListener + readTimeout time.Duration } func matchersToMatchWriters(matchers []Matcher) []MatchWriter { @@ -129,6 +137,10 @@ func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { return ml } +func (m *cMux) SetReadTimeout(t time.Duration) { + m.readTimeout = t +} + func (m *cMux) Serve() error { var wg sync.WaitGroup @@ -163,11 +175,17 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { defer wg.Done() muc := newMuxConn(c) + if m.readTimeout > noTimeout { + _ = c.SetReadDeadline(time.Now().Add(m.readTimeout)) + } for _, sl := range m.sls { for _, s := range sl.ss { matched := s(muc.Conn, muc.startSniffing()) if matched { muc.doneSniffing() + if m.readTimeout > noTimeout { + _ = c.SetReadDeadline(time.Time{}) + } select { case sl.l.connc <- muc: case <-donec: diff --git a/cmux_test.go b/cmux_test.go index 676016a..2279ded 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" "net/http" "net/rpc" @@ -73,14 +74,17 @@ func (l *chanListener) Accept() (net.Conn, error) { } func testListener(t *testing.T) (net.Listener, func()) { - l, err := net.Listen("tcp", ":0") + l, err := net.Listen("tcp4", ":0") if err != nil { t.Fatal(err) } + var once sync.Once return l, func() { - if err := l.Close(); err != nil { - t.Fatal(err) - } + once.Do(func() { + if err := l.Close(); err != nil { + t.Fatal(err) + } + }) } } @@ -181,6 +185,84 @@ func runTestRPCClient(t *testing.T, addr net.Addr) { } } +const ( + handleHttp1Close = 1 + handleHttp1Request = 2 + handleAnyClose = 3 + handleAnyRequest = 4 +) + +func TestTimeout(t *testing.T) { + defer leakCheck(t)() + lis, Close := testListener(t) + defer Close() + result := make(chan int, 5) + testDuration := time.Millisecond * 100 + m := New(lis) + m.SetReadTimeout(testDuration) + http1 := m.Match(HTTP1Fast()) + any := m.Match(Any()) + go func() { + _ = m.Serve() + }() + go func() { + con, err := http1.Accept() + if err != nil { + result <- handleHttp1Close + } else { + _, _ = con.Write([]byte("http1")) + _ = con.Close() + result <- handleHttp1Request + } + }() + go func() { + con, err := any.Accept() + if err != nil { + result <- handleAnyClose + } else { + _, _ = con.Write([]byte("any")) + _ = con.Close() + result <- handleAnyRequest + } + }() + time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners + client, err := net.Dial("tcp", lis.Addr().String()) + if err != nil { + log.Fatal("testTimeout client failed: ", err) + } + defer func() { + _ = client.Close() + }() + time.Sleep(testDuration / 2) + if len(result) != 0 { + log.Print("tcp ") + t.Fatal("testTimeout failed: accepted to fast: ", len(result)) + } + _ = client.SetReadDeadline(time.Now().Add(testDuration * 3)) + buffer := make([]byte, 10) + rl, err := client.Read(buffer) + if err != nil { + t.Fatal("testTimeout failed: client error: ", err, rl) + } + Close() + if rl != 3 { + log.Print("testTimeout failed: response from wrong sevice ", rl) + } + if string(buffer[0:3]) != "any" { + log.Print("testTimeout failed: response from wrong sevice ") + } + time.Sleep(testDuration * 2) + if len(result) != 2 { + t.Fatal("testTimeout failed: accepted to less: ", len(result)) + } + if a := <-result; a != handleAnyRequest { + t.Fatal("testTimeout failed: any rule did not match") + } + if a := <-result; a != handleHttp1Close { + t.Fatal("testTimeout failed: no close an http rule") + } +} + func TestRead(t *testing.T) { defer leakCheck(t)() errCh := make(chan error) @@ -439,6 +521,7 @@ func interestingGoroutines() (gs []string) { } if stack == "" || + strings.Contains(stack, "main.main()") || strings.Contains(stack, "testing.Main(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") ||