diff --git a/cmux_test.go b/cmux_test.go index 5cd7185..0df6470 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -6,7 +6,12 @@ import ( "net" "net/http" "net/rpc" + "runtime" + "sort" + "strings" + "sync" "testing" + "time" ) const ( @@ -14,6 +19,24 @@ const ( rpcVal = 1234 ) +func safeServe(errCh chan<- error, muxl CMux) { + if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { + errCh <- err + } +} + +func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { + c, err := rpc.Dial(addr.Network(), addr.String()) + if err != nil { + t.Fatal(err) + } + return c, func() { + if err := c.Close(); err != nil { + t.Fatal(err) + } + } +} + func testListener(t *testing.T) (net.Listener, func()) { l, err := net.Listen("tcp", ":0") if err != nil { @@ -21,7 +44,7 @@ func testListener(t *testing.T) (net.Listener, func()) { } return l, func() { if err := l.Close(); err != nil { - t.Error(err) + t.Fatal(err) } } } @@ -32,12 +55,35 @@ func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, testHTTP1Resp) } -func runTestHTTPServer(t *testing.T, l net.Listener) { +func runTestHTTPServer(errCh chan<- error, l net.Listener) { + var mu sync.Mutex + conns := make(map[net.Conn]struct{}) + + defer func() { + mu.Lock() + for c := range conns { + if err := c.Close(); err != nil { + errCh <- err + } + } + mu.Unlock() + }() + s := &http.Server{ Handler: &testHTTP1Handler{}, + ConnState: func(c net.Conn, state http.ConnState) { + mu.Lock() + switch state { + case http.StateNew: + conns[c] = struct{}{} + case http.StateClosed: + delete(conns, c) + } + mu.Unlock() + }, } - if err := s.Serve(l); err != nil && err != ErrListenerClosed { - t.Log(err) + if err := s.Serve(l); err != ErrListenerClosed { + errCh <- err } } @@ -49,12 +95,12 @@ func runTestHTTP1Client(t *testing.T, addr net.Addr) { defer func() { if err := r.Body.Close(); err != nil { - t.Log(err) + t.Fatal(err) } }() b, err := ioutil.ReadAll(r.Body) if err != nil { - t.Error(err) + t.Fatal(err) } if string(b) != testHTTP1Resp { @@ -69,15 +115,17 @@ func (r TestRPCRcvr) Test(i int, j *int) error { return nil } -func runTestRPCServer(t *testing.T, l net.Listener) { +func runTestRPCServer(errCh chan<- error, l net.Listener) { s := rpc.NewServer() if err := s.Register(TestRPCRcvr{}); err != nil { - t.Fatal(err) + errCh <- err } for { c, err := l.Accept() if err != nil { - t.Log(err) + if err != ErrListenerClosed { + errCh <- err + } return } go s.ServeConn(c) @@ -85,16 +133,12 @@ func runTestRPCServer(t *testing.T, l net.Listener) { } func runTestRPCClient(t *testing.T, addr net.Addr) { - c, err := rpc.Dial(addr.Network(), addr.String()) - if err != nil { - t.Error(err) - return - } + c, cleanup := safeDial(t, addr) + defer cleanup() var num int if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil { - t.Error(err) - return + t.Fatal(err) } if num != rpcVal { @@ -103,23 +147,37 @@ func runTestRPCClient(t *testing.T, addr net.Addr) { } func TestAny(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() l, cleanup := testListener(t) defer cleanup() muxl := New(l) httpl := muxl.Match(Any()) - go runTestHTTPServer(t, httpl) - go func() { - if err := muxl.Serve(); err != nil { - t.Log(err) - } - }() + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) runTestHTTP1Client(t, l.Addr()) } func TestHTTPGoRPC(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() l, cleanup := testListener(t) defer cleanup() @@ -127,31 +185,32 @@ func TestHTTPGoRPC(t *testing.T) { httpl := muxl.Match(HTTP2(), HTTP1Fast()) rpcl := muxl.Match(Any()) - go runTestHTTPServer(t, httpl) - go runTestRPCServer(t, rpcl) - go func() { - if err := muxl.Serve(); err != nil { - t.Log(err) - } - }() + go runTestHTTPServer(errCh, httpl) + go runTestRPCServer(errCh, rpcl) + go safeServe(errCh, muxl) runTestHTTP1Client(t, l.Addr()) runTestRPCClient(t, l.Addr()) } func TestErrorHandler(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() l, cleanup := testListener(t) defer cleanup() muxl := New(l) httpl := muxl.Match(HTTP2(), HTTP1Fast()) - go runTestHTTPServer(t, httpl) - go func() { - if err := muxl.Serve(); err != nil { - t.Log(err) - } - }() + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) firstErr := true muxl.HandleError(func(err error) bool { @@ -165,11 +224,8 @@ func TestErrorHandler(t *testing.T) { return true }) - addr := l.Addr() - c, err := rpc.Dial(addr.Network(), addr.String()) - if err != nil { - t.Fatal(err) - } + c, cleanup := safeDial(t, l.Addr()) + defer cleanup() var num int if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { @@ -184,6 +240,7 @@ type closerConn struct { func (c closerConn) Close() error { return nil } func TestClosed(t *testing.T) { + defer leakCheck(t)() mux := &cMux{} lis := mux.Match(Any()).(muxListener) close(lis.donec) @@ -193,3 +250,68 @@ func TestClosed(t *testing.T) { t.Errorf("expected errListenerClosed got %v", err) } } + +// Cribbed from google.golang.org/grpc/test/end2end_test.go. + +// interestingGoroutines returns all goroutines we care about for the purpose +// of leak checking. It excludes testing or runtime ones. +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if strings.HasPrefix(stack, "testing.RunTests") { + continue + } + + if stack == "" || + strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, g) + } + sort.Strings(gs) + return +} + +// leakCheck snapshots the currently-running goroutines and returns a +// function to be run at the end of tests to see whether any +// goroutines leaked. +func leakCheck(t testing.TB) func() { + orig := map[string]bool{} + for _, g := range interestingGoroutines() { + orig[g] = true + } + return func() { + // Loop, waiting for goroutines to shut down. + // Wait up to 5 seconds, but finish as quickly as possible. + deadline := time.Now().Add(5 * time.Second) + for { + var leaked []string + for _, g := range interestingGoroutines() { + if !orig[g] { + leaked = append(leaked, g) + } + } + if len(leaked) == 0 { + return + } + if time.Now().Before(deadline) { + time.Sleep(50 * time.Millisecond) + continue + } + for _, g := range leaked { + t.Errorf("Leaked goroutine: %v", g) + } + return + } + } +}