2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-09-20 02:55:46 +08:00
cmux/cmux_test.go
Tamir Duberstein 235d98b021 Tweak shutdown behaviour
When the root listener is closed, child listeners will not be closed
until all parked connections are served. This prevents losing
connections that have been read from.

This also allows moving the main test to package cmux_test, but that
will happen in a separate change.
2016-02-23 11:13:16 -05:00

357 lines
6.8 KiB
Go

package cmux
import (
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/rpc"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
const (
testHTTP1Resp = "http1"
rpcVal = 1234
)
func safeServe(errCh chan<- error, muxl CMux) {
if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") {
errCh <- err
}
}
func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
c, err := rpc.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatal(err)
}
return c, func() {
if err := c.Close(); err != nil {
t.Fatal(err)
}
}
}
type chanListener struct {
net.Listener
connCh chan net.Conn
}
func newChanListener() *chanListener {
return &chanListener{connCh: make(chan net.Conn, 1)}
}
func (l *chanListener) Accept() (net.Conn, error) {
if c, ok := <-l.connCh; ok {
return c, nil
}
return nil, errors.New("use of closed network connection")
}
func testListener(t *testing.T) (net.Listener, func()) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
return l, func() {
if err := l.Close(); err != nil {
t.Fatal(err)
}
}
}
type testHTTP1Handler struct{}
func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, testHTTP1Resp)
}
func runTestHTTPServer(errCh chan<- error, l net.Listener) {
var mu sync.Mutex
conns := make(map[net.Conn]struct{})
defer func() {
mu.Lock()
for c := range conns {
if err := c.Close(); err != nil {
errCh <- err
}
}
mu.Unlock()
}()
s := &http.Server{
Handler: &testHTTP1Handler{},
ConnState: func(c net.Conn, state http.ConnState) {
mu.Lock()
switch state {
case http.StateNew:
conns[c] = struct{}{}
case http.StateClosed:
delete(conns, c)
}
mu.Unlock()
},
}
if err := s.Serve(l); err != ErrListenerClosed {
errCh <- err
}
}
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
r, err := http.Get("http://" + addr.String())
if err != nil {
t.Fatal(err)
}
defer func() {
if err := r.Body.Close(); err != nil {
t.Fatal(err)
}
}()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
if string(b) != testHTTP1Resp {
t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b)
}
}
type TestRPCRcvr struct{}
func (r TestRPCRcvr) Test(i int, j *int) error {
*j = i
return nil
}
func runTestRPCServer(errCh chan<- error, l net.Listener) {
s := rpc.NewServer()
if err := s.Register(TestRPCRcvr{}); err != nil {
errCh <- err
}
for {
c, err := l.Accept()
if err != nil {
if err != ErrListenerClosed {
errCh <- err
}
return
}
go s.ServeConn(c)
}
}
func runTestRPCClient(t *testing.T, addr net.Addr) {
c, cleanup := safeDial(t, addr)
defer cleanup()
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
t.Fatal(err)
}
if num != rpcVal {
t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
}
}
func TestAny(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
muxl := New(l)
httpl := muxl.Match(Any())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
runTestHTTP1Client(t, l.Addr())
}
func TestHTTPGoRPC(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
rpcl := muxl.Match(Any())
go runTestHTTPServer(errCh, httpl)
go runTestRPCServer(errCh, rpcl)
go safeServe(errCh, muxl)
runTestHTTP1Client(t, l.Addr())
runTestRPCClient(t, l.Addr())
}
func TestErrorHandler(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l, cleanup := testListener(t)
defer cleanup()
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, muxl)
var errCount uint32
muxl.HandleError(func(err error) bool {
if atomic.AddUint32(&errCount, 1) == 1 {
if _, ok := err.(ErrNotMatched); !ok {
t.Errorf("unexpected error: %v", err)
}
}
return true
})
c, cleanup := safeDial(t, l.Addr())
defer cleanup()
var num int
for atomic.LoadUint32(&errCount) == 0 {
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
// The connection is simply closed.
t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
}
}
}
func TestClose(t *testing.T) {
defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()
l := newChanListener()
c1, c2 := net.Pipe()
muxl := New(l)
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
l.connCh <- c1
// First connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
// Second connection is sent
l.connCh <- c2
// Listener is closed.
close(l.connCh)
// Second connection goes through.
if _, err := anyl.Accept(); err != nil {
t.Fatal(err)
}
}
// Cribbed from google.golang.org/grpc/test/end2end_test.go.
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
for _, g := range strings.Split(string(buf), "\n\n") {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
continue
}
stack := strings.TrimSpace(sl[1])
if strings.HasPrefix(stack, "testing.RunTests") {
continue
}
if stack == "" ||
strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") {
continue
}
gs = append(gs, g)
}
sort.Strings(gs)
return
}
// leakCheck snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked.
func leakCheck(t testing.TB) func() {
orig := map[string]bool{}
for _, g := range interestingGoroutines() {
orig[g] = true
}
return func() {
// Loop, waiting for goroutines to shut down.
// Wait up to 5 seconds, but finish as quickly as possible.
deadline := time.Now().Add(5 * time.Second)
for {
var leaked []string
for _, g := range interestingGoroutines() {
if !orig[g] {
leaked = append(leaked, g)
}
}
if len(leaked) == 0 {
return
}
if time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
continue
}
for _, g := range leaked {
t.Errorf("Leaked goroutine: %v", g)
}
return
}
}
}