From 49c66ff242fc6f260163ab3cccf025c7e5181065 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Thu, 25 Feb 2016 20:57:46 -0500 Subject: [PATCH] (*MuxConn).Read: fix erroneous io.EOF return This fixes a bug where (*MuxConn).Read would return io.EOF if the buffer was exactly the right size to fill the passed-in slice. --- cmux.go | 8 ++++---- cmux_test.go | 29 +++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/cmux.go b/cmux.go index 3e899e7..e9a886d 100644 --- a/cmux.go +++ b/cmux.go @@ -199,12 +199,12 @@ func newMuxConn(c net.Conn) *MuxConn { // // This function implements the latter behaviour, returning the // (non-nil) error from the same call. -func (m *MuxConn) Read(b []byte) (int, error) { - n1, err := m.buf.Read(b) - if n1 == len(b) || err != io.EOF { +func (m *MuxConn) Read(p []byte) (int, error) { + n1, err := m.buf.Read(p) + if err != io.EOF { return n1, err } - n2, err := m.Conn.Read(b[n1:]) + n2, err := m.Conn.Read(p[n1:]) return n1 + n2, err } diff --git a/cmux_test.go b/cmux_test.go index 24d6249..669b713 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -183,27 +183,40 @@ func TestRead(t *testing.T) { if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil { t.Fatal(err) } + if err := writer.Close(); err != nil { + t.Fatal(err) + } }() l := newChanListener() defer close(l.connCh) l.connCh <- reader muxl := New(l) - // Register a bogus matcher to force reading from the conn. - muxl.Match(HTTP2()) + // Register a bogus matcher to force buffering exactly the right amount. + // Before this fix, this would trigger a bug where `Read` would incorrectly + // report `io.EOF` when only the buffer had been consumed. + muxl.Match(func(r io.Reader) bool { + var b [len(payload)]byte + _, _ = r.Read(b[:]) + return false + }) anyl := muxl.Match(Any()) go safeServe(errCh, muxl) muxedConn, err := anyl.Accept() if err != nil { t.Fatal(err) } - var b [mult * len(payload)]byte - n, err := muxedConn.Read(b[:]) - if err != nil { - t.Fatal(err) + for i := 0; i < mult; i++ { + var b [len(payload)]byte + if n, err := muxedConn.Read(b[:]); err != nil { + t.Error(err) + } else if e := len(b); n != e { + t.Errorf("expected to read %d bytes, but read %d bytes", e, n) + } } - if e := len(b); n != e { - t.Errorf("expected to read %d bytes, but read %d bytes", e, n) + var b [1]byte + if _, err := muxedConn.Read(b[:]); err != io.EOF { + t.Errorf("unexpected error %v, expected %v", err, io.EOF) } }