2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-11-10 03:31:52 +08:00

Merge pull request #34 from ekle/master

SetReadDeadline for Matching
This commit is contained in:
Soheil Hassas Yeganeh 2016-09-05 16:05:55 -04:00 committed by GitHub
commit 13f520d62c
4 changed files with 120 additions and 13 deletions

View File

@ -3,6 +3,7 @@ language: go
go: go:
- 1.5 - 1.5
- 1.6 - 1.6
- 1.7
- tip - tip
matrix: matrix:

View File

@ -20,6 +20,7 @@ import (
"net" "net"
"sync" "sync"
"testing" "testing"
"time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@ -43,6 +44,10 @@ func (c *mockConn) Read(b []byte) (n int, err error) {
return c.r.Read(b) return c.r.Read(b)
} }
func (c *mockConn) SetReadDeadline(time.Time) error {
return nil
}
func discard(l net.Listener) { func discard(l net.Listener) {
for { for {
if _, err := l.Accept(); err != nil { if _, err := l.Accept(); err != nil {

36
cmux.go
View File

@ -19,6 +19,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
) )
// Matcher matches a connection based on its content. // Matcher matches a connection based on its content.
@ -60,13 +61,17 @@ func (e errListenerClosed) Timeout() bool { return false }
// listener is closed. // listener is closed.
var ErrListenerClosed = errListenerClosed("mux: listener closed") var ErrListenerClosed = errListenerClosed("mux: listener closed")
// for readability of readTimeout
var noTimeout time.Duration
// New instantiates a new connection multiplexer. // New instantiates a new connection multiplexer.
func New(l net.Listener) CMux { func New(l net.Listener) CMux {
return &cMux{ return &cMux{
root: l, root: l,
bufLen: 1024, bufLen: 1024,
errh: func(_ error) bool { return true }, errh: func(_ error) bool { return true },
donec: make(chan struct{}), donec: make(chan struct{}),
readTimeout: noTimeout,
} }
} }
@ -90,6 +95,8 @@ type CMux interface {
Serve() error Serve() error
// HandleError registers an error handler that handles listener errors. // HandleError registers an error handler that handles listener errors.
HandleError(ErrorHandler) HandleError(ErrorHandler)
// sets a timeout for the read of matchers
SetReadTimeout(time.Duration)
} }
type matchersListener struct { type matchersListener struct {
@ -98,11 +105,12 @@ type matchersListener struct {
} }
type cMux struct { type cMux struct {
root net.Listener root net.Listener
bufLen int bufLen int
errh ErrorHandler errh ErrorHandler
donec chan struct{} donec chan struct{}
sls []matchersListener sls []matchersListener
readTimeout time.Duration
} }
func matchersToMatchWriters(matchers []Matcher) []MatchWriter { func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
@ -129,6 +137,10 @@ func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
return ml return ml
} }
func (m *cMux) SetReadTimeout(t time.Duration) {
m.readTimeout = t
}
func (m *cMux) Serve() error { func (m *cMux) Serve() error {
var wg sync.WaitGroup var wg sync.WaitGroup
@ -163,11 +175,17 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
muc := newMuxConn(c) muc := newMuxConn(c)
if m.readTimeout > noTimeout {
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
}
for _, sl := range m.sls { for _, sl := range m.sls {
for _, s := range sl.ss { for _, s := range sl.ss {
matched := s(muc.Conn, muc.startSniffing()) matched := s(muc.Conn, muc.startSniffing())
if matched { if matched {
muc.doneSniffing() muc.doneSniffing()
if m.readTimeout > noTimeout {
_ = c.SetReadDeadline(time.Time{})
}
select { select {
case sl.l.connc <- muc: case sl.l.connc <- muc:
case <-donec: case <-donec:

View File

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
@ -73,14 +74,17 @@ func (l *chanListener) Accept() (net.Conn, error) {
} }
func testListener(t *testing.T) (net.Listener, func()) { func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp4", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var once sync.Once
return l, func() { return l, func() {
if err := l.Close(); err != nil { once.Do(func() {
t.Fatal(err) if err := l.Close(); err != nil {
} t.Fatal(err)
}
})
} }
} }
@ -181,6 +185,84 @@ func runTestRPCClient(t *testing.T, addr net.Addr) {
} }
} }
const (
handleHttp1Close = 1
handleHttp1Request = 2
handleAnyClose = 3
handleAnyRequest = 4
)
func TestTimeout(t *testing.T) {
defer leakCheck(t)()
lis, Close := testListener(t)
defer Close()
result := make(chan int, 5)
testDuration := time.Millisecond * 100
m := New(lis)
m.SetReadTimeout(testDuration)
http1 := m.Match(HTTP1Fast())
any := m.Match(Any())
go func() {
_ = m.Serve()
}()
go func() {
con, err := http1.Accept()
if err != nil {
result <- handleHttp1Close
} else {
_, _ = con.Write([]byte("http1"))
_ = con.Close()
result <- handleHttp1Request
}
}()
go func() {
con, err := any.Accept()
if err != nil {
result <- handleAnyClose
} else {
_, _ = con.Write([]byte("any"))
_ = con.Close()
result <- handleAnyRequest
}
}()
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
client, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
log.Fatal("testTimeout client failed: ", err)
}
defer func() {
_ = client.Close()
}()
time.Sleep(testDuration / 2)
if len(result) != 0 {
log.Print("tcp ")
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
}
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
buffer := make([]byte, 10)
rl, err := client.Read(buffer)
if err != nil {
t.Fatal("testTimeout failed: client error: ", err, rl)
}
Close()
if rl != 3 {
log.Print("testTimeout failed: response from wrong sevice ", rl)
}
if string(buffer[0:3]) != "any" {
log.Print("testTimeout failed: response from wrong sevice ")
}
time.Sleep(testDuration * 2)
if len(result) != 2 {
t.Fatal("testTimeout failed: accepted to less: ", len(result))
}
if a := <-result; a != handleAnyRequest {
t.Fatal("testTimeout failed: any rule did not match")
}
if a := <-result; a != handleHttp1Close {
t.Fatal("testTimeout failed: no close an http rule")
}
}
func TestRead(t *testing.T) { func TestRead(t *testing.T) {
defer leakCheck(t)() defer leakCheck(t)()
errCh := make(chan error) errCh := make(chan error)
@ -439,6 +521,7 @@ func interestingGoroutines() (gs []string) {
} }
if stack == "" || if stack == "" ||
strings.Contains(stack, "main.main()") ||
strings.Contains(stack, "testing.Main(") || strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "created by runtime.gc") ||