mirror of
https://github.com/soheilhy/cmux.git
synced 2024-11-09 19:21:52 +08:00
639 lines
14 KiB
Go
639 lines
14 KiB
Go
// Copyright 2016 The CMux Authors. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
// implied. See the License for the specific language governing
|
|
// permissions and limitations under the License.
|
|
|
|
package cmux
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/rpc"
|
|
"runtime"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
"golang.org/x/net/http2/hpack"
|
|
)
|
|
|
|
const (
|
|
testHTTP1Resp = "http1"
|
|
rpcVal = 1234
|
|
)
|
|
|
|
func safeServe(errCh chan<- error, muxl CMux) {
|
|
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") {
|
|
errCh <- err
|
|
}
|
|
}
|
|
|
|
func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
|
|
c, err := rpc.Dial(addr.Network(), addr.String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return c, func() {
|
|
if err := c.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
type chanListener struct {
|
|
net.Listener
|
|
connCh chan net.Conn
|
|
}
|
|
|
|
func newChanListener() *chanListener {
|
|
return &chanListener{connCh: make(chan net.Conn, 1)}
|
|
}
|
|
|
|
func (l *chanListener) Accept() (net.Conn, error) {
|
|
if c, ok := <-l.connCh; ok {
|
|
return c, nil
|
|
}
|
|
return nil, errors.New("use of closed network connection")
|
|
}
|
|
|
|
func testListener(t *testing.T) (net.Listener, func()) {
|
|
l, err := net.Listen("tcp4", ":0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var once sync.Once
|
|
return l, func() {
|
|
once.Do(func() {
|
|
if err := l.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type testHTTP1Handler struct{}
|
|
|
|
func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, testHTTP1Resp)
|
|
}
|
|
|
|
func runTestHTTPServer(errCh chan<- error, l net.Listener) {
|
|
var mu sync.Mutex
|
|
conns := make(map[net.Conn]struct{})
|
|
|
|
defer func() {
|
|
mu.Lock()
|
|
for c := range conns {
|
|
if err := c.Close(); err != nil {
|
|
errCh <- err
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
}()
|
|
|
|
s := &http.Server{
|
|
Handler: &testHTTP1Handler{},
|
|
ConnState: func(c net.Conn, state http.ConnState) {
|
|
mu.Lock()
|
|
switch state {
|
|
case http.StateNew:
|
|
conns[c] = struct{}{}
|
|
case http.StateClosed:
|
|
delete(conns, c)
|
|
}
|
|
mu.Unlock()
|
|
},
|
|
}
|
|
if err := s.Serve(l); err != ErrListenerClosed {
|
|
errCh <- err
|
|
}
|
|
}
|
|
|
|
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
|
|
r, err := http.Get("http://" + addr.String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
defer func() {
|
|
if err = r.Body.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
b, err := ioutil.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if string(b) != testHTTP1Resp {
|
|
t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b)
|
|
}
|
|
}
|
|
|
|
type TestRPCRcvr struct{}
|
|
|
|
func (r TestRPCRcvr) Test(i int, j *int) error {
|
|
*j = i
|
|
return nil
|
|
}
|
|
|
|
func runTestRPCServer(errCh chan<- error, l net.Listener) {
|
|
s := rpc.NewServer()
|
|
if err := s.Register(TestRPCRcvr{}); err != nil {
|
|
errCh <- err
|
|
}
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
if err != ErrListenerClosed {
|
|
errCh <- err
|
|
}
|
|
return
|
|
}
|
|
go s.ServeConn(c)
|
|
}
|
|
}
|
|
|
|
func runTestRPCClient(t *testing.T, addr net.Addr) {
|
|
c, cleanup := safeDial(t, addr)
|
|
defer cleanup()
|
|
|
|
var num int
|
|
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if num != rpcVal {
|
|
t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
|
|
}
|
|
}
|
|
|
|
const (
|
|
handleHTTP1Close = 1
|
|
handleHTTP1Request = 2
|
|
handleAnyClose = 3
|
|
handleAnyRequest = 4
|
|
)
|
|
|
|
func TestTimeout(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
lis, Close := testListener(t)
|
|
defer Close()
|
|
result := make(chan int, 5)
|
|
testDuration := time.Millisecond * 100
|
|
m := New(lis)
|
|
m.SetReadTimeout(testDuration)
|
|
http1 := m.Match(HTTP1Fast())
|
|
any := m.Match(Any())
|
|
go func() {
|
|
_ = m.Serve()
|
|
}()
|
|
go func() {
|
|
con, err := http1.Accept()
|
|
if err != nil {
|
|
result <- handleHTTP1Close
|
|
} else {
|
|
_, _ = con.Write([]byte("http1"))
|
|
_ = con.Close()
|
|
result <- handleHTTP1Request
|
|
}
|
|
}()
|
|
go func() {
|
|
con, err := any.Accept()
|
|
if err != nil {
|
|
result <- handleAnyClose
|
|
} else {
|
|
_, _ = con.Write([]byte("any"))
|
|
_ = con.Close()
|
|
result <- handleAnyRequest
|
|
}
|
|
}()
|
|
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
|
|
client, err := net.Dial("tcp", lis.Addr().String())
|
|
if err != nil {
|
|
log.Fatal("testTimeout client failed: ", err)
|
|
}
|
|
defer func() {
|
|
_ = client.Close()
|
|
}()
|
|
time.Sleep(testDuration / 2)
|
|
if len(result) != 0 {
|
|
log.Print("tcp ")
|
|
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
|
|
}
|
|
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
|
|
buffer := make([]byte, 10)
|
|
rl, err := client.Read(buffer)
|
|
if err != nil {
|
|
t.Fatal("testTimeout failed: client error: ", err, rl)
|
|
}
|
|
Close()
|
|
if rl != 3 {
|
|
log.Print("testTimeout failed: response from wrong sevice ", rl)
|
|
}
|
|
if string(buffer[0:3]) != "any" {
|
|
log.Print("testTimeout failed: response from wrong sevice ")
|
|
}
|
|
time.Sleep(testDuration * 2)
|
|
if len(result) != 2 {
|
|
t.Fatal("testTimeout failed: accepted to less: ", len(result))
|
|
}
|
|
if a := <-result; a != handleAnyRequest {
|
|
t.Fatal("testTimeout failed: any rule did not match")
|
|
}
|
|
if a := <-result; a != handleHTTP1Close {
|
|
t.Fatal("testTimeout failed: no close an http rule")
|
|
}
|
|
}
|
|
|
|
func TestRead(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
const payload = "hello world\r\n"
|
|
const mult = 2
|
|
|
|
writer, reader := net.Pipe()
|
|
go func() {
|
|
if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
l := newChanListener()
|
|
defer close(l.connCh)
|
|
l.connCh <- reader
|
|
muxl := New(l)
|
|
// Register a bogus matcher to force buffering exactly the right amount.
|
|
// Before this fix, this would trigger a bug where `Read` would incorrectly
|
|
// report `io.EOF` when only the buffer had been consumed.
|
|
muxl.Match(func(r io.Reader) bool {
|
|
var b [len(payload)]byte
|
|
_, _ = r.Read(b[:])
|
|
return false
|
|
})
|
|
anyl := muxl.Match(Any())
|
|
go safeServe(errCh, muxl)
|
|
muxedConn, err := anyl.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for i := 0; i < mult; i++ {
|
|
var b [len(payload)]byte
|
|
n, err := muxedConn.Read(b[:])
|
|
if err != nil {
|
|
t.Error(err)
|
|
continue
|
|
}
|
|
if e := len(b); n != e {
|
|
t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
|
|
}
|
|
}
|
|
var b [1]byte
|
|
if _, err := muxedConn.Read(b[:]); err != io.EOF {
|
|
t.Errorf("unexpected error %v, expected %v", err, io.EOF)
|
|
}
|
|
}
|
|
|
|
func TestAny(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
l, cleanup := testListener(t)
|
|
defer cleanup()
|
|
|
|
muxl := New(l)
|
|
httpl := muxl.Match(Any())
|
|
|
|
go runTestHTTPServer(errCh, httpl)
|
|
go safeServe(errCh, muxl)
|
|
|
|
runTestHTTP1Client(t, l.Addr())
|
|
}
|
|
|
|
func TestHTTP2(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
writer, reader := net.Pipe()
|
|
go func() {
|
|
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
l := newChanListener()
|
|
l.connCh <- reader
|
|
muxl := New(l)
|
|
// Register a bogus matcher that only reads one byte.
|
|
muxl.Match(func(r io.Reader) bool {
|
|
var b [1]byte
|
|
_, _ = r.Read(b[:])
|
|
return false
|
|
})
|
|
h2l := muxl.Match(HTTP2())
|
|
go safeServe(errCh, muxl)
|
|
muxedConn, err := h2l.Accept()
|
|
close(l.connCh)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var b [len(http2.ClientPreface)]byte
|
|
var n int
|
|
// We have the sniffed buffer first...
|
|
if n, err = muxedConn.Read(b[:]); err == io.EOF {
|
|
t.Fatal(err)
|
|
}
|
|
// and then we read from the source.
|
|
if _, err = muxedConn.Read(b[n:]); err != io.EOF {
|
|
t.Fatal(err)
|
|
}
|
|
if string(b[:]) != http2.ClientPreface {
|
|
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
|
|
}
|
|
}
|
|
|
|
func TestHTTP2MatchHeaderField(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
name := "name"
|
|
value := "value"
|
|
writer, reader := net.Pipe()
|
|
go func() {
|
|
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var buf bytes.Buffer
|
|
enc := hpack.NewEncoder(&buf)
|
|
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
framer := http2.NewFramer(writer, nil)
|
|
err := framer.WriteHeaders(http2.HeadersFrameParam{
|
|
StreamID: 1,
|
|
BlockFragment: buf.Bytes(),
|
|
EndStream: true,
|
|
EndHeaders: true,
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
l := newChanListener()
|
|
l.connCh <- reader
|
|
muxl := New(l)
|
|
// Register a bogus matcher that only reads one byte.
|
|
muxl.Match(func(r io.Reader) bool {
|
|
var b [1]byte
|
|
_, _ = r.Read(b[:])
|
|
return false
|
|
})
|
|
// Create a matcher that cannot match the response.
|
|
muxl.Match(HTTP2HeaderField(name, "another"+value))
|
|
// Then match with the expected field.
|
|
h2l := muxl.Match(HTTP2HeaderField(name, value))
|
|
go safeServe(errCh, muxl)
|
|
muxedConn, err := h2l.Accept()
|
|
close(l.connCh)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var b [len(http2.ClientPreface)]byte
|
|
// We have the sniffed buffer first...
|
|
if _, err := muxedConn.Read(b[:]); err == io.EOF {
|
|
t.Fatal(err)
|
|
}
|
|
if string(b[:]) != http2.ClientPreface {
|
|
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
|
|
}
|
|
}
|
|
|
|
func TestHTTPGoRPC(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
l, cleanup := testListener(t)
|
|
defer cleanup()
|
|
|
|
muxl := New(l)
|
|
httpl := muxl.Match(HTTP2(), HTTP1Fast())
|
|
rpcl := muxl.Match(Any())
|
|
|
|
go runTestHTTPServer(errCh, httpl)
|
|
go runTestRPCServer(errCh, rpcl)
|
|
go safeServe(errCh, muxl)
|
|
|
|
runTestHTTP1Client(t, l.Addr())
|
|
runTestRPCClient(t, l.Addr())
|
|
}
|
|
|
|
func TestErrorHandler(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
l, cleanup := testListener(t)
|
|
defer cleanup()
|
|
|
|
muxl := New(l)
|
|
httpl := muxl.Match(HTTP2(), HTTP1Fast())
|
|
|
|
go runTestHTTPServer(errCh, httpl)
|
|
go safeServe(errCh, muxl)
|
|
|
|
var errCount uint32
|
|
muxl.HandleError(func(err error) bool {
|
|
if atomic.AddUint32(&errCount, 1) == 1 {
|
|
if _, ok := err.(ErrNotMatched); !ok {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
c, cleanup := safeDial(t, l.Addr())
|
|
defer cleanup()
|
|
|
|
var num int
|
|
for atomic.LoadUint32(&errCount) == 0 {
|
|
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
|
|
// The connection is simply closed.
|
|
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
defer leakCheck(t)()
|
|
errCh := make(chan error)
|
|
defer func() {
|
|
select {
|
|
case err := <-errCh:
|
|
t.Fatal(err)
|
|
default:
|
|
}
|
|
}()
|
|
l := newChanListener()
|
|
|
|
c1, c2 := net.Pipe()
|
|
|
|
muxl := New(l)
|
|
anyl := muxl.Match(Any())
|
|
|
|
go safeServe(errCh, muxl)
|
|
|
|
l.connCh <- c1
|
|
|
|
// First connection goes through.
|
|
if _, err := anyl.Accept(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Second connection is sent
|
|
l.connCh <- c2
|
|
|
|
// Listener is closed.
|
|
close(l.connCh)
|
|
|
|
// Second connection either goes through or it is closed.
|
|
if _, err := anyl.Accept(); err != nil {
|
|
if err != ErrListenerClosed {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := c2.Read([]byte{}); err != io.ErrClosedPipe {
|
|
t.Fatalf("connection is not closed and is leaked: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cribbed from google.golang.org/grpc/test/end2end_test.go.
|
|
|
|
// interestingGoroutines returns all goroutines we care about for the purpose
|
|
// of leak checking. It excludes testing or runtime ones.
|
|
func interestingGoroutines() (gs []string) {
|
|
buf := make([]byte, 2<<20)
|
|
buf = buf[:runtime.Stack(buf, true)]
|
|
for _, g := range strings.Split(string(buf), "\n\n") {
|
|
sl := strings.SplitN(g, "\n", 2)
|
|
if len(sl) != 2 {
|
|
continue
|
|
}
|
|
stack := strings.TrimSpace(sl[1])
|
|
if strings.HasPrefix(stack, "testing.RunTests") {
|
|
continue
|
|
}
|
|
|
|
if stack == "" ||
|
|
strings.Contains(stack, "main.main()") ||
|
|
strings.Contains(stack, "testing.Main(") ||
|
|
strings.Contains(stack, "runtime.goexit") ||
|
|
strings.Contains(stack, "created by runtime.gc") ||
|
|
strings.Contains(stack, "interestingGoroutines") ||
|
|
strings.Contains(stack, "runtime.MHeap_Scavenger") {
|
|
continue
|
|
}
|
|
gs = append(gs, g)
|
|
}
|
|
sort.Strings(gs)
|
|
return
|
|
}
|
|
|
|
// leakCheck snapshots the currently-running goroutines and returns a
|
|
// function to be run at the end of tests to see whether any
|
|
// goroutines leaked.
|
|
func leakCheck(t testing.TB) func() {
|
|
orig := map[string]bool{}
|
|
for _, g := range interestingGoroutines() {
|
|
orig[g] = true
|
|
}
|
|
return func() {
|
|
// Loop, waiting for goroutines to shut down.
|
|
// Wait up to 5 seconds, but finish as quickly as possible.
|
|
deadline := time.Now().Add(5 * time.Second)
|
|
for {
|
|
var leaked []string
|
|
for _, g := range interestingGoroutines() {
|
|
if !orig[g] {
|
|
leaked = append(leaked, g)
|
|
}
|
|
}
|
|
if len(leaked) == 0 {
|
|
return
|
|
}
|
|
if time.Now().Before(deadline) {
|
|
time.Sleep(50 * time.Millisecond)
|
|
continue
|
|
}
|
|
for _, g := range leaked {
|
|
t.Errorf("Leaked goroutine: %v", g)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|