diff --git a/README.md b/README.md index 9317546..16890ce 100644 --- a/README.md +++ b/README.md @@ -67,3 +67,10 @@ would not be set in your handlers. when it's accepted. For example, one connection can be either gRPC or REST, but not both. That is, we assume that a client connection is either used for gRPC or REST. + +* *Java gRPC Clients*: Java gRPC client blocks until it receives a SETTINGS +frame from the server. If you are using the Java client to connect to a cmux'ed +gRPC server please match with writers: +```go +grpcl := m.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc")) +``` diff --git a/cmux.go b/cmux.go index 8caf5d4..f92e203 100644 --- a/cmux.go +++ b/cmux.go @@ -10,6 +10,9 @@ import ( // Matcher matches a connection based on its content. type Matcher func(io.Reader) bool +// MatchWriter is a match that can also write response (say to do handshake). +type MatchWriter func(io.Writer, io.Reader) bool + // ErrorHandler handles an error and returns whether // the mux should continue serving the listener. type ErrorHandler func(error) bool @@ -60,6 +63,14 @@ type CMux interface { // // The order used to call Match determines the priority of matchers. Match(...Matcher) net.Listener + // MatchWithWriters returns a net.Listener that accepts only the + // connections that matched by at least of the matcher writers. + // + // Prefer Matchers over MatchWriters, since the latter can write on the + // connection before the actual handler. + // + // The order used to call Match determines the priority of matchers. + MatchWithWriters(...MatchWriter) net.Listener // Serve starts multiplexing the listener. Serve blocks and perhaps // should be invoked concurrently within a go routine. Serve() error @@ -68,7 +79,7 @@ type CMux interface { } type matchersListener struct { - ss []Matcher + ss []MatchWriter l muxListener } @@ -80,7 +91,22 @@ type cMux struct { sls []matchersListener } +func matchersToMatchWriters(matchers []Matcher) []MatchWriter { + mws := make([]MatchWriter, 0, len(matchers)) + for _, m := range matchers { + mws = append(mws, func(w io.Writer, r io.Reader) bool { + return m(r) + }) + } + return mws +} + func (m *cMux) Match(matchers ...Matcher) net.Listener { + mws := matchersToMatchWriters(matchers) + return m.MatchWithWriters(mws...) +} + +func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { ml := muxListener{ Listener: m.root, connc: make(chan net.Conn, m.bufLen), @@ -125,7 +151,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { muc := newMuxConn(c) for _, sl := range m.sls { for _, s := range sl.ss { - matched := s(muc.startSniffing()) + matched := s(muc.Conn, muc.startSniffing()) if matched { muc.doneSniffing() select { diff --git a/matchers.go b/matchers.go index abc30f6..b5471a1 100644 --- a/matchers.go +++ b/matchers.go @@ -94,7 +94,16 @@ func HTTP1HeaderField(name, value string) Matcher { // headers frame. func HTTP2HeaderField(name, value string) Matcher { return func(r io.Reader) bool { - return matchHTTP2Field(r, name, value) + return matchHTTP2Field(ioutil.Discard, r, name, value) + } +} + +// HTTP2MatchHeaderFieldSendSettings matches the header field and writes the +// settings to the server. Prefer HTTP2HeaderField over this one, if the client +// does not block on receiving a SETTING frame. +func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter { + return func(w io.Writer, r io.Reader) bool { + return matchHTTP2Field(w, r, name, value) } } @@ -116,12 +125,12 @@ func matchHTTP1Field(r io.Reader, name, value string) (matched bool) { return req.Header.Get(name) == value } -func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { +func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool) { if !hasHTTP2Preface(r) { return false } - framer := http2.NewFramer(ioutil.Discard, r) + framer := http2.NewFramer(w, r) hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { if hf.Name == name && hf.Value == value { matched = true @@ -134,6 +143,10 @@ func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { } switch f := f.(type) { + case *http2.SettingsFrame: + if err := framer.WriteSettings(); err != nil { + return false + } case *http2.HeadersFrame: if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false