2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-09-20 02:55:46 +08:00
cmux/cmux_test.go

789 lines
17 KiB
Go
Raw Normal View History

// Copyright 2016 The CMux Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
2015-07-30 01:45:57 +08:00
package cmux
import (
"bytes"
2017-07-20 23:35:05 +08:00
"crypto/rand"
"crypto/tls"
"errors"
2015-07-30 01:45:57 +08:00
"fmt"
2017-07-20 23:35:05 +08:00
"go/build"
"io"
2015-07-30 01:45:57 +08:00
"io/ioutil"
2016-09-06 01:56:50 +08:00
"log"
2015-07-30 01:45:57 +08:00
"net"
"net/http"
"net/rpc"
2017-07-20 23:35:05 +08:00
"os"
"os/exec"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
2015-07-30 01:45:57 +08:00
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
2015-07-30 01:45:57 +08:00
)
const (
testHTTP1Resp = "http1"
rpcVal = 1234
)
func safeServe(errCh chan<- error, muxl CMux) {
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed") {
errCh <- err
}
}
func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
c, err := rpc.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatal(err)
}
return c, func() {
if err := c.Close(); err != nil {
t.Fatal(err)
}
}
}
type chanListener struct {
net.Listener
connCh chan net.Conn
}
func newChanListener() *chanListener {
return &chanListener{connCh: make(chan net.Conn, 1)}
}
func (l *chanListener) Accept() (net.Conn, error) {
if c, ok := <-l.connCh; ok {
return c, nil
}
return nil, errors.New("use of closed network connection")
}
2016-02-21 04:43:18 +08:00
func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", "127.0.0.1:0")
2015-07-30 01:45:57 +08:00
if err != nil {
t.Fatal(err)
}
2016-09-06 01:56:50 +08:00
var once sync.Once
2016-02-21 04:43:18 +08:00
return l, func() {
2016-09-06 01:56:50 +08:00
once.Do(func() {
if err := l.Close(); err != nil {
t.Fatal(err)
}
})
2016-02-21 04:43:18 +08:00
}
2015-07-30 01:45:57 +08:00
}
type testHTTP1Handler struct{}
func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, testHTTP1Resp)
}
func runTestHTTPServer(errCh chan<- error, l net.Listener) {
var mu sync.Mutex
conns := make(map[net.Conn]struct{})
defer func() {
mu.Lock()
for c := range conns {
if err := c.Close(); err != nil {
errCh <- err
}
}
mu.Unlock()
}()
2015-07-30 01:45:57 +08:00
s := &http.Server{
Handler: &testHTTP1Handler{},
ConnState: func(c net.Conn, state http.ConnState) {
mu.Lock()
switch state {
case http.StateNew:
conns[c] = struct{}{}
case http.StateClosed:
delete(conns, c)
}
mu.Unlock()
},
2015-07-30 01:45:57 +08:00
}
if err := s.Serve(l); err != ErrListenerClosed && err != ErrServerClosed {
errCh <- err
errcheck go/src/github.com/soheilhy/cmux/cmux.go:127:13 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:134:9 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:137:15 m.root.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:43:9 s.Serve(l) go/src/github.com/soheilhy/cmux/cmux_test.go:52:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:72:12 s.Register(TestRPCRcvr{}) go/src/github.com/soheilhy/cmux/cmux_test.go:103:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:109:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:116:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:125:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:133:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:141:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:147:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:27:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_recursive_test.go:56:12 s.Register(&RecursiveRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_recursive_test.go:88:15 go tlsm.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:89:12 tcpm.Serve() go/src/github.com/soheilhy/cmux/example_test.go:30:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:34:9 io.Copy(ws, ws) go/src/github.com/soheilhy/cmux/example_test.go:41:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:53:12 s.Register(&ExampleRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_test.go:68:13 grpcs.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:97:9 m.Serve() go/src/github.com/soheilhy/cmux/example_tls_test.go:24:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_tls_test.go:69:9 m.Serve() go/src/github.com/soheilhy/cmux/matchers.go:151:14 hdec.Write(f.HeaderBlockFragment())
2016-01-07 23:15:03 +08:00
}
2015-07-30 01:45:57 +08:00
}
2017-07-20 23:35:05 +08:00
func generateTLSCert(t *testing.T) {
err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
if err != nil {
t.Fatal(err)
}
}
func cleanupTLSCert(t *testing.T) {
err := os.Remove("cert.pem")
if err != nil {
t.Error(err)
}
err = os.Remove("key.pem")
if err != nil {
t.Error(err)
}
}
func runTestTLSServer(errCh chan<- error, l net.Listener) {
certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
if err != nil {
errCh <- err
log.Printf("1")
return
}
config := &tls.Config{
Certificates: []tls.Certificate{certificate},
Rand: rand.Reader,
}
tlsl := tls.NewListener(l, config)
runTestHTTPServer(errCh, tlsl)
}
2016-02-21 04:43:18 +08:00
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
2017-07-20 23:35:05 +08:00
runTestHTTPClient(t, "http", addr)
}
func runTestTLSClient(t *testing.T, addr net.Addr) {
runTestHTTPClient(t, "https", addr)
}
func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
client := http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
r, err := client.Get(proto + "://" + addr.String())
if err != nil {
2015-07-30 01:45:57 +08:00
t.Fatal(err)
}
defer func() {
if err = r.Body.Close(); err != nil {
t.Fatal(err)
errcheck go/src/github.com/soheilhy/cmux/cmux.go:127:13 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:134:9 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:137:15 m.root.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:43:9 s.Serve(l) go/src/github.com/soheilhy/cmux/cmux_test.go:52:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:72:12 s.Register(TestRPCRcvr{}) go/src/github.com/soheilhy/cmux/cmux_test.go:103:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:109:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:116:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:125:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:133:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:141:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:147:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:27:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_recursive_test.go:56:12 s.Register(&RecursiveRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_recursive_test.go:88:15 go tlsm.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:89:12 tcpm.Serve() go/src/github.com/soheilhy/cmux/example_test.go:30:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:34:9 io.Copy(ws, ws) go/src/github.com/soheilhy/cmux/example_test.go:41:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:53:12 s.Register(&ExampleRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_test.go:68:13 grpcs.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:97:9 m.Serve() go/src/github.com/soheilhy/cmux/example_tls_test.go:24:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_tls_test.go:69:9 m.Serve() go/src/github.com/soheilhy/cmux/matchers.go:151:14 hdec.Write(f.HeaderBlockFragment())
2016-01-07 23:15:03 +08:00
}
}()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != testHTTP1Resp {
t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b)
2015-07-30 01:45:57 +08:00
}
}
type TestRPCRcvr struct{}
func (r TestRPCRcvr) Test(i int, j *int) error {
*j = i
return nil
}
func runTestRPCServer(errCh chan<- error, l net.Listener) {
2015-07-30 01:45:57 +08:00
s := rpc.NewServer()
errcheck go/src/github.com/soheilhy/cmux/cmux.go:127:13 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:134:9 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:137:15 m.root.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:43:9 s.Serve(l) go/src/github.com/soheilhy/cmux/cmux_test.go:52:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:72:12 s.Register(TestRPCRcvr{}) go/src/github.com/soheilhy/cmux/cmux_test.go:103:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:109:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:116:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:125:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:133:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:141:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:147:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:27:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_recursive_test.go:56:12 s.Register(&RecursiveRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_recursive_test.go:88:15 go tlsm.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:89:12 tcpm.Serve() go/src/github.com/soheilhy/cmux/example_test.go:30:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:34:9 io.Copy(ws, ws) go/src/github.com/soheilhy/cmux/example_test.go:41:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:53:12 s.Register(&ExampleRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_test.go:68:13 grpcs.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:97:9 m.Serve() go/src/github.com/soheilhy/cmux/example_tls_test.go:24:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_tls_test.go:69:9 m.Serve() go/src/github.com/soheilhy/cmux/matchers.go:151:14 hdec.Write(f.HeaderBlockFragment())
2016-01-07 23:15:03 +08:00
if err := s.Register(TestRPCRcvr{}); err != nil {
errCh <- err
errcheck go/src/github.com/soheilhy/cmux/cmux.go:127:13 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:134:9 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:137:15 m.root.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:43:9 s.Serve(l) go/src/github.com/soheilhy/cmux/cmux_test.go:52:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:72:12 s.Register(TestRPCRcvr{}) go/src/github.com/soheilhy/cmux/cmux_test.go:103:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:109:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:116:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:125:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:133:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:141:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:147:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:27:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_recursive_test.go:56:12 s.Register(&RecursiveRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_recursive_test.go:88:15 go tlsm.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:89:12 tcpm.Serve() go/src/github.com/soheilhy/cmux/example_test.go:30:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:34:9 io.Copy(ws, ws) go/src/github.com/soheilhy/cmux/example_test.go:41:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:53:12 s.Register(&ExampleRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_test.go:68:13 grpcs.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:97:9 m.Serve() go/src/github.com/soheilhy/cmux/example_tls_test.go:24:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_tls_test.go:69:9 m.Serve() go/src/github.com/soheilhy/cmux/matchers.go:151:14 hdec.Write(f.HeaderBlockFragment())
2016-01-07 23:15:03 +08:00
}
2015-07-30 01:45:57 +08:00
for {
c, err := l.Accept()
if err != nil {
if err != ErrListenerClosed && err != ErrServerClosed {
errCh <- err
}
2015-07-30 01:45:57 +08:00
return
}
errcheck go/src/github.com/soheilhy/cmux/cmux.go:127:13 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:134:9 c.Close() go/src/github.com/soheilhy/cmux/cmux.go:137:15 m.root.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:43:9 s.Serve(l) go/src/github.com/soheilhy/cmux/cmux_test.go:52:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:72:12 s.Register(TestRPCRcvr{}) go/src/github.com/soheilhy/cmux/cmux_test.go:103:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:109:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:116:20 defer r.Body.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:125:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:133:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/cmux_test.go:141:15 defer l.Close() go/src/github.com/soheilhy/cmux/cmux_test.go:147:15 go muxl.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:27:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_recursive_test.go:56:12 s.Register(&RecursiveRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_recursive_test.go:88:15 go tlsm.Serve() go/src/github.com/soheilhy/cmux/example_recursive_test.go:89:12 tcpm.Serve() go/src/github.com/soheilhy/cmux/example_test.go:30:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:34:9 io.Copy(ws, ws) go/src/github.com/soheilhy/cmux/example_test.go:41:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:53:12 s.Register(&ExampleRPCRcvr{}) go/src/github.com/soheilhy/cmux/example_test.go:68:13 grpcs.Serve(l) go/src/github.com/soheilhy/cmux/example_test.go:97:9 m.Serve() go/src/github.com/soheilhy/cmux/example_tls_test.go:24:9 s.Serve(l) go/src/github.com/soheilhy/cmux/example_tls_test.go:69:9 m.Serve() go/src/github.com/soheilhy/cmux/matchers.go:151:14 hdec.Write(f.HeaderBlockFragment())
2016-01-07 23:15:03 +08:00
go s.ServeConn(c)
2015-07-30 01:45:57 +08:00
}
}
2016-02-21 04:43:18 +08:00
func runTestRPCClient(t *testing.T, addr net.Addr) {
c, cleanup := safeDial(t, addr)
defer cleanup()
2015-07-30 01:45:57 +08:00
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
t.Fatal(err)
2015-07-30 01:45:57 +08:00
}
if num != rpcVal {
t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
}
}
2016-09-06 01:56:50 +08:00
const (
2016-09-25 12:58:27 +08:00
handleHTTP1Close = 1
handleHTTP1Request = 2
2016-09-06 01:56:50 +08:00
handleAnyClose = 3
handleAnyRequest = 4
)
func TestTimeout(t *testing.T) {
defer leakCheck(t)()
lis, Close := testListener(t)
defer Close()
result := make(chan int, 5)
testDuration := time.Millisecond * 500
2016-09-06 01:56:50 +08:00
m := New(lis)
m.SetReadTimeout(testDuration)
http1 := m.Match(HTTP1Fast())
any := m.Match(Any())
go func() {
_ = m.Serve()
}()
go func() {
con, err := http1.Accept()
if err != nil {
2016-09-25 12:58:27 +08:00
result <- handleHTTP1Close
2016-09-06 01:56:50 +08:00
} else {
_, _ = con.Write([]byte("http1"))
_ = con.Close()
2016-09-25 12:58:27 +08:00
result <- handleHTTP1Request
2016-09-06 01:56:50 +08:00
}
}()
go func() {
con, err := any.Accept()
if err != nil {
result <- handleAnyClose
} else {
_, _ = con.Write([]byte("any"))
_ = con.Close()
result <- handleAnyRequest
}
}()
time.Sleep(testDuration) // wait to prevent timeouts on slow test-runners
client, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
log.Fatal("testTimeout client failed: ", err)
}
defer func() {
_ = client.Close()
}()
time.Sleep(testDuration / 2)
if len(result) != 0 {
log.Print("tcp ")
t.Fatal("testTimeout failed: accepted to fast: ", len(result))
}
_ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
buffer := make([]byte, 10)
rl, err := client.Read(buffer)
if err != nil {
t.Fatal("testTimeout failed: client error: ", err, rl)
}
Close()
if rl != 3 {
log.Print("testTimeout failed: response from wrong sevice ", rl)
}
if string(buffer[0:3]) != "any" {
log.Print("testTimeout failed: response from wrong sevice ")
}
time.Sleep(testDuration * 2)
if len(result) != 2 {
t.Fatal("testTimeout failed: accepted to less: ", len(result))
}
if a := <-result; a != handleAnyRequest {
t.Fatal("testTimeout failed: any rule did not match")
}
2016-09-25 12:58:27 +08:00
if a := <-result; a != handleHTTP1Close {
2016-09-06 01:56:50 +08:00
t.Fatal("testTimeout failed: no close an http rule")
}
}
func TestRead(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
const payload = "hello world\r\n"
const mult = 2
writer, reader := net.Pipe()
go func() {
if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
}()
l := newChanListener()
defer close(l.connCh)
l.connCh <- reader
muxl := New(l)
// Register a bogus matcher to force buffering exactly the right amount.
// Before this fix, this would trigger a bug where `Read` would incorrectly
// report `io.EOF` when only the buffer had been consumed.
muxl.Match(func(r io.Reader) bool {
var b [len(payload)]byte
_, _ = r.Read(b[:])
return false
})
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
muxedConn, err := anyl.Accept()
if err != nil {
t.Fatal(err)
}
for i := 0; i < mult; i++ {
var b [len(payload)]byte
n, err := muxedConn.Read(b[:])
if err != nil {
t.Error(err)
continue
}
if e := len(b); n != e {
t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
}
}
var b [1]byte
if _, err := muxedConn.Read(b[:]); err != io.EOF {
t.Errorf("unexpected error %v, expected %v", err, io.EOF)
}
}
2015-07-30 01:45:57 +08:00
func TestAny(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
2016-02-21 04:43:18 +08:00
l, cleanup := testListener(t)
defer cleanup()
2015-07-30 01:45:57 +08:00
muxl := New(l)
httpl := muxl.Match(Any())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
2015-07-30 01:45:57 +08:00
2016-02-22 01:57:37 +08:00
runTestHTTP1Client(t, l.Addr())
2015-07-30 01:45:57 +08:00
}
2017-07-20 23:35:05 +08:00
func TestTLS(t *testing.T) {
generateTLSCert(t)
defer cleanupTLSCert(t)
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
muxl := New(l)
tlsl := muxl.Match(TLS())
httpl := muxl.Match(Any())
go runTestTLSServer(errCh, tlsl)
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
runTestHTTP1Client(t, l.Addr())
runTestTLSClient(t, l.Addr())
}
func TestHTTP2(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
writer, reader := net.Pipe()
go func() {
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
}()
l := newChanListener()
l.connCh <- reader
muxl := New(l)
// Register a bogus matcher that only reads one byte.
muxl.Match(func(r io.Reader) bool {
var b [1]byte
_, _ = r.Read(b[:])
return false
})
h2l := muxl.Match(HTTP2())
go safeServe(errCh, muxl)
muxedConn, err := h2l.Accept()
close(l.connCh)
if err != nil {
t.Fatal(err)
}
var b [len(http2.ClientPreface)]byte
var n int
// We have the sniffed buffer first...
if n, err = muxedConn.Read(b[:]); err == io.EOF {
t.Fatal(err)
}
// and then we read from the source.
if _, err = muxedConn.Read(b[n:]); err != io.EOF {
t.Fatal(err)
}
if string(b[:]) != http2.ClientPreface {
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
}
}
func TestHTTP2MatchHeaderField(t *testing.T) {
2017-08-15 00:23:22 +08:00
testHTTP2MatchHeaderField(t, HTTP2HeaderField, "value", "value", "anothervalue")
}
func TestHTTP2MatchHeaderFieldPrefix(t *testing.T) {
testHTTP2MatchHeaderField(t, HTTP2HeaderFieldPrefix, "application/grpc+proto", "application/grpc", "application/json")
}
func testHTTP2MatchHeaderField(
t *testing.T,
matcherConstructor func(string, string) Matcher,
headerValue string,
matchValue string,
notMatchValue string,
) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
name := "name"
writer, reader := net.Pipe()
go func() {
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
2017-08-15 00:23:22 +08:00
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil {
t.Fatal(err)
}
framer := http2.NewFramer(writer, nil)
err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: buf.Bytes(),
EndStream: true,
EndHeaders: true,
})
if err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
}()
l := newChanListener()
l.connCh <- reader
muxl := New(l)
// Register a bogus matcher that only reads one byte.
muxl.Match(func(r io.Reader) bool {
var b [1]byte
_, _ = r.Read(b[:])
return false
})
// Create a matcher that cannot match the response.
2017-08-15 00:23:22 +08:00
muxl.Match(matcherConstructor(name, notMatchValue))
// Then match with the expected field.
2017-08-15 00:23:22 +08:00
h2l := muxl.Match(matcherConstructor(name, matchValue))
go safeServe(errCh, muxl)
muxedConn, err := h2l.Accept()
close(l.connCh)
if err != nil {
t.Fatal(err)
}
var b [len(http2.ClientPreface)]byte
// We have the sniffed buffer first...
if _, err := muxedConn.Read(b[:]); err == io.EOF {
t.Fatal(err)
}
if string(b[:]) != http2.ClientPreface {
t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
}
}
2015-07-30 01:45:57 +08:00
func TestHTTPGoRPC(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
2016-02-21 04:43:18 +08:00
l, cleanup := testListener(t)
defer cleanup()
2015-07-30 01:45:57 +08:00
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
rpcl := muxl.Match(Any())
go runTestHTTPServer(errCh, httpl)
go runTestRPCServer(errCh, rpcl)
go safeServe(errCh, muxl)
2015-07-30 01:45:57 +08:00
2016-02-21 04:43:18 +08:00
runTestHTTP1Client(t, l.Addr())
runTestRPCClient(t, l.Addr())
2015-07-30 01:45:57 +08:00
}
func TestErrorHandler(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
2016-02-21 04:43:18 +08:00
l, cleanup := testListener(t)
defer cleanup()
2015-07-30 01:45:57 +08:00
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
2015-07-30 01:45:57 +08:00
var errCount uint32
2015-07-30 01:45:57 +08:00
muxl.HandleError(func(err error) bool {
if atomic.AddUint32(&errCount, 1) == 1 {
if _, ok := err.(ErrNotMatched); !ok {
t.Errorf("unexpected error: %v", err)
}
2015-07-30 01:45:57 +08:00
}
return true
})
c, cleanup := safeDial(t, l.Addr())
defer cleanup()
2015-07-30 01:45:57 +08:00
var num int
for atomic.LoadUint32(&errCount) == 0 {
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
// The connection is simply closed.
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
}
2015-07-30 01:45:57 +08:00
}
}
2018-01-24 01:15:21 +08:00
func TestMultipleMatchers(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
matcher := func(r io.Reader) bool {
return true
}
unmatcher := func(r io.Reader) bool {
return false
}
muxl := New(l)
lis := muxl.Match(unmatcher, matcher, unmatcher)
go runTestHTTPServer(errCh, lis)
go safeServe(errCh, muxl)
runTestHTTP1Client(t, l.Addr())
}
func TestListenerClose(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l := newChanListener()
c1, c2 := net.Pipe()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
l.connCh <- c1
// First connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
// Second connection is sent
l.connCh <- c2
// Listener is closed.
close(l.connCh)
// Second connection either goes through or it is closed.
if _, err := anyl.Accept(); err != nil {
if err != ErrListenerClosed && err != ErrServerClosed {
t.Fatal(err)
}
// The error is either io.ErrClosedPipe or net.OpError wrapping
// a net.pipeError depending on the go version.
if _, err := c2.Read([]byte{}); !strings.Contains(err.Error(), "closed") {
t.Fatalf("connection is not closed and is leaked: %v", err)
}
}
}
func TestClose(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
muxl.Close()
if _, err := anyl.Accept(); err != ErrServerClosed {
t.Fatal(err)
}
}
// Cribbed from google.golang.org/grpc/test/end2end_test.go.
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
for _, g := range strings.Split(string(buf), "\n\n") {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
continue
}
stack := strings.TrimSpace(sl[1])
if strings.HasPrefix(stack, "testing.RunTests") {
continue
}
if stack == "" ||
2016-09-06 01:56:50 +08:00
strings.Contains(stack, "main.main()") ||
strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") {
continue
}
gs = append(gs, g)
}
sort.Strings(gs)
return
}
// leakCheck snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked.
func leakCheck(t testing.TB) func() {
orig := map[string]bool{}
for _, g := range interestingGoroutines() {
orig[g] = true
}
return func() {
// Loop, waiting for goroutines to shut down.
// Wait up to 5 seconds, but finish as quickly as possible.
deadline := time.Now().Add(5 * time.Second)
for {
var leaked []string
for _, g := range interestingGoroutines() {
if !orig[g] {
leaked = append(leaked, g)
}
}
if len(leaked) == 0 {
return
}
if time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
continue
}
for _, g := range leaked {
t.Errorf("Leaked goroutine: %v", g)
}
return
}
}
}