diff --git a/cmux_test.go b/cmux_test.go index 2279ded..d348aa8 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -15,6 +15,7 @@ package cmux import ( + "bytes" "errors" "fmt" "io" @@ -32,6 +33,7 @@ import ( "time" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" ) const ( @@ -394,6 +396,72 @@ func TestHTTP2(t *testing.T) { } } +func TestHTTP2MatchHeaderField(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + name := "name" + value := "value" + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil { + t.Fatal(err) + } + framer := http2.NewFramer(writer, nil) + err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }) + if err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher that only reads one byte. + muxl.Match(func(r io.Reader) bool { + var b [1]byte + _, _ = r.Read(b[:]) + return false + }) + // Create a matcher that cannot match the response. + muxl.Match(HTTP2HeaderField(name, "another"+value)) + // Then match with the expected field. + h2l := muxl.Match(HTTP2HeaderField(name, value)) + go safeServe(errCh, muxl) + muxedConn, err := h2l.Accept() + close(l.connCh) + if err != nil { + t.Fatal(err) + } + var b [len(http2.ClientPreface)]byte + // We have the sniffed buffer first... + if _, err := muxedConn.Read(b[:]); err == io.EOF { + t.Fatal(err) + } + if string(b[:]) != http2.ClientPreface { + t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) + } +} + func TestHTTPGoRPC(t *testing.T) { defer leakCheck(t)() errCh := make(chan error) diff --git a/matchers.go b/matchers.go index 2e7428f..485ede8 100644 --- a/matchers.go +++ b/matchers.go @@ -144,10 +144,14 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool return false } + done := false framer := http2.NewFramer(w, r) hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { - if hf.Name == name && hf.Value == value { - matched = true + if hf.Name == name { + done = true + if hf.Value == value { + matched = true + } } }) for { @@ -161,17 +165,20 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool if err := framer.WriteSettings(); err != nil { return false } + case *http2.ContinuationFrame: + if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { + return false + } + done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 case *http2.HeadersFrame: if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false } - if matched { - return true - } + done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 + } - if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 { - return false - } + if done { + return matched } } }