diff --git a/buffer.go b/buffer.go index f8cf30a..014e500 100644 --- a/buffer.go +++ b/buffer.go @@ -29,6 +29,7 @@ type bufferedReader struct { buffer bytes.Buffer bufferRead int bufferSize int + stagePoint int sniffing bool lastErr error } @@ -65,3 +66,11 @@ func (s *bufferedReader) reset(snif bool) { s.bufferRead = 0 s.bufferSize = s.buffer.Len() } + +func (s *bufferedReader) newStage() { + s.stagePoint = s.buffer.Len() +} + +func (s *bufferedReader) discard() { + s.buffer.Truncate(s.stagePoint) +} diff --git a/matchers.go b/matchers.go index 878ae98..625124b 100644 --- a/matchers.go +++ b/matchers.go @@ -217,11 +217,28 @@ func matchHTTP1Field(r io.Reader, name string, matches func(string) bool) (match return matches(req.Header.Get(name)) } +type stageReader interface { + io.Reader + newStage() + discard() +} + +type nonStageReader struct{ io.Reader } + +func (nonStageReader) newStage() {} +func (nonStageReader) discard() {} + func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) bool) (matched bool) { if !hasHTTP2Preface(r) { return false } + sr := (stageReader)(nil) + if sr, _ = r.(stageReader); sr == nil { + sr = nonStageReader{} + } + + waitAcks := 0 done := false framer := http2.NewFramer(w, r) hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { @@ -233,6 +250,7 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) } }) for { + sr.newStage() f, err := framer.ReadFrame() if err != nil { return false @@ -243,11 +261,15 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) // Sender acknoweldged the SETTINGS frame. No need to write // SETTINGS again. if f.IsAck() { + // Avoid causing golang.org/x/net/http2.serverConn.unackedSettings PROTOCOL_ERROR + sr.discard() + waitAcks-- break } if err := framer.WriteSettings(); err != nil { return false } + waitAcks++ case *http2.ContinuationFrame: if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false @@ -260,7 +282,7 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 } - if done { + if done && waitAcks == 0 { return matched } }