From 6490dea1995d1a4f0512071b0fb4a13046ea6746 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Sun, 21 Feb 2016 04:19:59 -0500 Subject: [PATCH] Reduce the number of calls needed to (*MuxConn).Read Also affects (*buffer).Read. --- buffer.go | 32 +++++++++++++++++---------- buffer_test.go | 44 +++++++++++++++++++++++++----------- cmux.go | 25 ++++++++++++++++----- cmux_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 31 deletions(-) diff --git a/buffer.go b/buffer.go index 50e206a..cb979a3 100644 --- a/buffer.go +++ b/buffer.go @@ -7,19 +7,27 @@ type buffer struct { data []byte } -func (b *buffer) Read(p []byte) (n int, err error) { - n = len(b.data) - b.read - if n == 0 { - return 0, io.EOF - } - - if len(p) < n { - n = len(p) - } - - copy(p[:n], b.data[b.read:b.read+n]) +// From the io.Reader documentation: +// +// When Read encounters an error or end-of-file condition after +// successfully reading n > 0 bytes, it returns the number of +// bytes read. It may return the (non-nil) error from the same call +// or return the error (and n == 0) from a subsequent call. +// An instance of this general case is that a Reader returning +// a non-zero number of bytes at the end of the input stream may +// return either err == EOF or err == nil. The next Read should +// return 0, EOF. +// +// This function implements the latter behaviour, returning the +// (non-nil) error from the same call. +func (b *buffer) Read(p []byte) (int, error) { + var err error + n := copy(p, b.data[b.read:]) b.read += n - return + if b.read == len(b.data) { + err = io.EOF + } + return n, err } func (b *buffer) Len() int { diff --git a/buffer_test.go b/buffer_test.go index 8a7a564..ba1f00d 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -9,10 +9,12 @@ import ( func TestBuffer(t *testing.T) { writeBytes := []byte("deadbeef") + const numWrites = 10 + var b buffer - for i := 0; i < 10; i++ { + for i := 0; i < numWrites; i++ { n, err := b.Write(writeBytes) - if err != nil { + if err != nil && err != io.EOF { t.Fatal(err) } if n != len(writeBytes) { @@ -22,9 +24,14 @@ func TestBuffer(t *testing.T) { for j := 0; j < 2; j++ { readBytes := make([]byte, len(writeBytes)) - for i := 0; i < 10; i++ { + for i := 0; i < numWrites; i++ { n, err := b.Read(readBytes) - if err != nil { + if i == numWrites-1 { + // The last read should report EOF. + if err != io.EOF { + t.Fatal(err) + } + } else if err != nil { t.Fatal(err) } if n != len(readBytes) { @@ -34,10 +41,13 @@ func TestBuffer(t *testing.T) { t.Errorf("different bytes read: want=%d got=%d", writeBytes, readBytes) } } - _, err := b.Read(readBytes) + n, err := b.Read(readBytes) if err != io.EOF { t.Errorf("expected EOF") } + if n != 0 { + t.Errorf("expected buffer to be empty, but got %d bytes", n) + } b.resetRead() } @@ -55,18 +65,26 @@ func TestBufferOffset(t *testing.T) { t.Fatalf("cannot write all the bytes: want=%d got=%d", len(writeBytes), n) } - for i := 0; i < len(writeBytes)/2; i++ { - readBytes := make([]byte, 2) + const readSize = 2 + + numReads := len(writeBytes) / readSize + + for i := 0; i < numReads; i++ { + readBytes := make([]byte, readSize) n, err := b.Read(readBytes) - if err != nil { + if i == numReads-1 { + // The last read should report EOF. + if err != io.EOF { + t.Fatal(err) + } + } else if err != nil { t.Fatal(err) } - if n != 2 { - t.Fatalf("cannot read the bytes: want=%d got=%d", 2, n) + if n != readSize { + t.Fatalf("cannot read the bytes: want=%d got=%d", readSize, n) } - if !bytes.Equal(readBytes, writeBytes[i*2:i*2+2]) { - t.Fatalf("different bytes read: want=%s got=%s", - readBytes, writeBytes[i*2:i*2+2]) + if got := writeBytes[i*readSize : i*readSize+readSize]; !bytes.Equal(got, readBytes) { + t.Fatalf("different bytes read: want=%s got=%s", readBytes, got) } } } diff --git a/cmux.go b/cmux.go index cbc259b..47db1ea 100644 --- a/cmux.go +++ b/cmux.go @@ -174,13 +174,26 @@ func newMuxConn(c net.Conn) *MuxConn { } } -func (m *MuxConn) Read(b []byte) (n int, err error) { - if n, err = m.buf.Read(b); err == nil { - return +// From the io.Reader documentation: +// +// When Read encounters an error or end-of-file condition after +// successfully reading n > 0 bytes, it returns the number of +// bytes read. It may return the (non-nil) error from the same call +// or return the error (and n == 0) from a subsequent call. +// An instance of this general case is that a Reader returning +// a non-zero number of bytes at the end of the input stream may +// return either err == EOF or err == nil. The next Read should +// return 0, EOF. +// +// This function implements the latter behaviour, returning the +// (non-nil) error from the same call. +func (m *MuxConn) Read(b []byte) (int, error) { + n1, err := m.buf.Read(b) + if n1 == len(b) || err != io.EOF { + return n1, err } - - n, err = m.Conn.Read(b) - return + n2, err := m.Conn.Read(b[n1:]) + return n1 + n2, err } func (m *MuxConn) sniffer() io.Reader { diff --git a/cmux_test.go b/cmux_test.go index cada54e..719d3d8 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -1,7 +1,9 @@ package cmux import ( + "errors" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -38,6 +40,22 @@ func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { } } +type chanListener struct { + net.Listener + connCh chan net.Conn +} + +func newChanListener() *chanListener { + return &chanListener{connCh: make(chan net.Conn, 1)} +} + +func (l *chanListener) Accept() (net.Conn, error) { + if c, ok := <-l.connCh; ok { + return c, nil + } + return nil, errors.New("use of closed network connection") +} + func testListener(t *testing.T) (net.Listener, func()) { l, err := net.Listen("tcp", ":0") if err != nil { @@ -147,6 +165,48 @@ func runTestRPCClient(t *testing.T, addr net.Addr) { } } +func TestRead(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + const payload = "hello world\r\n" + const mult = 2 + + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + defer close(l.connCh) + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher to force reading from the conn. + muxl.Match(HTTP2()) + anyl := muxl.Match(Any()) + go safeServe(errCh, muxl) + muxedConn, err := anyl.Accept() + if err != nil { + t.Fatal(err) + } + var b [mult * len(payload)]byte + n, err := muxedConn.Read(b[:]) + if err != nil { + t.Fatal(err) + } + if e := len(b); n != e { + t.Errorf("expected to read %d bytes, but read %d bytes", e, n) + } +} + func TestAny(t *testing.T) { defer leakCheck(t)() errCh := make(chan error)