diff --git a/buffer.go b/buffer.go index 7c10522..8bb66ff 100644 --- a/buffer.go +++ b/buffer.go @@ -1,57 +1,42 @@ package cmux -import "io" +import ( + "bytes" + "io" +) -var _ io.ReadWriter = (*buffer)(nil) - -type buffer struct { - read int - data []byte +// bufferedReader is an optimized implementation of io.Reader that behaves like +// ``` +// io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(source, buffer)) +// ``` +// without allocating. +type bufferedReader struct { + source io.Reader + buffer bytes.Buffer + bufferRead int + bufferSize int + sniffing bool } -// From the io.Reader documentation: -// -// When Read encounters an error or end-of-file condition after -// successfully reading n > 0 bytes, it returns the number of -// bytes read. It may return the (non-nil) error from the same call -// or return the error (and n == 0) from a subsequent call. -// An instance of this general case is that a Reader returning -// a non-zero number of bytes at the end of the input stream may -// return either err == EOF or err == nil. The next Read should -// return 0, EOF. -// -// This function implements the latter behaviour, returning the -// (non-nil) error from the same call. -func (b *buffer) Read(p []byte) (int, error) { - var err error - n := copy(p, b.data[b.read:]) - b.read += n - if b.read == len(b.data) { - err = io.EOF +func (s *bufferedReader) Read(p []byte) (int, error) { + // Functionality of bytes.Reader. + bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize]) + s.bufferRead += bn + + p = p[bn:] + + // Funtionality of io.TeeReader. + sn, sErr := s.source.Read(p) + if sn > 0 && s.sniffing { + if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil { + return bn + wn, wErr + } } - return n, err + return bn + sn, sErr } -func (b *buffer) Len() int { - return len(b.data) - b.read -} - -func (b *buffer) resetRead() { - b.read = 0 -} - -// From the io.Writer documentation: -// -// Write writes len(p) bytes from p to the underlying data stream. -// It returns the number of bytes written from p (0 <= n <= len(p)) -// and any error encountered that caused the write to stop early. -// Write must return a non-nil error if it returns n < len(p). -// Write must not modify the slice data, even temporarily. -// -// Implementations must not retain p. -// -// In a previous incarnation, this implementation retained the incoming slice. -func (b *buffer) Write(p []byte) (int, error) { - b.data = append(b.data, p...) - return len(p), nil +func (s *bufferedReader) reset(snif bool) { + s.sniffing = snif + s.bufferRead = 0 + s.bufferSize = s.buffer.Len() } diff --git a/buffer_test.go b/buffer_test.go deleted file mode 100644 index f098b80..0000000 --- a/buffer_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package cmux - -import ( - "bytes" - "io" - "testing" -) - -func TestWriteNoModify(t *testing.T) { - var b buffer - - const origWriteByte = 0 - const postWriteByte = 1 - - writeBytes := []byte{origWriteByte} - if _, err := b.Write(writeBytes); err != nil { - t.Fatal(err) - } - writeBytes[0] = postWriteByte - readBytes := make([]byte, 1) - if _, err := b.Read(readBytes); err != io.EOF { - t.Fatal(err) - } - - if readBytes[0] != origWriteByte { - t.Fatalf("expected to read %x, but read %x; buffer retained passed-in slice", origWriteByte, postWriteByte) - } -} - -const writeString = "deadbeef" - -func TestBuffer(t *testing.T) { - writeBytes := []byte(writeString) - - const numWrites = 10 - - var b buffer - for i := 0; i < numWrites; i++ { - n, err := b.Write(writeBytes) - if err != nil && err != io.EOF { - t.Fatal(err) - } - if n != len(writeBytes) { - t.Fatalf("cannot write all the bytes: want=%d got=%d", len(writeBytes), n) - } - } - - for j := 0; j < 2; j++ { - readBytes := make([]byte, len(writeBytes)) - for i := 0; i < numWrites; i++ { - n, err := b.Read(readBytes) - if i == numWrites-1 { - // The last read should report EOF. - if err != io.EOF { - t.Fatal(err) - } - } else if err != nil { - t.Fatal(err) - } - if n != len(readBytes) { - t.Fatalf("cannot read all the bytes: want=%d got=%d", len(readBytes), n) - } - if !bytes.Equal(writeBytes, readBytes) { - t.Errorf("different bytes read: want=%d got=%d", writeBytes, readBytes) - } - } - n, err := b.Read(readBytes) - if err != io.EOF { - t.Errorf("expected EOF") - } - if n != 0 { - t.Errorf("expected buffer to be empty, but got %d bytes", n) - } - - b.resetRead() - } -} - -func TestBufferOffset(t *testing.T) { - writeBytes := []byte(writeString) - - var b buffer - n, err := b.Write(writeBytes) - if err != nil { - t.Fatal(err) - } - if n != len(writeBytes) { - t.Fatalf("cannot write all the bytes: want=%d got=%d", len(writeBytes), n) - } - - const readSize = 2 - - numReads := len(writeBytes) / readSize - - for i := 0; i < numReads; i++ { - readBytes := make([]byte, readSize) - n, err := b.Read(readBytes) - if i == numReads-1 { - // The last read should report EOF. - if err != io.EOF { - t.Fatal(err) - } - } else if err != nil { - t.Fatal(err) - } - if n != readSize { - t.Fatalf("cannot read the bytes: want=%d got=%d", readSize, n) - } - if got := writeBytes[i*readSize : i*readSize+readSize]; !bytes.Equal(got, readBytes) { - t.Fatalf("different bytes read: want=%s got=%s", readBytes, got) - } - } -} diff --git a/cmux.go b/cmux.go index e9a886d..8caf5d4 100644 --- a/cmux.go +++ b/cmux.go @@ -125,9 +125,9 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { muc := newMuxConn(c) for _, sl := range m.sls { for _, s := range sl.ss { - matched := s(muc.sniffer()) - muc.reset() + matched := s(muc.startSniffing()) if matched { + muc.doneSniffing() select { case sl.l.connc <- muc: case <-donec: @@ -177,12 +177,13 @@ func (l muxListener) Accept() (net.Conn, error) { // MuxConn wraps a net.Conn and provides transparent sniffing of connection data. type MuxConn struct { net.Conn - buf buffer + buf bufferedReader } func newMuxConn(c net.Conn) *MuxConn { return &MuxConn{ Conn: c, + buf: bufferedReader{source: c}, } } @@ -196,22 +197,15 @@ func newMuxConn(c net.Conn) *MuxConn { // a non-zero number of bytes at the end of the input stream may // return either err == EOF or err == nil. The next Read should // return 0, EOF. -// -// This function implements the latter behaviour, returning the -// (non-nil) error from the same call. func (m *MuxConn) Read(p []byte) (int, error) { - n1, err := m.buf.Read(p) - if err != io.EOF { - return n1, err - } - n2, err := m.Conn.Read(p[n1:]) - return n1 + n2, err + return m.buf.Read(p) } -func (m *MuxConn) sniffer() io.Reader { - return io.MultiReader(&m.buf, io.TeeReader(m.Conn, &m.buf)) +func (m *MuxConn) startSniffing() io.Reader { + m.buf.reset(true) + return &m.buf } -func (m *MuxConn) reset() { - m.buf.resetRead() +func (m *MuxConn) doneSniffing() { + m.buf.reset(false) }