2
0
mirror of https://github.com/soheilhy/cmux.git synced 2025-01-19 03:06:07 +08:00

(*MuxConn).Read: fix erroneous io.EOF return

This fixes a bug where (*MuxConn).Read would return io.EOF if the
buffer was exactly the right size to fill the passed-in slice.
This commit is contained in:
Tamir Duberstein 2016-02-25 20:57:46 -05:00
parent d8cc6481fa
commit 49c66ff242
2 changed files with 25 additions and 12 deletions

View File

@ -199,12 +199,12 @@ func newMuxConn(c net.Conn) *MuxConn {
// //
// This function implements the latter behaviour, returning the // This function implements the latter behaviour, returning the
// (non-nil) error from the same call. // (non-nil) error from the same call.
func (m *MuxConn) Read(b []byte) (int, error) { func (m *MuxConn) Read(p []byte) (int, error) {
n1, err := m.buf.Read(b) n1, err := m.buf.Read(p)
if n1 == len(b) || err != io.EOF { if err != io.EOF {
return n1, err return n1, err
} }
n2, err := m.Conn.Read(b[n1:]) n2, err := m.Conn.Read(p[n1:])
return n1 + n2, err return n1 + n2, err
} }

View File

@ -183,27 +183,40 @@ func TestRead(t *testing.T) {
if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil { if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := writer.Close(); err != nil {
t.Fatal(err)
}
}() }()
l := newChanListener() l := newChanListener()
defer close(l.connCh) defer close(l.connCh)
l.connCh <- reader l.connCh <- reader
muxl := New(l) muxl := New(l)
// Register a bogus matcher to force reading from the conn. // Register a bogus matcher to force buffering exactly the right amount.
muxl.Match(HTTP2()) // Before this fix, this would trigger a bug where `Read` would incorrectly
// report `io.EOF` when only the buffer had been consumed.
muxl.Match(func(r io.Reader) bool {
var b [len(payload)]byte
_, _ = r.Read(b[:])
return false
})
anyl := muxl.Match(Any()) anyl := muxl.Match(Any())
go safeServe(errCh, muxl) go safeServe(errCh, muxl)
muxedConn, err := anyl.Accept() muxedConn, err := anyl.Accept()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var b [mult * len(payload)]byte for i := 0; i < mult; i++ {
n, err := muxedConn.Read(b[:]) var b [len(payload)]byte
if err != nil { if n, err := muxedConn.Read(b[:]); err != nil {
t.Fatal(err) t.Error(err)
} else if e := len(b); n != e {
t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
}
} }
if e := len(b); n != e { var b [1]byte
t.Errorf("expected to read %d bytes, but read %d bytes", e, n) if _, err := muxedConn.Read(b[:]); err != io.EOF {
t.Errorf("unexpected error %v, expected %v", err, io.EOF)
} }
} }