2
0
mirror of https://github.com/soheilhy/cmux.git synced 2025-01-18 10:53:46 +08:00

Initial commit

This commit is contained in:
Soheil Hassas Yeganeh 2015-07-29 13:45:57 -04:00
commit 754f5b897d
9 changed files with 903 additions and 0 deletions

24
.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof

5
.travis.yml Normal file
View File

@ -0,0 +1,5 @@
language: go
go:
- 1.3
- 1.4

57
README.md Normal file
View File

@ -0,0 +1,57 @@
# cmux: Connection Mux ![Travis Build Status](https://api.travis-ci.org/soheilhy/args.svg?branch=master "Travis Build Status") ![GoDoc](https://godoc.org/github.com/soheilhy/cmux?status.png)
cmux is a generic Go library to multiplex connections based on
their content. Using cmux, one can serve gRPC, HTTP, and Go RPC
on the same TCP listener to avoid having to use one port per
protocol.
## How-To
Simply create your main listener, create a cmux for that listener,
and then match connections:
```go
// Create the main listener.
l, err := net.Listen("tcp", ":23456")
if err != nil {
log.Fatal(err)
}
// Create a cmux.
m := cmux.New(l)
// Match connections in order.
grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
httpl := m.Match(cmux.Any()) // Any means anything that is not yet matched.
// Create your protocol servers.
grpcS := grpc.NewServer()
pb.RegisterGreeterServer(grpcs, &server{})
httpS := &http.Server{
Handler: &testHTTP1Handler{},
}
// Use the muxed listeners for your servers.
go grpcS.Serve(grpcl)
go httpS.Serve(httpl)
// Start serving!
m.Serve()
```
Take a look at [other examples in the GoDoc](http://localhost:6060/pkg/github.com/soheilhy/cmux/#pkg-examples).
## Docs
* [GoDocs](https://godoc.org/github.com/soheilhy/cmux)
## Performance
There is a huge room for improvment but since we are only matching
the very first bytes of a connection, the performance overheads on
long-lived connections (i.e., RPCs and pipelined HTTP streams)
is negligible.
*TODO(soheil)*: Add benchmarks.
## Limitations
*TLS*: Since `cmux` sits in between the actual listener and the mux'ed
listeners, TLS handshake is not handled inside the actual servers.
Because of that, when you handle HTTPS using cmux `http.Request.TLS`
would not be set.

197
cmux.go Normal file
View File

@ -0,0 +1,197 @@
package cmux
import (
"bytes"
"flag"
"fmt"
"io"
"net"
)
// Matcher matches a connection based on its content.
type Matcher func(r io.Reader) (ok bool)
// ErrorHandler handles an error and returns whether
// the mux should continue serving the listener.
type ErrorHandler func(err error) (ok bool)
// ErrNotMatched is returned whenever a connection is not matched by any of
// the matchers registered in the multiplexer.
type ErrNotMatched struct {
c net.Conn
}
func (e ErrNotMatched) Error() string {
return fmt.Sprintf("mux: connection %v not matched by an matcher",
e.c.RemoteAddr())
}
func (e ErrNotMatched) Temporary() bool { return true }
func (e ErrNotMatched) Timeout() bool { return false }
type errListenerClosed string
func (e errListenerClosed) Error() string { return string(e) }
func (e errListenerClosed) Temporary() bool { return false }
func (e errListenerClosed) Timeout() bool { return false }
var (
ErrListenerClosed = errListenerClosed("mux: listener closed")
)
// New instantiates a new connection multiplexer.
func New(l net.Listener) CMux {
if !flag.Parsed() {
flag.Parse()
}
return &cMux{
root: l,
bufLen: 1024,
errh: func(err error) bool { return true },
}
}
// CMux is a multiplexer for network connections.
type CMux interface {
// Match returns a net.Listener that sees (i.e., accepts) only
// the connections matched by at least one of the matcher.
//
// The order used to call Match determines the priority of matchers.
Match(matchers ...Matcher) net.Listener
// Serve starts multiplexing the listener. Serve blocks and perhaps
// should be invoked concurrently within a go routine.
Serve() error
// HandleError registers an error handler that handles listener errors.
HandleError(h ErrorHandler)
}
type matchersListener struct {
ss []Matcher
l muxListener
}
type cMux struct {
root net.Listener
bufLen int
errh ErrorHandler
sls []matchersListener
}
func (m *cMux) Match(matchers ...Matcher) (l net.Listener) {
ml := muxListener{
Listener: m.root,
cch: make(chan net.Conn, m.bufLen),
}
m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
return ml
}
func (m *cMux) Serve() error {
defer func() {
for _, sl := range m.sls {
close(sl.l.cch)
}
}()
for {
c, err := m.root.Accept()
if err != nil {
if !m.handleErr(err) {
return err
}
continue
}
muc := newMuxConn(c)
matched := false
outer:
for _, sl := range m.sls {
for _, s := range sl.ss {
matched = s(muc.sniffer())
muc.reset()
if matched {
sl.l.cch <- muc
break outer
}
}
}
if !matched {
c.Close()
err := ErrNotMatched{c: c}
if !m.handleErr(err) {
return err
}
}
}
}
func (m *cMux) HandleError(h ErrorHandler) {
m.errh = h
}
func (m *cMux) handleErr(err error) bool {
if !m.errh(err) {
return false
}
if ne, ok := err.(net.Error); ok {
return ne.Temporary()
}
return false
}
type muxListener struct {
net.Listener
cch chan net.Conn
}
func (l muxListener) Accept() (c net.Conn, err error) {
c, ok := <-l.cch
if !ok {
return nil, ErrListenerClosed
}
return c, nil
}
type MuxConn struct {
net.Conn
prv *bytes.Buffer
nxt *bytes.Buffer
}
func newMuxConn(c net.Conn) *MuxConn {
return &MuxConn{
Conn: c,
prv: &bytes.Buffer{},
nxt: &bytes.Buffer{},
}
}
func (m *MuxConn) Read(b []byte) (n int, err error) {
if n, err = m.prv.Read(b); err == nil {
return
}
n, err = m.Conn.Read(b)
return
}
func (m *MuxConn) sniffer() io.Reader {
return io.MultiReader(io.TeeReader(m.prv, m.nxt), io.TeeReader(m.Conn, m.nxt))
}
func (m *MuxConn) reset() {
if m.nxt.Len() == 0 {
return
}
if m.prv.Len() != 0 {
io.Copy(m.nxt, m.prv)
}
m.prv, m.nxt = m.nxt, m.prv
m.nxt.Reset()
}

175
cmux_test.go Normal file
View File

@ -0,0 +1,175 @@
package cmux
import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/rpc"
"testing"
"github.com/bradfitz/http2"
)
const (
testHTTP1Resp = "http1"
rpcVal = 1234
)
var testPort = 5125
func testAddr() string {
testPort++
return fmt.Sprintf("127.0.0.1:%d", testPort)
}
func testListener(t *testing.T) (net.Listener, string) {
addr := testAddr()
l, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
return l, addr
}
type testHTTP1Handler struct{}
func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, testHTTP1Resp)
}
func runTestHTTPServer(l net.Listener, withHTTP2 bool) {
s := &http.Server{
Handler: &testHTTP1Handler{},
}
if withHTTP2 {
http2.ConfigureServer(s, &http2.Server{})
}
s.Serve(l)
}
func runTestHTTP1Client(t *testing.T, addr string) {
r, err := http.Get("http://" + addr)
if err != nil {
t.Fatal(err)
}
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error(err)
}
if string(b) != testHTTP1Resp {
t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b)
}
}
type TestRPCRcvr struct{}
func (r TestRPCRcvr) Test(i int, j *int) error {
*j = i
return nil
}
func runTestRPCServer(l net.Listener) {
s := rpc.NewServer()
s.Register(TestRPCRcvr{})
for {
c, err := l.Accept()
if err != nil {
return
}
s.ServeConn(c)
}
}
func runTestRPCClient(t *testing.T, addr string) {
c, err := rpc.Dial("tcp", addr)
if err != nil {
t.Error(err)
return
}
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
t.Error(err)
return
}
if num != rpcVal {
t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
}
}
func TestAny(t *testing.T) {
l, addr := testListener(t)
defer l.Close()
muxl := New(l)
httpl := muxl.Match(Any())
go runTestHTTPServer(httpl, false)
go muxl.Serve()
r, err := http.Get("http://" + addr)
if err != nil {
t.Fatal(err)
}
defer r.Body.Close()
b, err := ioutil.ReadAll(r.Body)
if string(b) != testHTTP1Resp {
t.Errorf("invalid response: want=%s got=%s", testHTTP1Resp, b)
}
}
func TestHTTPGoRPC(t *testing.T) {
l, addr := testListener(t)
defer l.Close()
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
rpcl := muxl.Match(Any())
go runTestHTTPServer(httpl, true)
go runTestRPCServer(rpcl)
go muxl.Serve()
runTestHTTP1Client(t, addr)
runTestRPCClient(t, addr)
}
func TestErrorHandler(t *testing.T) {
l, addr := testListener(t)
defer l.Close()
muxl := New(l)
httpl := muxl.Match(HTTP2(), HTTP1Fast())
go runTestHTTPServer(httpl, true)
go muxl.Serve()
firstErr := true
muxl.HandleError(func(err error) bool {
if !firstErr {
return true
}
if _, ok := err.(ErrNotMatched); !ok {
t.Errorf("unexpected error: %v", err)
}
firstErr = false
return true
})
c, err := rpc.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
var num int
if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
t.Error("rpc got a response")
}
}

