package cmux import ( "fmt" "io/ioutil" "net" "net/http" "net/rpc" "runtime" "sort" "strings" "sync" "testing" "time" ) const ( testHTTP1Resp = "http1" rpcVal = 1234 ) func safeServe(t *testing.T, muxl CMux) { if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed network connection") { t.Fatal(err) } } func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { c, err := rpc.Dial(addr.Network(), addr.String()) if err != nil { t.Fatal(err) } return c, func() { if err := c.Close(); err != nil { t.Fatal(err) } } } func testListener(t *testing.T) (net.Listener, func()) { l, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } return l, func() { if err := l.Close(); err != nil { t.Fatal(err) } } } type testHTTP1Handler struct{} func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, testHTTP1Resp) } func runTestHTTPServer(t *testing.T, l net.Listener) { var mu sync.Mutex conns := make(map[net.Conn]struct{}) defer func() { mu.Lock() for c := range conns { if err := c.Close(); err != nil { t.Fatal(err) } } mu.Unlock() }() s := &http.Server{ Handler: &testHTTP1Handler{}, ConnState: func(c net.Conn, state http.ConnState) { mu.Lock() switch state { case http.StateNew: conns[c] = struct{}{} case http.StateClosed: delete(conns, c) } mu.Unlock() }, } if err := s.Serve(l); err != ErrListenerClosed { t.Fatal(err) } } func runTestHTTP1Client(t *testing.T, addr net.Addr) { r, err := http.Get("http://" + addr.String()) if err != nil { t.Fatal(err) } defer func() { if err := r.Body.Close(); err != nil { t.Fatal(err) } }() b, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) } if string(b) != testHTTP1Resp { t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b) } } type TestRPCRcvr struct{} func (r TestRPCRcvr) Test(i int, j *int) error { *j = i return nil } func runTestRPCServer(t *testing.T, l net.Listener) { s := rpc.NewServer() if err := s.Register(TestRPCRcvr{}); err != nil { t.Fatal(err) } for { c, err := l.Accept() if err != nil { if err != ErrListenerClosed { t.Fatal(err) } return } go s.ServeConn(c) } } func runTestRPCClient(t *testing.T, addr net.Addr) { c, cleanup := safeDial(t, addr) defer cleanup() var num int if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil { t.Fatal(err) } if num != rpcVal { t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num) } } func TestAny(t *testing.T) { defer leakCheck(t)() l, cleanup := testListener(t) defer cleanup() muxl := New(l) httpl := muxl.Match(Any()) go runTestHTTPServer(t, httpl) go safeServe(t, muxl) runTestHTTP1Client(t, l.Addr()) } func TestHTTPGoRPC(t *testing.T) { defer leakCheck(t)() l, cleanup := testListener(t) defer cleanup() muxl := New(l) httpl := muxl.Match(HTTP2(), HTTP1Fast()) rpcl := muxl.Match(Any()) go runTestHTTPServer(t, httpl) go runTestRPCServer(t, rpcl) go safeServe(t, muxl) runTestHTTP1Client(t, l.Addr()) runTestRPCClient(t, l.Addr()) } func TestErrorHandler(t *testing.T) { defer leakCheck(t)() l, cleanup := testListener(t) defer cleanup() muxl := New(l) httpl := muxl.Match(HTTP2(), HTTP1Fast()) go runTestHTTPServer(t, httpl) go safeServe(t, muxl) firstErr := true muxl.HandleError(func(err error) bool { if !firstErr { return true } if _, ok := err.(ErrNotMatched); !ok { t.Errorf("unexpected error: %v", err) } firstErr = false return true }) c, cleanup := safeDial(t, l.Addr()) defer cleanup() var num int if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { t.Error("rpc got a response") } } type closerConn struct { net.Conn } func (c closerConn) Close() error { return nil } func TestClosed(t *testing.T) { defer leakCheck(t)() mux := &cMux{} lis := mux.Match(Any()).(muxListener) close(lis.donec) mux.serve(closerConn{}) _, err := lis.Accept() if _, ok := err.(errListenerClosed); !ok { t.Errorf("expected errListenerClosed got %v", err) } } // Cribbed from google.golang.org/grpc/test/end2end_test.go. // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { buf := make([]byte, 2<<20) buf = buf[:runtime.Stack(buf, true)] for _, g := range strings.Split(string(buf), "\n\n") { sl := strings.SplitN(g, "\n", 2) if len(sl) != 2 { continue } stack := strings.TrimSpace(sl[1]) if strings.HasPrefix(stack, "testing.RunTests") { continue } if stack == "" || strings.Contains(stack, "testing.Main(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "interestingGoroutines") || strings.Contains(stack, "runtime.MHeap_Scavenger") { continue } gs = append(gs, g) } sort.Strings(gs) return } // leakCheck snapshots the currently-running goroutines and returns a // function to be run at the end of tests to see whether any // goroutines leaked. func leakCheck(t testing.TB) func() { orig := map[string]bool{} for _, g := range interestingGoroutines() { orig[g] = true } return func() { // Loop, waiting for goroutines to shut down. // Wait up to 5 seconds, but finish as quickly as possible. deadline := time.Now().Add(5 * time.Second) for { var leaked []string for _, g := range interestingGoroutines() { if !orig[g] { leaked = append(leaked, g) } } if len(leaked) == 0 { return } if time.Now().Before(deadline) { time.Sleep(50 * time.Millisecond) continue } for _, g := range leaked { t.Errorf("Leaked goroutine: %v", g) } return } } }