diff --git a/cmux_test.go b/cmux_test.go index 8746c0d..32221ed 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -16,14 +16,19 @@ package cmux import ( "bytes" + "crypto/rand" + "crypto/tls" "errors" "fmt" + "go/build" "io" "io/ioutil" "log" "net" "net/http" "net/rpc" + "os" + "os/exec" "runtime" "sort" "strings" @@ -128,8 +133,57 @@ func runTestHTTPServer(errCh chan<- error, l net.Listener) { } } +func generateTLSCert(t *testing.T) { + err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run() + if err != nil { + t.Fatal(err) + } +} + +func cleanupTLSCert(t *testing.T) { + err := os.Remove("cert.pem") + if err != nil { + t.Error(err) + } + err = os.Remove("key.pem") + if err != nil { + t.Error(err) + } +} + +func runTestTLSServer(errCh chan<- error, l net.Listener) { + certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem") + if err != nil { + errCh <- err + log.Printf("1") + return + } + + config := &tls.Config{ + Certificates: []tls.Certificate{certificate}, + Rand: rand.Reader, + } + + tlsl := tls.NewListener(l, config) + runTestHTTPServer(errCh, tlsl) +} + func runTestHTTP1Client(t *testing.T, addr net.Addr) { - r, err := http.Get("http://" + addr.String()) + runTestHTTPClient(t, "http", addr) +} + +func runTestTLSClient(t *testing.T, addr net.Addr) { + runTestHTTPClient(t, "https", addr) +} + +func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) { + client := http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + r, err := client.Get(proto + "://" + addr.String()) if err != nil { t.Fatal(err) } @@ -345,6 +399,33 @@ func TestAny(t *testing.T) { runTestHTTP1Client(t, l.Addr()) } +func TestTLS(t *testing.T) { + generateTLSCert(t) + defer cleanupTLSCert(t) + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + l, cleanup := testListener(t) + defer cleanup() + + muxl := New(l) + tlsl := muxl.Match(TLS()) + httpl := muxl.Match(Any()) + + go runTestTLSServer(errCh, tlsl) + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, muxl) + + runTestHTTP1Client(t, l.Addr()) + runTestTLSClient(t, l.Addr()) +} + func TestHTTP2(t *testing.T) { defer leakCheck(t)() errCh := make(chan error) diff --git a/matchers.go b/matchers.go index cfc24c7..6ccd7a8 100644 --- a/matchers.go +++ b/matchers.go @@ -16,6 +16,7 @@ package cmux import ( "bufio" + "crypto/tls" "io" "io/ioutil" "net/http" @@ -37,6 +38,11 @@ func PrefixMatcher(strs ...string) Matcher { return pt.matchPrefix } +func prefixByteMatcher(list ...[]byte) Matcher { + pt := newPatriciaTree(list...) + return pt.matchPrefix +} + var defaultHTTPMethods = []string{ "OPTIONS", "GET", @@ -57,6 +63,27 @@ func HTTP1Fast(extMethods ...string) Matcher { return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...) } +// TLS matches HTTPS requests. +// +// By default, any TLS handshake packet is matched. An optional whitelist +// of versions can be passed in to restrict the matcher, for example: +// TLS(tls.VersionTLS11, tls.VersionTLS12) +func TLS(versions ...int) Matcher { + if len(versions) == 0 { + versions = []int{ + tls.VersionSSL30, + tls.VersionTLS10, + tls.VersionTLS11, + tls.VersionTLS12, + } + } + prefixes := [][]byte{} + for _, v := range versions { + prefixes = append(prefixes, []byte{22, byte(v >> 8 & 0xff), byte(v & 0xff)}) + } + return prefixByteMatcher(prefixes...) +} + const maxHTTPRead = 4096 // HTTP1 parses the first line or upto 4096 bytes of the request to see if