81
example_test.go Normal file
View File

@ -0,0 +1,81 @@
package cmux_test
import (
"fmt"
"log"
"net"
"net/http"
"net/rpc"
"google.golang.org/grpc"
"golang.org/x/net/context"
grpchello "github.com/grpc/grpc-common/go/helloworld"
"github.com/soheilhy/cmux"
)
type exampleHTTPHandler struct{}
func (h *exampleHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "example http response")
}
func serveHTTP(l net.Listener) {
s := &http.Server{
Handler: &exampleHTTPHandler{},
}
s.Serve(l)
}
type ExampleRPCRcvr struct{}
func (r *ExampleRPCRcvr) Cube(i int, j *int) error {
*j = i * i
return nil
}
func serveRPC(l net.Listener) {
s := rpc.NewServer()
s.Register(&ExampleRPCRcvr{})
s.Accept(l)
}
type grpcServer struct{}
func (s *grpcServer) SayHello(ctx context.Context, in *grpchello.HelloRequest) (
*grpchello.HelloReply, error) {
return &grpchello.HelloReply{Message: "Hello " + in.Name + " from cmux"}, nil
}
func serveGRPC(l net.Listener) {
grpcs := grpc.NewServer()
grpchello.RegisterGreeterServer(grpcs, &grpcServer{})
grpcs.Serve(l)
}
func Example() {
l, err := net.Listen("tcp", "127.0.0.1:50051")
if err != nil {
log.Fatal(err)
}
m := cmux.New(l)
// We first match the connection against HTTP2 fields. If matched, the
// connection will be sent through the "grpcl" listener.
grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
// Otherwise, we match it againts HTTP1 methods and HTTP2. If matched by
// any of them, it is sent through the "httpl" listener.
httpl := m.Match(cmux.HTTP1Fast(), cmux.HTTP2())
// If not matched by HTTP, we assume it is an RPC connection.
rpcl := m.Match(cmux.Any())
// Then we used the muxed listeners.
go serveGRPC(grpcl)
go serveHTTP(httpl)
go serveRPC(rpcl)
m.Serve()
}

