mirror of
https://github.com/soheilhy/cmux.git
synced 2025-01-18 10:53:46 +08:00
commit
34a8ab6cda
83
cmux_test.go
83
cmux_test.go
@ -16,14 +16,19 @@ package cmux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -128,8 +133,57 @@ func runTestHTTPServer(errCh chan<- error, l net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateTLSCert(t *testing.T) {
|
||||
err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTLSCert(t *testing.T) {
|
||||
err := os.Remove("cert.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = os.Remove("key.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func runTestTLSServer(errCh chan<- error, l net.Listener) {
|
||||
certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
log.Printf("1")
|
||||
return
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
Rand: rand.Reader,
|
||||
}
|
||||
|
||||
tlsl := tls.NewListener(l, config)
|
||||
runTestHTTPServer(errCh, tlsl)
|
||||
}
|
||||
|
||||
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
|
||||
r, err := http.Get("http://" + addr.String())
|
||||
runTestHTTPClient(t, "http", addr)
|
||||
}
|
||||
|
||||
func runTestTLSClient(t *testing.T, addr net.Addr) {
|
||||
runTestHTTPClient(t, "https", addr)
|
||||
}
|
||||
|
||||
func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
r, err := client.Get(proto + "://" + addr.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -345,6 +399,33 @@ func TestAny(t *testing.T) {
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
generateTLSCert(t)
|
||||
defer cleanupTLSCert(t)
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
l, cleanup := testListener(t)
|
||||
defer cleanup()
|
||||
|
||||
muxl := New(l)
|
||||
tlsl := muxl.Match(TLS())
|
||||
httpl := muxl.Match(Any())
|
||||
|
||||
go runTestTLSServer(errCh, tlsl)
|
||||
go runTestHTTPServer(errCh, httpl)
|
||||
go safeServe(errCh, muxl)
|
||||
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
runTestTLSClient(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestHTTP2(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
|
27
matchers.go
27
matchers.go
@ -16,6 +16,7 @@ package cmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@ -37,6 +38,11 @@ func PrefixMatcher(strs ...string) Matcher {
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
func prefixByteMatcher(list ...[]byte) Matcher {
|
||||
pt := newPatriciaTree(list...)
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
var defaultHTTPMethods = []string{
|
||||
"OPTIONS",
|
||||
"GET",
|
||||
@ -57,6 +63,27 @@ func HTTP1Fast(extMethods ...string) Matcher {
|
||||
return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
|
||||
}
|
||||
|
||||
// TLS matches HTTPS requests.
|
||||
//
|
||||
// By default, any TLS handshake packet is matched. An optional whitelist
|
||||
// of versions can be passed in to restrict the matcher, for example:
|
||||
// TLS(tls.VersionTLS11, tls.VersionTLS12)
|
||||
func TLS(versions ...int) Matcher {
|
||||
if len(versions) == 0 {
|
||||
versions = []int{
|
||||
tls.VersionSSL30,
|
||||
tls.VersionTLS10,
|
||||
tls.VersionTLS11,
|
||||
tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
prefixes := [][]byte{}
|
||||
for _, v := range versions {
|
||||
prefixes = append(prefixes, []byte{22, byte(v >> 8 & 0xff), byte(v & 0xff)})
|
||||
}
|
||||
return prefixByteMatcher(prefixes...)
|
||||
}
|
||||
|
||||
const maxHTTPRead = 4096
|
||||
|
||||
// HTTP1 parses the first line or upto 4096 bytes of the request to see if
|
||||
|
Loading…
Reference in New Issue
Block a user