From d83a667cb2ae6cc19ab06d923144de80c9c4c78f Mon Sep 17 00:00:00 2001 From: Soheil Hassas Yeganeh Date: Fri, 15 Apr 2016 19:16:33 -0400 Subject: [PATCH] Add Matchers that can write back on the channel As reported in issue #22 reports that Java gRPC clients cannot handshake with cmux'ed gRPC server, since the client does not immediately send a header with the content-type field. The reason is that the java client, block on receiving the first SETTING frame. Add MatchWriter that can match and write on the connection. Implement a MatchWriter that writes a SETTING frame once it receives a SETTING frame. --- cmux.go | 30 ++++++++++++++++++++++++++++-- matchers.go | 19 ++++++++++++++++--- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/cmux.go b/cmux.go index 8caf5d4..f92e203 100644 --- a/cmux.go +++ b/cmux.go @@ -10,6 +10,9 @@ import ( // Matcher matches a connection based on its content. type Matcher func(io.Reader) bool +// MatchWriter is a match that can also write response (say to do handshake). +type MatchWriter func(io.Writer, io.Reader) bool + // ErrorHandler handles an error and returns whether // the mux should continue serving the listener. type ErrorHandler func(error) bool @@ -60,6 +63,14 @@ type CMux interface { // // The order used to call Match determines the priority of matchers. Match(...Matcher) net.Listener + // MatchWithWriters returns a net.Listener that accepts only the + // connections that matched by at least of the matcher writers. + // + // Prefer Matchers over MatchWriters, since the latter can write on the + // connection before the actual handler. + // + // The order used to call Match determines the priority of matchers. + MatchWithWriters(...MatchWriter) net.Listener // Serve starts multiplexing the listener. Serve blocks and perhaps // should be invoked concurrently within a go routine. Serve() error @@ -68,7 +79,7 @@ type CMux interface { } type matchersListener struct { - ss []Matcher + ss []MatchWriter l muxListener } @@ -80,7 +91,22 @@ type cMux struct { sls []matchersListener } +func matchersToMatchWriters(matchers []Matcher) []MatchWriter { + mws := make([]MatchWriter, 0, len(matchers)) + for _, m := range matchers { + mws = append(mws, func(w io.Writer, r io.Reader) bool { + return m(r) + }) + } + return mws +} + func (m *cMux) Match(matchers ...Matcher) net.Listener { + mws := matchersToMatchWriters(matchers) + return m.MatchWithWriters(mws...) +} + +func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { ml := muxListener{ Listener: m.root, connc: make(chan net.Conn, m.bufLen), @@ -125,7 +151,7 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { muc := newMuxConn(c) for _, sl := range m.sls { for _, s := range sl.ss { - matched := s(muc.startSniffing()) + matched := s(muc.Conn, muc.startSniffing()) if matched { muc.doneSniffing() select { diff --git a/matchers.go b/matchers.go index abc30f6..b5471a1 100644 --- a/matchers.go +++ b/matchers.go @@ -94,7 +94,16 @@ func HTTP1HeaderField(name, value string) Matcher { // headers frame. func HTTP2HeaderField(name, value string) Matcher { return func(r io.Reader) bool { - return matchHTTP2Field(r, name, value) + return matchHTTP2Field(ioutil.Discard, r, name, value) + } +} + +// HTTP2MatchHeaderFieldSendSettings matches the header field and writes the +// settings to the server. Prefer HTTP2HeaderField over this one, if the client +// does not block on receiving a SETTING frame. +func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter { + return func(w io.Writer, r io.Reader) bool { + return matchHTTP2Field(w, r, name, value) } } @@ -116,12 +125,12 @@ func matchHTTP1Field(r io.Reader, name, value string) (matched bool) { return req.Header.Get(name) == value } -func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { +func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool) { if !hasHTTP2Preface(r) { return false } - framer := http2.NewFramer(ioutil.Discard, r) + framer := http2.NewFramer(w, r) hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { if hf.Name == name && hf.Value == value { matched = true @@ -134,6 +143,10 @@ func matchHTTP2Field(r io.Reader, name, value string) (matched bool) { } switch f := f.(type) { + case *http2.SettingsFrame: + if err := framer.WriteSettings(); err != nil { + return false + } case *http2.HeadersFrame: if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false