mirror of
https://github.com/soheilhy/cmux.git
synced 2025-10-17 20:58:14 +08:00
Compare commits
12 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6a5d332559 | ||
|
c0f3570a02 | ||
|
b6ec57c1a4 | ||
|
79b9df6ccf | ||
|
210139db95 | ||
|
bf4a8ede9e | ||
|
f661dcfb59 | ||
|
526b64db7a | ||
|
3ac8d3a667 | ||
|
861c99e0fc | ||
|
13f520d62c | ||
|
e132036cce |
@@ -1,8 +1,9 @@
|
|||||||
language: go
|
language: go
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.5
|
|
||||||
- 1.6
|
- 1.6
|
||||||
|
- 1.7
|
||||||
|
- 1.8
|
||||||
- tip
|
- tip
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
# Auto-generated with:
|
# Auto-generated with:
|
||||||
# git log --oneline --pretty=format:'%an <%aE>' | sort -u
|
# git log --oneline --pretty=format:'%an <%aE>' | sort -u
|
||||||
#
|
#
|
||||||
|
Andreas Jaekle <andreas@jaekle.net>
|
||||||
Dmitri Shuralyov <shurcooL@gmail.com>
|
Dmitri Shuralyov <shurcooL@gmail.com>
|
||||||
Ethan Mosbaugh <emosbaugh@gmail.com>
|
Ethan Mosbaugh <emosbaugh@gmail.com>
|
||||||
Soheil Hassas Yeganeh <soheil.h.y@gmail.com>
|
Soheil Hassas Yeganeh <soheil.h.y@gmail.com>
|
||||||
|
@@ -20,6 +20,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
@@ -43,6 +44,10 @@ func (c *mockConn) Read(b []byte) (n int, err error) {
|
|||||||
return c.r.Read(b)
|
return c.r.Read(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *mockConn) SetReadDeadline(time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func discard(l net.Listener) {
|
func discard(l net.Listener) {
|
||||||
for {
|
for {
|
||||||
if _, err := l.Accept(); err != nil {
|
if _, err := l.Accept(); err != nil {
|
||||||
|
@@ -42,6 +42,10 @@ func (s *bufferedReader) Read(p []byte) (int, error) {
|
|||||||
bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
|
bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize])
|
||||||
s.bufferRead += bn
|
s.bufferRead += bn
|
||||||
return bn, s.lastErr
|
return bn, s.lastErr
|
||||||
|
} else if !s.sniffing && s.buffer.Cap() != 0 {
|
||||||
|
// We don't need the buffer anymore.
|
||||||
|
// Reset it to release the internal slice.
|
||||||
|
s.buffer = bytes.Buffer{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is nothing more to return in the sniffed buffer, read from the
|
// If there is nothing more to return in the sniffed buffer, read from the
|
||||||
|
36
cmux.go
36
cmux.go
@@ -19,6 +19,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Matcher matches a connection based on its content.
|
// Matcher matches a connection based on its content.
|
||||||
@@ -60,13 +61,17 @@ func (e errListenerClosed) Timeout() bool { return false }
|
|||||||
// listener is closed.
|
// listener is closed.
|
||||||
var ErrListenerClosed = errListenerClosed("mux: listener closed")
|
var ErrListenerClosed = errListenerClosed("mux: listener closed")
|
||||||
|
|
||||||
|
// for readability of readTimeout
|
||||||
|
var noTimeout time.Duration
|
||||||
|
|
||||||
// New instantiates a new connection multiplexer.
|
// New instantiates a new connection multiplexer.
|
||||||
func New(l net.Listener) CMux {
|
func New(l net.Listener) CMux {
|
||||||
return &cMux{
|
return &cMux{
|
||||||
root: l,
|
root: l,
|
||||||
bufLen: 1024,
|
bufLen: 1024,
|
||||||
errh: func(_ error) bool { return true },
|
errh: func(_ error) bool { return true },
|
||||||
donec: make(chan struct{}),
|
donec: make(chan struct{}),
|
||||||
|
readTimeout: noTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +95,8 @@ type CMux interface {
|
|||||||
Serve() error
|
Serve() error
|
||||||
// HandleError registers an error handler that handles listener errors.
|
// HandleError registers an error handler that handles listener errors.
|
||||||
HandleError(ErrorHandler)
|
HandleError(ErrorHandler)
|
||||||
|
// sets a timeout for the read of matchers
|
||||||
|
SetReadTimeout(time.Duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
type matchersListener struct {
|
type matchersListener struct {
|
||||||
@@ -98,11 +105,12 @@ type matchersListener struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type cMux struct {
|
type cMux struct {
|
||||||
root net.Listener
|
root net.Listener
|
||||||
bufLen int
|
bufLen int
|
||||||
errh ErrorHandler
|
errh ErrorHandler
|
||||||
donec chan struct{}
|
donec chan struct{}
|
||||||
sls []matchersListener
|
sls []matchersListener
|
||||||
|
readTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
|
func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
|
||||||
@@ -129,6 +137,10 @@ func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
|
|||||||
return ml
|
return ml
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *cMux) SetReadTimeout(t time.Duration) {
|
||||||
|
m.readTimeout = t
|
||||||
|
}
|
||||||
|
|
||||||
func (m *cMux) Serve() error {
|
func (m *cMux) Serve() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
@@ -163,11 +175,17 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
muc := newMuxConn(c)
|
muc := newMuxConn(c)
|
||||||
|
if m.readTimeout > noTimeout {
|
||||||
|
_ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
|
||||||
|
}
|
||||||
for _, sl := range m.sls {
|
for _, sl := range m.sls {
|
||||||
for _, s := range sl.ss {
|
for _, s := range sl.ss {
|
||||||
matched := s(muc.Conn, muc.startSniffing())
|
matched := s(muc.Conn, muc.startSniffing())
|
||||||
if matched {
|
if matched {
|
||||||
muc.doneSniffing()
|
muc.doneSniffing()
|
||||||
|
if m.readTimeout > noTimeout {
|
||||||
|
_ = c.SetReadDeadline(time.Time{})
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case sl.l.connc <- muc:
|
case sl.l.connc <- muc:
|
||||||
case <-donec:
|
case <-donec:
|
||||||
|
161
cmux_test.go
161
cmux_test.go
@@ -15,10 +15,12 @@
|
|||||||
package cmux
|
package cmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
@@ -31,6 +33,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -39,7 +42,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func safeServe(errCh chan<- error, muxl CMux) {
|
func safeServe(errCh chan<- error, muxl CMux) {
|
||||||
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") {
|
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed") {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,14 +76,17 @@ func (l *chanListener) Accept() (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testListener(t *testing.T) (net.Listener, func()) {
|
func testListener(t *testing.T) (net.Listener, func()) {
|
||||||
l, err := net.Listen("tcp", ":0")
|
l, err := net.Listen("tcp4", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
var once sync.Once
|
||||||
return l, func() {
|
return l, func() {
|
||||||
if err := l.Close(); err != nil {
|
once.Do(func() {
|
||||||
t.Fatal(err)
|
if err := l.Close(); err != nil {
|
||||||
}
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,6 +187,84 @@ func runTestRPCClient(t *testing.T, addr net.Addr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
handleHTTP1Close = 1
|
||||||
|
handleHTTP1Request = 2
|
||||||
|
handleAnyClose = 3
|
||||||
|
handleAnyRequest = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTimeout(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
lis, Close := testListener(t)
|
||||||
|
defer Close()
|
||||||
|
result := make(chan int, 5)
|
||||||
|
testDuration := time.Millisecond * 100
|
||||||
|
m := New(lis)
|
||||||
|
m.SetReadTimeout(testDuration)
|
||||||
|
http1 := m.Match(HTTP1Fast())
|
||||||
|
any := m.Match(Any())
|
||||||
|
go func() {
|
||||||
|
_ = m.Serve()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
con, err := http1.Accept()
|
||||||
|
if err != nil {
|
||||||
|
result <- handleHTTP1Close
|
||||||
|
} else {
|
||||||
|
_, _ = con.Write([]byte("http1"))
|
||||||
|
_ = con.Close()
|
||||||
|
result <- handleHTTP1Request
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
con, err := any.Accept()
|
||||||
|
if err != nil {
|
||||||
|
result <- handleAnyClose
|
||||||
|
} else {
|
||||||
|
_, _ = con.Write([]byte("any"))
|
||||||
|
_ = con.Close()
|
||||||
|
result <- handleAnyRequest
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
|
||||||
|
client, err := net.Dial("tcp", lis.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("testTimeout client failed: ", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = client.Close()
|
||||||
|
}()
|
||||||
|
time.Sleep(testDuration / 2)
|
||||||
|
if len(result) != 0 {
|
||||||
|
log.Print("tcp ")
|
||||||
|
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
|
||||||
|
}
|
||||||
|
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
|
||||||
|
buffer := make([]byte, 10)
|
||||||
|
rl, err := client.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("testTimeout failed: client error: ", err, rl)
|
||||||
|
}
|
||||||
|
Close()
|
||||||
|
if rl != 3 {
|
||||||
|
log.Print("testTimeout failed: response from wrong sevice ", rl)
|
||||||
|
}
|
||||||
|
if string(buffer[0:3]) != "any" {
|
||||||
|
log.Print("testTimeout failed: response from wrong sevice ")
|
||||||
|
}
|
||||||
|
time.Sleep(testDuration * 2)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatal("testTimeout failed: accepted to less: ", len(result))
|
||||||
|
}
|
||||||
|
if a := <-result; a != handleAnyRequest {
|
||||||
|
t.Fatal("testTimeout failed: any rule did not match")
|
||||||
|
}
|
||||||
|
if a := <-result; a != handleHTTP1Close {
|
||||||
|
t.Fatal("testTimeout failed: no close an http rule")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
func TestRead(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
@@ -312,6 +396,72 @@ func TestHTTP2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
errCh := make(chan error)
|
||||||
|
defer func() {
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatal(err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
name := "name"
|
||||||
|
value := "value"
|
||||||
|
writer, reader := net.Pipe()
|
||||||
|
go func() {
|
||||||
|
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := hpack.NewEncoder(&buf)
|
||||||
|
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
framer := http2.NewFramer(writer, nil)
|
||||||
|
err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||||
|
StreamID: 1,
|
||||||
|
BlockFragment: buf.Bytes(),
|
||||||
|
EndStream: true,
|
||||||
|
EndHeaders: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
l := newChanListener()
|
||||||
|
l.connCh <- reader
|
||||||
|
muxl := New(l)
|
||||||
|
// Register a bogus matcher that only reads one byte.
|
||||||
|
muxl.Match(func(r io.Reader) bool {
|
||||||
|
var b [1]byte
|
||||||
|
_, _ = r.Read(b[:])
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
// Create a matcher that cannot match the response.
|
||||||
|
muxl.Match(HTTP2HeaderField(name, "another"+value))
|
||||||
|
// Then match with the expected field.
|
||||||
|
h2l := muxl.Match(HTTP2HeaderField(name, value))
|
||||||
|
go safeServe(errCh, muxl)
|
||||||
|
muxedConn, err := h2l.Accept()
|
||||||
|
close(l.connCh)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var b [len(http2.ClientPreface)]byte
|
||||||
|
// We have the sniffed buffer first...
|
||||||
|
if _, err := muxedConn.Read(b[:]); err == io.EOF {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(b[:]) != http2.ClientPreface {
|
||||||
|
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPGoRPC(t *testing.T) {
|
func TestHTTPGoRPC(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
@@ -439,6 +589,7 @@ func interestingGoroutines() (gs []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stack == "" ||
|
if stack == "" ||
|
||||||
|
strings.Contains(stack, "main.main()") ||
|
||||||
strings.Contains(stack, "testing.Main(") ||
|
strings.Contains(stack, "testing.Main(") ||
|
||||||
strings.Contains(stack, "runtime.goexit") ||
|
strings.Contains(stack, "runtime.goexit") ||
|
||||||
strings.Contains(stack, "created by runtime.gc") ||
|
strings.Contains(stack, "created by runtime.gc") ||
|
||||||
|
43
matchers.go
43
matchers.go
@@ -123,11 +123,23 @@ func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter {
|
|||||||
|
|
||||||
func hasHTTP2Preface(r io.Reader) bool {
|
func hasHTTP2Preface(r io.Reader) bool {
|
||||||
var b [len(http2.ClientPreface)]byte
|
var b [len(http2.ClientPreface)]byte
|
||||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
last := 0
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(b[:]) == http2.ClientPreface
|
for {
|
||||||
|
n, err := r.Read(b[last:])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
last += n
|
||||||
|
eq := string(b[:last]) == http2.ClientPreface[:last]
|
||||||
|
if last == len(http2.ClientPreface) {
|
||||||
|
return eq
|
||||||
|
}
|
||||||
|
if !eq {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
|
func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
|
||||||
@@ -144,10 +156,14 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done := false
|
||||||
framer := http2.NewFramer(w, r)
|
framer := http2.NewFramer(w, r)
|
||||||
hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
|
hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
|
||||||
if hf.Name == name && hf.Value == value {
|
if hf.Name == name {
|
||||||
matched = true
|
done = true
|
||||||
|
if hf.Value == value {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
for {
|
for {
|
||||||
@@ -161,17 +177,20 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool
|
|||||||
if err := framer.WriteSettings(); err != nil {
|
if err := framer.WriteSettings(); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
case *http2.ContinuationFrame:
|
||||||
|
if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0
|
||||||
case *http2.HeadersFrame:
|
case *http2.HeadersFrame:
|
||||||
if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
|
if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if matched {
|
done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0
|
||||||
return true
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 {
|
if done {
|
||||||
return false
|
return matched
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user