173
patricia.go Normal file
View File

@ -0,0 +1,173 @@
package cmux
import (
"bytes"
"io"
)
// patriciaTree is a simple patricia tree that handles []byte instead of string
// and cannot be changed after instantiation.
type patriciaTree struct {
root *ptNode
}
func newPatriciaTree(b ...[]byte) *patriciaTree {
return &patriciaTree{
root: newNode(b),
}
}
func newPatriciaTreeString(strs ...string) *patriciaTree {
b := make([][]byte, len(strs))
for i, s := range strs {
b[i] = []byte(s)
}
return &patriciaTree{
root: newNode(b),
}
}
func (t *patriciaTree) matchPrefix(r io.Reader) bool {
return t.root.match(r, true)
}
func (t *patriciaTree) match(r io.Reader) bool {
return t.root.match(r, false)
}
type ptNode struct {
prefix []byte
next map[byte]*ptNode
terminal bool
}
func newNode(strs [][]byte) *ptNode {
if len(strs) == 0 {
return &ptNode{
prefix: []byte{},
terminal: true,
}
}
if len(strs) == 1 {
return &ptNode{
prefix: strs[0],
terminal: true,
}
}
p, strs := splitPrefix(strs)
n := &ptNode{
prefix: p,
}
nexts := make(map[byte][][]byte)
for _, s := range strs {
if len(s) == 0 {
n.terminal = true
continue
}
nexts[s[0]] = append(nexts[s[0]], s[1:])
}
n.next = make(map[byte]*ptNode)
for first, rests := range nexts {
n.next[first] = newNode(rests)
}
return n
}
func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) {
if len(bss) == 0 || len(bss[0]) == 0 {
return prefix, bss
}
if len(bss) == 1 {
return bss[0], [][]byte{[]byte{}}
}
for i := 0; ; i++ {
var cur byte
eq := true
for j, b := range bss {
if len(b) <= i {
eq = false
break
}
if j == 0 {
cur = b[i]
continue
}
if cur != b[i] {
eq = false
break
}
}
if !eq {
break
}
prefix = append(prefix, cur)
}
rest = make([][]byte, 0, len(bss))
for _, b := range bss {
rest = append(rest, b[len(prefix):])
}
return prefix, rest
}
func readBytes(r io.Reader, n int) (b []byte, err error) {
b = make([]byte, n)
o := 0
for o < n {
nr, err := r.Read(b[o:])
if err != nil && err != io.EOF {
return b, err
}
o += nr
if err == io.EOF {
break
}
}
return b[:o], nil
}
func (n *ptNode) match(r io.Reader, prefix bool) bool {
if l := len(n.prefix); l > 0 {
b, err := readBytes(r, l)
if err != nil || len(b) != l || !bytes.Equal(b, n.prefix) {
return false
}
}
if prefix && n.terminal {
return true
}
b := make([]byte, 1)
for {
nr, err := r.Read(b)
if nr != 0 {
break
}
if err == io.EOF {
return n.terminal
}
if err != nil {
return false
}
}
nextN, ok := n.next[b[0]]
return ok && nextN.match(r, prefix)
}

