mirror of
https://github.com/soheilhy/cmux.git
synced 2025-10-17 20:58:14 +08:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
e09e9389d8 | ||
|
444ce56efe | ||
|
cfc68f9888 | ||
|
e96bd75f84 | ||
|
be5b383fd5 | ||
|
b9e684ba4e | ||
|
bb79a83465 | ||
|
7e08502c7a | ||
|
0c129dc694 | ||
|
34a8ab6cda | ||
|
3b204bab2a | ||
|
9a3402ad7a | ||
|
4f90533583 | ||
|
8cd60510aa | ||
|
f671b41193 | ||
|
885b8d8a14 | ||
|
0068a46c9c |
@@ -32,7 +32,7 @@ httpS := &http.Server{
|
||||
}
|
||||
|
||||
trpcS := rpc.NewServer()
|
||||
s.Register(&ExampleRPCRcvr{})
|
||||
trpcS.Register(&ExampleRPCRcvr{})
|
||||
|
||||
// Use the muxed listeners for your servers.
|
||||
go grpcS.Serve(grpcL)
|
||||
|
3
cmux.go
3
cmux.go
@@ -116,8 +116,9 @@ type cMux struct {
|
||||
func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
|
||||
mws := make([]MatchWriter, 0, len(matchers))
|
||||
for _, m := range matchers {
|
||||
cm := m
|
||||
mws = append(mws, func(w io.Writer, r io.Reader) bool {
|
||||
return m(r)
|
||||
return cm(r)
|
||||
})
|
||||
}
|
||||
return mws
|
||||
|
141
cmux_test.go
141
cmux_test.go
@@ -16,14 +16,19 @@ package cmux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -76,7 +81,7 @@ func (l *chanListener) Accept() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func testListener(t *testing.T) (net.Listener, func()) {
|
||||
l, err := net.Listen("tcp4", ":0")
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -128,8 +133,57 @@ func runTestHTTPServer(errCh chan<- error, l net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateTLSCert(t *testing.T) {
|
||||
err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTLSCert(t *testing.T) {
|
||||
err := os.Remove("cert.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = os.Remove("key.pem")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func runTestTLSServer(errCh chan<- error, l net.Listener) {
|
||||
certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
log.Printf("1")
|
||||
return
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
Rand: rand.Reader,
|
||||
}
|
||||
|
||||
tlsl := tls.NewListener(l, config)
|
||||
runTestHTTPServer(errCh, tlsl)
|
||||
}
|
||||
|
||||
func runTestHTTP1Client(t *testing.T, addr net.Addr) {
|
||||
r, err := http.Get("http://" + addr.String())
|
||||
runTestHTTPClient(t, "http", addr)
|
||||
}
|
||||
|
||||
func runTestTLSClient(t *testing.T, addr net.Addr) {
|
||||
runTestHTTPClient(t, "https", addr)
|
||||
}
|
||||
|
||||
func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
r, err := client.Get(proto + "://" + addr.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -199,7 +253,7 @@ func TestTimeout(t *testing.T) {
|
||||
lis, Close := testListener(t)
|
||||
defer Close()
|
||||
result := make(chan int, 5)
|
||||
testDuration := time.Millisecond * 100
|
||||
testDuration := time.Millisecond * 500
|
||||
m := New(lis)
|
||||
m.SetReadTimeout(testDuration)
|
||||
http1 := m.Match(HTTP1Fast())
|
||||
@@ -345,6 +399,33 @@ func TestAny(t *testing.T) {
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
generateTLSCert(t)
|
||||
defer cleanupTLSCert(t)
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
l, cleanup := testListener(t)
|
||||
defer cleanup()
|
||||
|
||||
muxl := New(l)
|
||||
tlsl := muxl.Match(TLS())
|
||||
httpl := muxl.Match(Any())
|
||||
|
||||
go runTestTLSServer(errCh, tlsl)
|
||||
go runTestHTTPServer(errCh, httpl)
|
||||
go safeServe(errCh, muxl)
|
||||
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
runTestTLSClient(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestHTTP2(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@@ -397,6 +478,20 @@ func TestHTTP2(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||
testHTTP2MatchHeaderField(t, HTTP2HeaderField, "value", "value", "anothervalue")
|
||||
}
|
||||
|
||||
func TestHTTP2MatchHeaderFieldPrefix(t *testing.T) {
|
||||
testHTTP2MatchHeaderField(t, HTTP2HeaderFieldPrefix, "application/grpc+proto", "application/grpc", "application/json")
|
||||
}
|
||||
|
||||
func testHTTP2MatchHeaderField(
|
||||
t *testing.T,
|
||||
matcherConstructor func(string, string) Matcher,
|
||||
headerValue string,
|
||||
matchValue string,
|
||||
notMatchValue string,
|
||||
) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
@@ -407,7 +502,6 @@ func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
name := "name"
|
||||
value := "value"
|
||||
writer, reader := net.Pipe()
|
||||
go func() {
|
||||
if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
|
||||
@@ -415,7 +509,7 @@ func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil {
|
||||
if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
framer := http2.NewFramer(writer, nil)
|
||||
@@ -443,9 +537,9 @@ func TestHTTP2MatchHeaderField(t *testing.T) {
|
||||
return false
|
||||
})
|
||||
// Create a matcher that cannot match the response.
|
||||
muxl.Match(HTTP2HeaderField(name, "another"+value))
|
||||
muxl.Match(matcherConstructor(name, notMatchValue))
|
||||
// Then match with the expected field.
|
||||
h2l := muxl.Match(HTTP2HeaderField(name, value))
|
||||
h2l := muxl.Match(matcherConstructor(name, matchValue))
|
||||
go safeServe(errCh, muxl)
|
||||
muxedConn, err := h2l.Accept()
|
||||
close(l.connCh)
|
||||
@@ -528,6 +622,35 @@ func TestErrorHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleMatchers(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
defer func() {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
}()
|
||||
l, cleanup := testListener(t)
|
||||
defer cleanup()
|
||||
|
||||
matcher := func(r io.Reader) bool {
|
||||
return true
|
||||
}
|
||||
unmatcher := func(r io.Reader) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
muxl := New(l)
|
||||
lis := muxl.Match(unmatcher, matcher, unmatcher)
|
||||
|
||||
go runTestHTTPServer(errCh, lis)
|
||||
go safeServe(errCh, muxl)
|
||||
|
||||
runTestHTTP1Client(t, l.Addr())
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
errCh := make(chan error)
|
||||
@@ -565,7 +688,9 @@ func TestClose(t *testing.T) {
|
||||
if err != ErrListenerClosed {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c2.Read([]byte{}); err != io.ErrClosedPipe {
|
||||
// The error is either io.ErrClosedPipe or net.OpError wrapping
|
||||
// a net.pipeError depending on the go version.
|
||||
if _, err := c2.Read([]byte{}); !strings.Contains(err.Error(), "closed") {
|
||||
t.Fatalf("connection is not closed and is leaked: %v", err)
|
||||
}
|
||||
}
|
||||
|
@@ -112,7 +112,7 @@ func Example() {
|
||||
|
||||
// We first match the connection against HTTP2 fields. If matched, the
|
||||
// connection will be sent through the "grpcl" listener.
|
||||
grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
|
||||
grpcl := m.Match(cmux.HTTP2HeaderFieldPrefix("content-type", "application/grpc"))
|
||||
//Otherwise, we match it againts a websocket upgrade request.
|
||||
wsl := m.Match(cmux.HTTP1HeaderField("Upgrade", "websocket"))
|
||||
|
||||
|
87
matchers.go
87
matchers.go
@@ -16,6 +16,7 @@ package cmux
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -37,6 +38,11 @@ func PrefixMatcher(strs ...string) Matcher {
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
func prefixByteMatcher(list ...[]byte) Matcher {
|
||||
pt := newPatriciaTree(list...)
|
||||
return pt.matchPrefix
|
||||
}
|
||||
|
||||
var defaultHTTPMethods = []string{
|
||||
"OPTIONS",
|
||||
"GET",
|
||||
@@ -57,6 +63,27 @@ func HTTP1Fast(extMethods ...string) Matcher {
|
||||
return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
|
||||
}
|
||||
|
||||
// TLS matches HTTPS requests.
|
||||
//
|
||||
// By default, any TLS handshake packet is matched. An optional whitelist
|
||||
// of versions can be passed in to restrict the matcher, for example:
|
||||
// TLS(tls.VersionTLS11, tls.VersionTLS12)
|
||||
func TLS(versions ...int) Matcher {
|
||||
if len(versions) == 0 {
|
||||
versions = []int{
|
||||
tls.VersionSSL30,
|
||||
tls.VersionTLS10,
|
||||
tls.VersionTLS11,
|
||||
tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
prefixes := [][]byte{}
|
||||
for _, v := range versions {
|
||||
prefixes = append(prefixes, []byte{22, byte(v >> 8 & 0xff), byte(v & 0xff)})
|
||||
}
|
||||
return prefixByteMatcher(prefixes...)
|
||||
}
|
||||
|
||||
const maxHTTPRead = 4096
|
||||
|
||||
// HTTP1 parses the first line or upto 4096 bytes of the request to see if
|
||||
@@ -100,15 +127,41 @@ func HTTP2() Matcher {
|
||||
// request of an HTTP 1 connection.
|
||||
func HTTP1HeaderField(name, value string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP1Field(r, name, value)
|
||||
return matchHTTP1Field(r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderField resturns a matcher matching the header fields of the first
|
||||
// HTTP1HeaderFieldPrefix returns a matcher matching the header fields of the
|
||||
// first request of an HTTP 1 connection. If the header with key name has a
|
||||
// value prefixed with valuePrefix, this will match.
|
||||
func HTTP1HeaderFieldPrefix(name, valuePrefix string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP1Field(r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderField returns a matcher matching the header fields of the first
|
||||
// headers frame.
|
||||
func HTTP2HeaderField(name, value string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, value)
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2HeaderFieldPrefix returns a matcher matching the header fields of the
|
||||
// first headers frame. If the header with key name has a value prefixed with
|
||||
// valuePrefix, this will match.
|
||||
func HTTP2HeaderFieldPrefix(name, valuePrefix string) Matcher {
|
||||
return func(r io.Reader) bool {
|
||||
return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +170,20 @@ func HTTP2HeaderField(name, value string) Matcher {
|
||||
// does not block on receiving a SETTING frame.
|
||||
func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter {
|
||||
return func(w io.Writer, r io.Reader) bool {
|
||||
return matchHTTP2Field(w, r, name, value)
|
||||
return matchHTTP2Field(w, r, name, func(gotValue string) bool {
|
||||
return gotValue == value
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP2MatchHeaderFieldPrefixSendSettings matches the header field prefix
|
||||
// and writes the settings to the server. Prefer HTTP2HeaderFieldPrefix over
|
||||
// this one, if the client does not block on receiving a SETTING frame.
|
||||
func HTTP2MatchHeaderFieldPrefixSendSettings(name, valuePrefix string) MatchWriter {
|
||||
return func(w io.Writer, r io.Reader) bool {
|
||||
return matchHTTP2Field(w, r, name, func(gotValue string) bool {
|
||||
return strings.HasPrefix(gotValue, valuePrefix)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,16 +208,16 @@ func hasHTTP2Preface(r io.Reader) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
|
||||
func matchHTTP1Field(r io.Reader, name string, matches func(string) bool) (matched bool) {
|
||||
req, err := http.ReadRequest(bufio.NewReader(r))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return req.Header.Get(name) == value
|
||||
return matches(req.Header.Get(name))
|
||||
}
|
||||
|
||||
func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool) {
|
||||
func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) bool) (matched bool) {
|
||||
if !hasHTTP2Preface(r) {
|
||||
return false
|
||||
}
|
||||
@@ -161,7 +227,7 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool
|
||||
hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
|
||||
if hf.Name == name {
|
||||
done = true
|
||||
if hf.Value == value {
|
||||
if matches(hf.Value) {
|
||||
matched = true
|
||||
}
|
||||
}
|
||||
@@ -174,6 +240,11 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool
|
||||
|
||||
switch f := f.(type) {
|
||||
case *http2.SettingsFrame:
|
||||
// Sender acknoweldged the SETTINGS frame. No need to write
|
||||
// SETTINGS again.
|
||||
if f.IsAck() {
|
||||
break
|
||||
}
|
||||
if err := framer.WriteSettings(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
Reference in New Issue
Block a user