diff --git a/cmux_test.go b/cmux_test.go index bb919af..266fe6f 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -478,6 +478,20 @@ func TestHTTP2(t *testing.T) { } func TestHTTP2MatchHeaderField(t *testing.T) { + testHTTP2MatchHeaderField(t, HTTP2HeaderField, "value", "value", "anothervalue") +} + +func TestHTTP2MatchHeaderFieldPrefix(t *testing.T) { + testHTTP2MatchHeaderField(t, HTTP2HeaderFieldPrefix, "application/grpc+proto", "application/grpc", "application/json") +} + +func testHTTP2MatchHeaderField( + t *testing.T, + matcherConstructor func(string, string) Matcher, + headerValue string, + matchValue string, + notMatchValue string, +) { defer leakCheck(t)() errCh := make(chan error) defer func() { @@ -488,7 +502,6 @@ func TestHTTP2MatchHeaderField(t *testing.T) { } }() name := "name" - value := "value" writer, reader := net.Pipe() go func() { if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { @@ -496,7 +509,7 @@ func TestHTTP2MatchHeaderField(t *testing.T) { } var buf bytes.Buffer enc := hpack.NewEncoder(&buf) - if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil { + if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil { t.Fatal(err) } framer := http2.NewFramer(writer, nil) @@ -524,9 +537,9 @@ func TestHTTP2MatchHeaderField(t *testing.T) { return false }) // Create a matcher that cannot match the response. - muxl.Match(HTTP2HeaderField(name, "another"+value)) + muxl.Match(matcherConstructor(name, notMatchValue)) // Then match with the expected field. - h2l := muxl.Match(HTTP2HeaderField(name, value)) + h2l := muxl.Match(matcherConstructor(name, matchValue)) go safeServe(errCh, muxl) muxedConn, err := h2l.Accept() close(l.connCh) diff --git a/example_test.go b/example_test.go index 25396fa..7144f5e 100644 --- a/example_test.go +++ b/example_test.go @@ -112,7 +112,7 @@ func Example() { // We first match the connection against HTTP2 fields. If matched, the // connection will be sent through the "grpcl" listener. - grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + grpcl := m.Match(cmux.HTTP2HeaderFieldPrefix("content-type", "application/grpc")) //Otherwise, we match it againts a websocket upgrade request. wsl := m.Match(cmux.HTTP1HeaderField("Upgrade", "websocket")) diff --git a/matchers.go b/matchers.go index 6ccd7a8..652fd86 100644 --- a/matchers.go +++ b/matchers.go @@ -127,15 +127,41 @@ func HTTP2() Matcher { // request of an HTTP 1 connection. func HTTP1HeaderField(name, value string) Matcher { return func(r io.Reader) bool { - return matchHTTP1Field(r, name, value) + return matchHTTP1Field(r, name, func(gotValue string) bool { + return gotValue == value + }) } } -// HTTP2HeaderField resturns a matcher matching the header fields of the first +// HTTP1HeaderFieldPrefix returns a matcher matching the header fields of the +// first request of an HTTP 1 connection. If the header with key name has a +// value prefixed with valuePrefix, this will match. +func HTTP1HeaderFieldPrefix(name, valuePrefix string) Matcher { + return func(r io.Reader) bool { + return matchHTTP1Field(r, name, func(gotValue string) bool { + return strings.HasPrefix(gotValue, valuePrefix) + }) + } +} + +// HTTP2HeaderField returns a matcher matching the header fields of the first // headers frame. func HTTP2HeaderField(name, value string) Matcher { return func(r io.Reader) bool { - return matchHTTP2Field(ioutil.Discard, r, name, value) + return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool { + return gotValue == value + }) + } +} + +// HTTP2HeaderFieldPrefix returns a matcher matching the header fields of the +// first headers frame. If the header with key name has a value prefixed with +// valuePrefix, this will match. +func HTTP2HeaderFieldPrefix(name, valuePrefix string) Matcher { + return func(r io.Reader) bool { + return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool { + return strings.HasPrefix(gotValue, valuePrefix) + }) } } @@ -144,7 +170,20 @@ func HTTP2HeaderField(name, value string) Matcher { // does not block on receiving a SETTING frame. func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter { return func(w io.Writer, r io.Reader) bool { - return matchHTTP2Field(w, r, name, value) + return matchHTTP2Field(w, r, name, func(gotValue string) bool { + return gotValue == value + }) + } +} + +// HTTP2MatchHeaderFieldPrefixSendSettings matches the header field prefix +// and writes the settings to the server. Prefer HTTP2HeaderFieldPrefix over +// this one, if the client does not block on receiving a SETTING frame. +func HTTP2MatchHeaderFieldPrefixSendSettings(name, valuePrefix string) MatchWriter { + return func(w io.Writer, r io.Reader) bool { + return matchHTTP2Field(w, r, name, func(gotValue string) bool { + return strings.HasPrefix(gotValue, valuePrefix) + }) } } @@ -169,16 +208,16 @@ func hasHTTP2Preface(r io.Reader) bool { } } -func matchHTTP1Field(r io.Reader, name, value string) (matched bool) { +func matchHTTP1Field(r io.Reader, name string, matches func(string) bool) (matched bool) { req, err := http.ReadRequest(bufio.NewReader(r)) if err != nil { return false } - return req.Header.Get(name) == value + return matches(req.Header.Get(name)) } -func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool) { +func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) bool) (matched bool) { if !hasHTTP2Preface(r) { return false } @@ -188,7 +227,7 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { if hf.Name == name { done = true - if hf.Value == value { + if matches(hf.Value) { matched = true } }