2
0
mirror of https://github.com/soheilhy/cmux.git synced 2024-09-20 02:55:46 +08:00

Merge pull request #20 from tamird/fix-read-again

(*MuxConn).Read: fix erroneous io.EOF return
This commit is contained in:
Soheil Hassas Yeganeh 2016-02-25 23:16:54 -05:00
commit d710784914
2 changed files with 25 additions and 12 deletions

View File

@ -199,12 +199,12 @@ func newMuxConn(c net.Conn) *MuxConn {
//
// This function implements the latter behaviour, returning the
// (non-nil) error from the same call.
func (m *MuxConn) Read(b []byte) (int, error) {
n1, err := m.buf.Read(b)
if n1 == len(b) || err != io.EOF {
func (m *MuxConn) Read(p []byte) (int, error) {
n1, err := m.buf.Read(p)
if err != io.EOF {
return n1, err
}
n2, err := m.Conn.Read(b[n1:])
n2, err := m.Conn.Read(p[n1:])
return n1 + n2, err
}

View File

@ -183,27 +183,40 @@ func TestRead(t *testing.T) {
if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
}()
l := newChanListener()
defer close(l.connCh)
l.connCh <- reader
muxl := New(l)
// Register a bogus matcher to force reading from the conn.
muxl.Match(HTTP2())
// Register a bogus matcher to force buffering exactly the right amount.
// Before this fix, this would trigger a bug where `Read` would incorrectly
// report `io.EOF` when only the buffer had been consumed.
muxl.Match(func(r io.Reader) bool {
var b [len(payload)]byte
_, _ = r.Read(b[:])
return false
})
anyl := muxl.Match(Any())
go safeServe(errCh, muxl)
muxedConn, err := anyl.Accept()
if err != nil {
t.Fatal(err)
}
var b [mult * len(payload)]byte
n, err := muxedConn.Read(b[:])
if err != nil {
t.Fatal(err)
for i := 0; i < mult; i++ {
var b [len(payload)]byte
if n, err := muxedConn.Read(b[:]); err != nil {
t.Error(err)
} else if e := len(b); n != e {
t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
}
}
if e := len(b); n != e {
t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
var b [1]byte
if _, err := muxedConn.Read(b[:]); err != io.EOF {
t.Errorf("unexpected error %v, expected %v", err, io.EOF)
}
}