diff --git a/buffer.go b/buffer.go index cb979a3..7c10522 100644 --- a/buffer.go +++ b/buffer.go @@ -2,6 +2,8 @@ package cmux import "io" +var _ io.ReadWriter = (*buffer)(nil) + type buffer struct { read int data []byte @@ -38,13 +40,18 @@ func (b *buffer) resetRead() { b.read = 0 } -func (b *buffer) Write(p []byte) (n int, err error) { - n = len(p) - if b.data == nil { - b.data = p[:n:n] - return - } - +// From the io.Writer documentation: +// +// Write writes len(p) bytes from p to the underlying data stream. +// It returns the number of bytes written from p (0 <= n <= len(p)) +// and any error encountered that caused the write to stop early. +// Write must return a non-nil error if it returns n < len(p). +// Write must not modify the slice data, even temporarily. +// +// Implementations must not retain p. +// +// In a previous incarnation, this implementation retained the incoming slice. +func (b *buffer) Write(p []byte) (int, error) { b.data = append(b.data, p...) - return + return len(p), nil } diff --git a/buffer_test.go b/buffer_test.go index ba1f00d..f098b80 100644 --- a/buffer_test.go +++ b/buffer_test.go @@ -6,8 +6,31 @@ import ( "testing" ) +func TestWriteNoModify(t *testing.T) { + var b buffer + + const origWriteByte = 0 + const postWriteByte = 1 + + writeBytes := []byte{origWriteByte} + if _, err := b.Write(writeBytes); err != nil { + t.Fatal(err) + } + writeBytes[0] = postWriteByte + readBytes := make([]byte, 1) + if _, err := b.Read(readBytes); err != io.EOF { + t.Fatal(err) + } + + if readBytes[0] != origWriteByte { + t.Fatalf("expected to read %x, but read %x; buffer retained passed-in slice", origWriteByte, postWriteByte) + } +} + +const writeString = "deadbeef" + func TestBuffer(t *testing.T) { - writeBytes := []byte("deadbeef") + writeBytes := []byte(writeString) const numWrites = 10 @@ -54,7 +77,7 @@ func TestBuffer(t *testing.T) { } func TestBufferOffset(t *testing.T) { - writeBytes := []byte("deadbeef") + writeBytes := []byte(writeString) var b buffer n, err := b.Write(writeBytes)