diff --git a/cmux_test.go b/cmux_test.go index 8746c0d..32221ed 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -16,14 +16,19 @@ package cmux import ( "bytes" + "crypto/rand" + "crypto/tls" "errors" "fmt" + "go/build" "io" "io/ioutil" "log" "net" "net/http" "net/rpc" + "os" + "os/exec" "runtime" "sort" "strings" @@ -128,8 +133,57 @@ func runTestHTTPServer(errCh chan<- error, l net.Listener) { } } +func generateTLSCert(t *testing.T) { + err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run() + if err != nil { + t.Fatal(err) + } +} + +func cleanupTLSCert(t *testing.T) { + err := os.Remove("cert.pem") + if err != nil { + t.Error(err) + } + err = os.Remove("key.pem") + if err != nil { + t.Error(err) + } +} + +func runTestTLSServer(errCh chan<- error, l net.Listener) { + certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + errCh <- err + log.Printf("1") + return + } + + config := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + Rand: rand.Reader, + } + + tlsl := tls.NewListener(l, config) + runTestHTTPServer(errCh, tlsl) +} + func runTestHTTP1Client(t *testing.T, addr net.Addr) { - r, err := http.Get("http://" + addr.String()) + runTestHTTPClient(t, "http", addr) +} + +func runTestTLSClient(t *testing.T, addr net.Addr) { + runTestHTTPClient(t, "https", addr) +} + +func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) { + client := http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + r, err := client.Get(proto + "://" + addr.String()) if err != nil { t.Fatal(err) } @@ -345,6 +399,33 @@ func TestAny(t *testing.T) { runTestHTTP1Client(t, l.Addr()) } +func TestTLS(t *testing.T) { + generateTLSCert(t) + defer cleanupTLSCert(t) + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l, cleanup := testListener(t) + defer cleanup() + + muxl := New(l) + tlsl := muxl.Match(TLS()) + httpl := muxl.Match(Any()) + + go runTestTLSServer(errCh, tlsl) + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) + + runTestHTTP1Client(t, l.Addr()) + runTestTLSClient(t, l.Addr()) +} + func TestHTTP2(t *testing.T) { defer leakCheck(t)() errCh := make(chan error)