mirror of
https://github.com/soheilhy/cmux.git
synced 2025-01-18 18:56:26 +08:00
SetReadDeadline for Matching
This commit is contained in:
parent
b26951527b
commit
e132036cce
@ -3,6 +3,7 @@ language: go
|
||||
go:
|
||||
- 1.5
|
||||
- 1.6
|
||||
- 1.7
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
@ -43,6 +44,10 @@ func (c *mockConn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func discard(l net.Listener) {
|
||||
for {
|
||||
if _, err := l.Accept(); err != nil {
|
||||
|
18
cmux.go
18
cmux.go
@ -19,6 +19,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Matcher matches a connection based on its content.
|
||||
@ -60,6 +61,9 @@ func (e errListenerClosed) Timeout() bool { return false }
|
||||
// listener is closed.
|
||||
var ErrListenerClosed = errListenerClosed("mux: listener closed")
|
||||
|
||||
// for readability of readTimeout
|
||||
var noTimeout time.Duration
|
||||
|
||||
// New instantiates a new connection multiplexer.
|
||||
func New(l net.Listener) CMux {
|
||||
return &cMux{
|
||||
@ -67,6 +71,7 @@ func New(l net.Listener) CMux {
|
||||
bufLen: 1024,
|
||||
errh: func(_ error) bool { return true },
|
||||
donec: make(chan struct{}),
|
||||
readTimeout: noTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,6 +95,8 @@ type CMux interface {
|
||||
Serve() error
|
||||
// HandleError registers an error handler that handles listener errors.
|
||||
HandleError(ErrorHandler)
|
||||
// sets a timeout for the read of matchers
|
||||
SetReadTimeout(time.Duration)
|
||||
}
|
||||
|
||||
type matchersListener struct {
|
||||
@ -103,6 +110,7 @@ type cMux struct {
|
||||
errh ErrorHandler
|
||||
donec chan struct{}
|
||||
sls []matchersListener
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
|
||||
@ -129,6 +137,10 @@ func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
|
||||
return ml
|
||||
}
|
||||
|
||||
func (m *cMux) SetReadTimeout(t time.Duration) {
|
||||
m.readTimeout = t
|
||||
}
|
||||
|
||||
func (m *cMux) Serve() error {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@ -163,11 +175,17 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
muc := newMuxConn(c)
|
||||
if m.readTimeout > noTimeout {
|
||||
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
|
||||
}
|
||||
for _, sl := range m.sls {
|
||||
for _, s := range sl.ss {
|
||||
matched := s(muc.Conn, muc.startSniffing())
|
||||
if matched {
|
||||
muc.doneSniffing()
|
||||
if m.readTimeout > noTimeout {
|
||||
_ = c.SetReadDeadline(time.Time{})
|
||||
}
|
||||
select {
|
||||
case sl.l.connc <- muc:
|
||||
case <-donec:
|
||||
|
85
cmux_test.go
85
cmux_test.go
@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
@ -73,14 +74,17 @@ func (l *chanListener) Accept() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func testListener(t *testing.T) (net.Listener, func()) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
l, err := net.Listen("tcp4", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var once sync.Once
|
||||
return l, func() {
|
||||
once.Do(func() {
|
||||
if err := l.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,6 +185,84 @@ func runTestRPCClient(t *testing.T, addr net.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
handleHttp1Close = 1
|
||||
handleHttp1Request = 2
|
||||
handleAnyClose = 3
|
||||
handleAnyRequest = 4
|
||||
)
|
||||
|
||||
func TestTimeout(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
lis, Close := testListener(t)
|
||||
defer Close()
|
||||
result := make(chan int, 5)
|
||||
testDuration := time.Millisecond * 100
|
||||
m := New(lis)
|
||||
m.SetReadTimeout(testDuration)
|
||||
http1 := m.Match(HTTP1Fast())
|
||||
any := m.Match(Any())
|
||||
go func() {
|
||||
_ = m.Serve()
|
||||
}()
|
||||
go func() {
|
||||
con, err := http1.Accept()
|
||||
if err != nil {
|
||||
result <- handleHttp1Close
|
||||
} else {
|
||||
_, _ = con.Write([]byte("http1"))
|
||||
_ = con.Close()
|
||||
result <- handleHttp1Request
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
con, err := any.Accept()
|
||||
if err != nil {
|
||||
result <- handleAnyClose
|
||||
} else {
|
||||
_, _ = con.Write([]byte("any"))
|
||||
_ = con.Close()
|
||||
result <- handleAnyRequest
|
||||
}
|
||||
}()
|
||||
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
|
||||
client, err := net.Dial("tcp", lis.Addr().String())
|
||||
if err != nil {
|
||||
log.Fatal("testTimeout client failed: ", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
}()
|
||||
time.Sleep(testDuration / 2)
|
||||
if len(result) != 0 {
|
||||
log.Print("tcp ")
|
||||
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
|
||||
}
|
||||
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
|
||||
buffer := make([]byte, 10)
|
||||
rl, err := client.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatal("testTimeout failed: client error: ", err, rl)
|
||||
}
|
||||
Close()
|
||||
if rl != 3 {
|
||||
log.Print("testTimeout failed: response from wrong sevice ", rl)
|
||||
}
|
||||
if string(buffer[0:3]) != "any" {
|
||||
log.Print("testTimeout failed: response from wrong sevice ")
|
||||
}
|
||||
time.Sleep(testDuration * 2)
|
||||
if len(result) != 2 {
|
||||
t.Fatal("testTimeout failed: accepted to less: ", len(result))
|
||||
}
|
||||
if a := <-result; a != handleAnyRequest {
|
||||
t.Fatal("testTimeout failed: any rule did not match")
|
||||
}
|
||||
if a := <-result; a != handleHttp1Close {
|
||||
t.Fatal("testTimeout failed: no close an http rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@ -439,6 +521,7 @@ func interestingGoroutines() (gs []string) {
|
||||
}
|
||||
|
||||
if stack == "" ||
|
||||
strings.Contains(stack, "main.main()") ||
|
||||
strings.Contains(stack, "testing.Main(") ||
|
||||
strings.Contains(stack, "runtime.goexit") ||
|
||||
strings.Contains(stack, "created by runtime.gc") ||
|
||||
|
Loading…
Reference in New Issue
Block a user