commit 754f5b897d311f0a05cfce27ef92d47e6fdd2c95 Author: Soheil Hassas Yeganeh Date: Wed Jul 29 13:45:57 2015 -0400 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..daf913b --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..4adcaca --- /dev/null +++ b/.travis.yml @@ -0,0 +1,5 @@ +language: go + +go: + - 1.3 + - 1.4 diff --git a/README.md b/README.md new file mode 100644 index 0000000..1fd6818 --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# cmux: Connection Mux ![Travis Build Status](https://api.travis-ci.org/soheilhy/args.svg?branch=master "Travis Build Status") ![GoDoc](https://godoc.org/github.com/soheilhy/cmux?status.png) +cmux is a generic Go library to multiplex connections based on +their content. Using cmux, one can serve gRPC, HTTP, and Go RPC +on the same TCP listener to avoid having to use one port per +protocol. + +## How-To +Simply create your main listener, create a cmux for that listener, +and then match connections: +```go +// Create the main listener. +l, err := net.Listen("tcp", ":23456") +if err != nil { + log.Fatal(err) +} + +// Create a cmux. +m := cmux.New(l) + +// Match connections in order. +grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) +httpl := m.Match(cmux.Any()) // Any means anything that is not yet matched. + +// Create your protocol servers. +grpcS := grpc.NewServer() +pb.RegisterGreeterServer(grpcs, &server{}) + +httpS := &http.Server{ + Handler: &testHTTP1Handler{}, +} + +// Use the muxed listeners for your servers. +go grpcS.Serve(grpcl) +go httpS.Serve(httpl) + +// Start serving! +m.Serve() +``` + +Take a look at [other examples in the GoDoc](http://localhost:6060/pkg/github.com/soheilhy/cmux/#pkg-examples). + +## Docs +* [GoDocs](https://godoc.org/github.com/soheilhy/cmux) + +## Performance +There is a huge room for improvment but since we are only matching +the very first bytes of a connection, the performance overheads on +long-lived connections (i.e., RPCs and pipelined HTTP streams) +is negligible. + +*TODO(soheil)*: Add benchmarks. + +## Limitations +*TLS*: Since `cmux` sits in between the actual listener and the mux'ed +listeners, TLS handshake is not handled inside the actual servers. +Because of that, when you handle HTTPS using cmux `http.Request.TLS` +would not be set. diff --git a/cmux.go b/cmux.go new file mode 100644 index 0000000..8b38f30 --- /dev/null +++ b/cmux.go @@ -0,0 +1,197 @@ +package cmux + +import ( + "bytes" + "flag" + "fmt" + "io" + "net" +) + +// Matcher matches a connection based on its content. +type Matcher func(r io.Reader) (ok bool) + +// ErrorHandler handles an error and returns whether +// the mux should continue serving the listener. +type ErrorHandler func(err error) (ok bool) + +// ErrNotMatched is returned whenever a connection is not matched by any of +// the matchers registered in the multiplexer. +type ErrNotMatched struct { + c net.Conn +} + +func (e ErrNotMatched) Error() string { + return fmt.Sprintf("mux: connection %v not matched by an matcher", + e.c.RemoteAddr()) +} + +func (e ErrNotMatched) Temporary() bool { return true } +func (e ErrNotMatched) Timeout() bool { return false } + +type errListenerClosed string + +func (e errListenerClosed) Error() string { return string(e) } +func (e errListenerClosed) Temporary() bool { return false } +func (e errListenerClosed) Timeout() bool { return false } + +var ( + ErrListenerClosed = errListenerClosed("mux: listener closed") +) + +// New instantiates a new connection multiplexer. +func New(l net.Listener) CMux { + if !flag.Parsed() { + flag.Parse() + } + + return &cMux{ + root: l, + bufLen: 1024, + errh: func(err error) bool { return true }, + } +} + +// CMux is a multiplexer for network connections. +type CMux interface { + // Match returns a net.Listener that sees (i.e., accepts) only + // the connections matched by at least one of the matcher. + // + // The order used to call Match determines the priority of matchers. + Match(matchers ...Matcher) net.Listener + // Serve starts multiplexing the listener. Serve blocks and perhaps + // should be invoked concurrently within a go routine. + Serve() error + // HandleError registers an error handler that handles listener errors. + HandleError(h ErrorHandler) +} + +type matchersListener struct { + ss []Matcher + l muxListener +} + +type cMux struct { + root net.Listener + bufLen int + errh ErrorHandler + sls []matchersListener +} + +func (m *cMux) Match(matchers ...Matcher) (l net.Listener) { + ml := muxListener{ + Listener: m.root, + cch: make(chan net.Conn, m.bufLen), + } + m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) + return ml +} + +func (m *cMux) Serve() error { + defer func() { + for _, sl := range m.sls { + close(sl.l.cch) + } + }() + + for { + c, err := m.root.Accept() + if err != nil { + if !m.handleErr(err) { + return err + } + continue + } + + muc := newMuxConn(c) + matched := false + outer: + for _, sl := range m.sls { + for _, s := range sl.ss { + matched = s(muc.sniffer()) + muc.reset() + if matched { + sl.l.cch <- muc + break outer + } + } + } + + if !matched { + c.Close() + err := ErrNotMatched{c: c} + if !m.handleErr(err) { + return err + } + } + } +} + +func (m *cMux) HandleError(h ErrorHandler) { + m.errh = h +} + +func (m *cMux) handleErr(err error) bool { + if !m.errh(err) { + return false + } + + if ne, ok := err.(net.Error); ok { + return ne.Temporary() + } + + return false +} + +type muxListener struct { + net.Listener + cch chan net.Conn +} + +func (l muxListener) Accept() (c net.Conn, err error) { + c, ok := <-l.cch + if !ok { + return nil, ErrListenerClosed + } + return c, nil +} + +type MuxConn struct { + net.Conn + prv *bytes.Buffer + nxt *bytes.Buffer +} + +func newMuxConn(c net.Conn) *MuxConn { + return &MuxConn{ + Conn: c, + prv: &bytes.Buffer{}, + nxt: &bytes.Buffer{}, + } +} + +func (m *MuxConn) Read(b []byte) (n int, err error) { + if n, err = m.prv.Read(b); err == nil { + return + } + + n, err = m.Conn.Read(b) + return +} + +func (m *MuxConn) sniffer() io.Reader { + return io.MultiReader(io.TeeReader(m.prv, m.nxt), io.TeeReader(m.Conn, m.nxt)) +} + +func (m *MuxConn) reset() { + if m.nxt.Len() == 0 { + return + } + + if m.prv.Len() != 0 { + io.Copy(m.nxt, m.prv) + } + + m.prv, m.nxt = m.nxt, m.prv + m.nxt.Reset() +} diff --git a/cmux_test.go b/cmux_test.go new file mode 100644 index 0000000..343d294 --- /dev/null +++ b/cmux_test.go @@ -0,0 +1,175 @@ +package cmux + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "net/rpc" + "testing" + + "github.com/bradfitz/http2" +) + +const ( + testHTTP1Resp = "http1" + rpcVal = 1234 +) + +var testPort = 5125 + +func testAddr() string { + testPort++ + return fmt.Sprintf("127.0.0.1:%d", testPort) +} + +func testListener(t *testing.T) (net.Listener, string) { + addr := testAddr() + l, err := net.Listen("tcp", addr) + if err != nil { + t.Fatal(err) + } + return l, addr +} + +type testHTTP1Handler struct{} + +func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, testHTTP1Resp) +} + +func runTestHTTPServer(l net.Listener, withHTTP2 bool) { + s := &http.Server{ + Handler: &testHTTP1Handler{}, + } + if withHTTP2 { + http2.ConfigureServer(s, &http2.Server{}) + } + s.Serve(l) +} + +func runTestHTTP1Client(t *testing.T, addr string) { + r, err := http.Get("http://" + addr) + if err != nil { + t.Fatal(err) + } + + defer r.Body.Close() + b, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Error(err) + } + + if string(b) != testHTTP1Resp { + t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b) + } +} + +type TestRPCRcvr struct{} + +func (r TestRPCRcvr) Test(i int, j *int) error { + *j = i + return nil +} + +func runTestRPCServer(l net.Listener) { + s := rpc.NewServer() + s.Register(TestRPCRcvr{}) + + for { + c, err := l.Accept() + if err != nil { + return + } + s.ServeConn(c) + } +} + +func runTestRPCClient(t *testing.T, addr string) { + c, err := rpc.Dial("tcp", addr) + if err != nil { + t.Error(err) + return + } + + var num int + if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil { + t.Error(err) + return + } + + if num != rpcVal { + t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num) + } +} + +func TestAny(t *testing.T) { + l, addr := testListener(t) + defer l.Close() + + muxl := New(l) + httpl := muxl.Match(Any()) + + go runTestHTTPServer(httpl, false) + go muxl.Serve() + + r, err := http.Get("http://" + addr) + if err != nil { + t.Fatal(err) + } + + defer r.Body.Close() + b, err := ioutil.ReadAll(r.Body) + if string(b) != testHTTP1Resp { + t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b) + } +} + +func TestHTTPGoRPC(t *testing.T) { + l, addr := testListener(t) + defer l.Close() + + muxl := New(l) + httpl := muxl.Match(HTTP2(), HTTP1Fast()) + rpcl := muxl.Match(Any()) + + go runTestHTTPServer(httpl, true) + go runTestRPCServer(rpcl) + go muxl.Serve() + + runTestHTTP1Client(t, addr) + runTestRPCClient(t, addr) +} + +func TestErrorHandler(t *testing.T) { + l, addr := testListener(t) + defer l.Close() + + muxl := New(l) + httpl := muxl.Match(HTTP2(), HTTP1Fast()) + + go runTestHTTPServer(httpl, true) + go muxl.Serve() + + firstErr := true + muxl.HandleError(func(err error) bool { + if !firstErr { + return true + } + if _, ok := err.(ErrNotMatched); !ok { + t.Errorf("unexpected error: %v", err) + } + firstErr = false + return true + }) + + c, err := rpc.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + + var num int + if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil { + t.Error("rpc got a response") + } +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..7271421 --- /dev/null +++ b/example_test.go @@ -0,0 +1,81 @@ +package cmux_test + +import ( + "fmt" + "log" + "net" + "net/http" + "net/rpc" + + "google.golang.org/grpc" + + "golang.org/x/net/context" + + grpchello "github.com/grpc/grpc-common/go/helloworld" + "github.com/soheilhy/cmux" +) + +type exampleHTTPHandler struct{} + +func (h *exampleHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "example http response") +} + +func serveHTTP(l net.Listener) { + s := &http.Server{ + Handler: &exampleHTTPHandler{}, + } + s.Serve(l) +} + +type ExampleRPCRcvr struct{} + +func (r *ExampleRPCRcvr) Cube(i int, j *int) error { + *j = i * i + return nil +} + +func serveRPC(l net.Listener) { + s := rpc.NewServer() + s.Register(&ExampleRPCRcvr{}) + s.Accept(l) +} + +type grpcServer struct{} + +func (s *grpcServer) SayHello(ctx context.Context, in *grpchello.HelloRequest) ( + *grpchello.HelloReply, error) { + + return &grpchello.HelloReply{Message: "Hello " + in.Name + " from cmux"}, nil +} + +func serveGRPC(l net.Listener) { + grpcs := grpc.NewServer() + grpchello.RegisterGreeterServer(grpcs, &grpcServer{}) + grpcs.Serve(l) +} + +func Example() { + l, err := net.Listen("tcp", "127.0.0.1:50051") + if err != nil { + log.Fatal(err) + } + + m := cmux.New(l) + + // We first match the connection against HTTP2 fields. If matched, the + // connection will be sent through the "grpcl" listener. + grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + // Otherwise, we match it againts HTTP1 methods and HTTP2. If matched by + // any of them, it is sent through the "httpl" listener. + httpl := m.Match(cmux.HTTP1Fast(), cmux.HTTP2()) + // If not matched by HTTP, we assume it is an RPC connection. + rpcl := m.Match(cmux.Any()) + + // Then we used the muxed listeners. + go serveGRPC(grpcl) + go serveHTTP(httpl) + go serveRPC(rpcl) + + m.Serve() +} diff --git a/patricia.go b/patricia.go new file mode 100644 index 0000000..718f951 --- /dev/null +++ b/patricia.go @@ -0,0 +1,173 @@ +package cmux + +import ( + "bytes" + "io" +) + +// patriciaTree is a simple patricia tree that handles []byte instead of string +// and cannot be changed after instantiation. +type patriciaTree struct { + root *ptNode +} + +func newPatriciaTree(b ...[]byte) *patriciaTree { + return &patriciaTree{ + root: newNode(b), + } +} + +func newPatriciaTreeString(strs ...string) *patriciaTree { + b := make([][]byte, len(strs)) + for i, s := range strs { + b[i] = []byte(s) + } + return &patriciaTree{ + root: newNode(b), + } +} + +func (t *patriciaTree) matchPrefix(r io.Reader) bool { + return t.root.match(r, true) +} + +func (t *patriciaTree) match(r io.Reader) bool { + return t.root.match(r, false) +} + +type ptNode struct { + prefix []byte + next map[byte]*ptNode + terminal bool +} + +func newNode(strs [][]byte) *ptNode { + if len(strs) == 0 { + return &ptNode{ + prefix: []byte{}, + terminal: true, + } + } + + if len(strs) == 1 { + return &ptNode{ + prefix: strs[0], + terminal: true, + } + } + + p, strs := splitPrefix(strs) + n := &ptNode{ + prefix: p, + } + + nexts := make(map[byte][][]byte) + for _, s := range strs { + if len(s) == 0 { + n.terminal = true + continue + } + nexts[s[0]] = append(nexts[s[0]], s[1:]) + } + + n.next = make(map[byte]*ptNode) + for first, rests := range nexts { + n.next[first] = newNode(rests) + } + + return n +} + +func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) { + if len(bss) == 0 || len(bss[0]) == 0 { + return prefix, bss + } + + if len(bss) == 1 { + return bss[0], [][]byte{[]byte{}} + } + + for i := 0; ; i++ { + var cur byte + eq := true + for j, b := range bss { + if len(b) <= i { + eq = false + break + } + + if j == 0 { + cur = b[i] + continue + } + + if cur != b[i] { + eq = false + break + } + } + + if !eq { + break + } + + prefix = append(prefix, cur) + } + + rest = make([][]byte, 0, len(bss)) + for _, b := range bss { + rest = append(rest, b[len(prefix):]) + } + + return prefix, rest +} + +func readBytes(r io.Reader, n int) (b []byte, err error) { + b = make([]byte, n) + o := 0 + for o < n { + nr, err := r.Read(b[o:]) + if err != nil && err != io.EOF { + return b, err + } + + o += nr + + if err == io.EOF { + break + } + } + return b[:o], nil +} + +func (n *ptNode) match(r io.Reader, prefix bool) bool { + if l := len(n.prefix); l > 0 { + b, err := readBytes(r, l) + if err != nil || len(b) != l || !bytes.Equal(b, n.prefix) { + return false + } + } + + if prefix && n.terminal { + return true + } + + b := make([]byte, 1) + for { + nr, err := r.Read(b) + if nr != 0 { + break + } + + if err == io.EOF { + return n.terminal + } + + if err != nil { + return false + } + } + + nextN, ok := n.next[b[0]] + return ok && nextN.match(r, prefix) +} diff --git a/patricia_test.go b/patricia_test.go new file mode 100644 index 0000000..16b0f40 --- /dev/null +++ b/patricia_test.go @@ -0,0 +1,35 @@ +package cmux + +import ( + "strings" + "testing" +) + +func testPTree(t *testing.T, strs ...string) { + pt := newPatriciaTreeString(strs...) + for _, s := range strs { + if !pt.match(strings.NewReader(s)) { + t.Errorf("%s is not matched by %s", s, s) + } + + if !pt.matchPrefix(strings.NewReader(s + s)) { + t.Errorf("%s is not matched as a prefix by %s", s+s, s) + } + + if pt.match(strings.NewReader(s + s)) { + t.Errorf("%s matches %s", s+s, s) + } + } +} + +func TestPatriciaOnePrefix(t *testing.T) { + testPTree(t, "prefix") +} + +func TestPatriciaNonOverlapping(t *testing.T) { + testPTree(t, "foo", "bar", "dummy") +} + +func TestPatriciaOverlapping(t *testing.T) { + testPTree(t, "foo", "far", "farther", "boo", "bar") +} diff --git a/selectors.go b/selectors.go new file mode 100644 index 0000000..635579c --- /dev/null +++ b/selectors.go @@ -0,0 +1,156 @@ +package cmux + +import ( + "bufio" + "bytes" + "io" + "io/ioutil" + "net/http" + "strings" + + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" +) + +// Any is a Matcher that matches any connection. +func Any() Matcher { + return func(r io.Reader) bool { return true } +} + +// PrefixMatcher returns a matcher that matches a connection if it +// starts with any of the strings in strs. +func PrefixMatcher(strs ...string) Matcher { + pt := newPatriciaTreeString(strs...) + return func(r io.Reader) bool { + return pt.matchPrefix(r) + } +} + +var defaultHTTPMethods = []string{ + "OPTIONS", + "GET", + "HEAD", + "POST", + "PUT", + "DELETE", + "TRACE", + "CONNECT", +} + +// HTTP1Fast only matches the methods in the HTTP request. +// +// This matcher is very optimistic: if it returns true, it does not mean that +// the request is a valid HTTP response. If you want a correct but slower HTTP1 +// matcher, use HTTP1 instead. +func HTTP1Fast(extMethods ...string) Matcher { + return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...) +} + +const ( + maxHTTPRead = 4096 +) + +// HTTP1 parses the first line or upto 4096 bytes of the request to see if +// the conection contains an HTTP request. +func HTTP1() Matcher { + return func(r io.Reader) bool { + br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead}) + l, part, err := br.ReadLine() + if err != nil || part { + return false + } + + _, _, proto, ok := parseRequestLine(string(l)) + if !ok { + return false + } + + v, _, ok := http.ParseHTTPVersion(proto) + return ok && v == 1 + } +} + +// grabbed from net/http. +func parseRequestLine(line string) (method, uri, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +var ( + http2Preface = []byte(http2.ClientPreface) +) + +// HTTP2 parses the frame header of the first frame to detect whether the +// connection is an HTTP2 connection. +func HTTP2() Matcher { + return func(r io.Reader) bool { + return hasHTTP2Preface(r) + } +} + +// HTTP1HeaderField returns a matcher matching the header fields of the first +// request of an HTTP 1 connection. +func HTTP1HeaderField(name, value string) Matcher { + return func(r io.Reader) bool { + return matchHTTP1Field(r, name, value) + } +} + +// HTTP2HeaderField resturns a matcher matching the header fields of the first +// headers frame. +func HTTP2HeaderField(name, value string) Matcher { + return func(r io.Reader) bool { + return matchHTTP2Field(r, name, value) + } +} + +func hasHTTP2Preface(r io.Reader) (ok bool) { + b := make([]byte, len(http2Preface)) + n, err := r.Read(b) + if err != nil { + return false + } + + b = b[:n] + return bytes.Equal(b, http2Preface) +} + +func matchHTTP1Field(r io.Reader, name, value string) (matched bool) { + return +} + +func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { + if !hasHTTP2Preface(r) { + return false + } + + framer := http2.NewFramer(ioutil.Discard, r) + hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { + if hf.Name == name && hf.Value == value { + matched = true + } + }) + for { + f, err := framer.ReadFrame() + if err != nil { + return false + } + + switch f := f.(type) { + case *http2.HeadersFrame: + hdec.Write(f.HeaderBlockFragment()) + if matched { + return true + } + + if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 { + return false + } + } + } +}