35
patricia_test.go Normal file
View File

@ -0,0 +1,35 @@
package cmux
import (
"strings"
"testing"
)
func testPTree(t *testing.T, strs ...string) {
pt := newPatriciaTreeString(strs...)
for _, s := range strs {
if !pt.match(strings.NewReader(s)) {
t.Errorf("%s is not matched by %s", s, s)
}
if !pt.matchPrefix(strings.NewReader(s + s)) {
t.Errorf("%s is not matched as a prefix by %s", s+s, s)
}
if pt.match(strings.NewReader(s + s)) {
t.Errorf("%s matches %s", s+s, s)
}
}
}
func TestPatriciaOnePrefix(t *testing.T) {
testPTree(t, "prefix")
}
func TestPatriciaNonOverlapping(t *testing.T) {
testPTree(t, "foo", "bar", "dummy")
}
func TestPatriciaOverlapping(t *testing.T) {
testPTree(t, "foo", "far", "farther", "boo", "bar")
}

156
selectors.go Normal file
View File

@ -0,0 +1,156 @@
package cmux
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
)
// Any is a Matcher that matches any connection.
func Any() Matcher {
return func(r io.Reader) bool { return true }
}
// PrefixMatcher returns a matcher that matches a connection if it
// starts with any of the strings in strs.
func PrefixMatcher(strs ...string) Matcher {
pt := newPatriciaTreeString(strs...)
return func(r io.Reader) bool {
return pt.matchPrefix(r)
}
}
var defaultHTTPMethods = []string{
"OPTIONS",
"GET",
"HEAD",
"POST",
"PUT",
"DELETE",
"TRACE",
"CONNECT",
}
// HTTP1Fast only matches the methods in the HTTP request.
//
// This matcher is very optimistic: if it returns true, it does not mean that
// the request is a valid HTTP response. If you want a correct but slower HTTP1
// matcher, use HTTP1 instead.
func HTTP1Fast(extMethods ...string) Matcher {
return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...)
}
const (
maxHTTPRead = 4096
)
// HTTP1 parses the first line or upto 4096 bytes of the request to see if
// the conection contains an HTTP request.
func HTTP1() Matcher {
return func(r io.Reader) bool {
br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead})
l, part, err := br.ReadLine()
if err != nil || part {
return false
}
_, _, proto, ok := parseRequestLine(string(l))
if !ok {
return false
}
v, _, ok := http.ParseHTTPVersion(proto)
return ok && v == 1
}
}
// grabbed from net/http.
func parseRequestLine(line string) (method, uri, proto string, ok bool) {
s1 := strings.Index(line, " ")
s2 := strings.Index(line[s1+1:], " ")
if s1 < 0 || s2 < 0 {
return
}
s2 += s1 + 1
return line[:s1], line[s1+1 : s2], line[s2+1:], true
}
var (
http2Preface = []byte(http2.ClientPreface)
)
// HTTP2 parses the frame header of the first frame to detect whether the
// connection is an HTTP2 connection.
func HTTP2() Matcher {
return func(r io.Reader) bool {
return hasHTTP2Preface(r)
}
}
// HTTP1HeaderField returns a matcher matching the header fields of the first
// request of an HTTP 1 connection.
func HTTP1HeaderField(name, value string) Matcher {
return func(r io.Reader) bool {
return matchHTTP1Field(r, name, value)
}
}
// HTTP2HeaderField resturns a matcher matching the header fields of the first
// headers frame.
func HTTP2HeaderField(name, value string) Matcher {
return func(r io.Reader) bool {
return matchHTTP2Field(r, name, value)
}
}
func hasHTTP2Preface(r io.Reader) (ok bool) {
b := make([]byte, len(http2Preface))
n, err := r.Read(b)
if err != nil {
return false
}
b = b[:n]
return bytes.Equal(b, http2Preface)
}
func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {
return
}
func matchHTTP2Field(r io.Reader, name, value string) (matched bool) {
if !hasHTTP2Preface(r) {
return false
}
framer := http2.NewFramer(ioutil.Discard, r)
hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) {
if hf.Name == name && hf.Value == value {
matched = true
}
})
for {
f, err := framer.ReadFrame()
if err != nil {
return false
}
switch f := f.(type) {
case *http2.HeadersFrame:
hdec.Write(f.HeaderBlockFragment())
if matched {
return true
}
if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 {
return false
}
}
}
}