diff --git a/cmux_test.go b/cmux_test.go index 531aae3..d7ddca4 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -15,6 +15,8 @@ import ( "sync/atomic" "testing" "time" + + "golang.org/x/net/http2" ) const ( @@ -229,6 +231,51 @@ func TestAny(t *testing.T) { runTestHTTP1Client(t, l.Addr()) } +func TestHTTP2(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher that only reads one byte. + muxl.Match(func(r io.Reader) bool { + var b [1]byte + _, _ = r.Read(b[:]) + return false + }) + h2l := muxl.Match(HTTP2()) + go safeServe(errCh, muxl) + close(l.connCh) + if muxedConn, err := h2l.Accept(); err != nil { + t.Fatal(err) + } else { + var b [len(http2.ClientPreface)]byte + if _, err := muxedConn.Read(b[:]); err != io.EOF { + t.Fatal(err) + } + if string(b[:]) != http2.ClientPreface { + t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) + } + } +} + func TestHTTPGoRPC(t *testing.T) { defer leakCheck(t)() errCh := make(chan error) diff --git a/matchers.go b/matchers.go index 5f1024b..abc30f6 100644 --- a/matchers.go +++ b/matchers.go @@ -2,7 +2,6 @@ package cmux import ( "bufio" - "bytes" "io" "io/ioutil" "net/http" @@ -77,8 +76,6 @@ func parseRequestLine(line string) (method, uri, proto string, ok bool) { return line[:s1], line[s1+1 : s2], line[s2+1:], true } -var http2Preface = []byte(http2.ClientPreface) - // HTTP2 parses the frame header of the first frame to detect whether the // connection is an HTTP2 connection. func HTTP2() Matcher { @@ -102,14 +99,12 @@ func HTTP2HeaderField(name, value string) Matcher { } func hasHTTP2Preface(r io.Reader) bool { - b := make([]byte, len(http2Preface)) - n, err := r.Read(b) - if err != nil { + var b [len(http2.ClientPreface)]byte + if _, err := io.ReadFull(r, b[:]); err != nil { return false } - b = b[:n] - return bytes.Equal(b, http2Preface) + return string(b[:]) == http2.ClientPreface } func matchHTTP1Field(r io.Reader, name, value string) (matched bool) {