diff --git a/patricia.go b/patricia.go index ec10514..2bcc3f2 100644 --- a/patricia.go +++ b/patricia.go @@ -17,30 +17,18 @@ package cmux import ( "bytes" "io" - "sync" ) // patriciaTree is a simple patricia tree that handles []byte instead of string // and cannot be changed after instantiation. type patriciaTree struct { root *ptNode - mu struct { - sync.Mutex - buf []byte // preallocated buffer to read data while matching - } } -func newPatriciaTree(bs ...[]byte) *patriciaTree { - max := 0 - for _, b := range bs { - if max < len(b) { - max = len(b) - } +func newPatriciaTree(b ...[]byte) *patriciaTree { + return &patriciaTree{ + root: newNode(b), } - t := patriciaTree{root: newNode(bs)} - t.mu.buf = make([]byte, max+1) - - return &t } func newPatriciaTreeString(strs ...string) *patriciaTree { @@ -48,23 +36,17 @@ func newPatriciaTreeString(strs ...string) *patriciaTree { for i, s := range strs { b[i] = []byte(s) } - return newPatriciaTree(b...) + return &patriciaTree{ + root: newNode(b), + } } func (t *patriciaTree) matchPrefix(r io.Reader) bool { - t.mu.Lock() - defer t.mu.Unlock() - - n, _ := io.ReadFull(r, t.mu.buf) - return t.root.match(t.mu.buf[:n], true) + return t.root.match(r, true) } func (t *patriciaTree) match(r io.Reader) bool { - t.mu.Lock() - defer t.mu.Unlock() - - n, _ := io.ReadFull(r, t.mu.buf) - return t.root.match(t.mu.buf[:n], false) + return t.root.match(r, false) } type ptNode struct { @@ -154,30 +136,52 @@ func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) { return prefix, rest } -func (n *ptNode) match(b []byte, prefix bool) bool { - l := len(n.prefix) - if l > 0 { - if l > len(b) { - l = len(b) +func readBytes(r io.Reader, n int) (b []byte, err error) { + b = make([]byte, n) + o := 0 + for o < n { + nr, err := r.Read(b[o:]) + if err != nil && err != io.EOF { + return b, err } - if !bytes.Equal(b[:l], n.prefix) { + + o += nr + + if err == io.EOF { + break + } + } + return b[:o], nil +} + +func (n *ptNode) match(r io.Reader, prefix bool) bool { + if l := len(n.prefix); l > 0 { + b, err := readBytes(r, l) + if err != nil || len(b) != l || !bytes.Equal(b, n.prefix) { return false } } - if n.terminal && (prefix || len(n.prefix) == len(b)) { + if prefix && n.terminal { return true } - nextN, ok := n.next[b[l]] - if !ok { - return false + b := make([]byte, 1) + for { + nr, err := r.Read(b) + if nr != 0 { + break + } + + if err == io.EOF { + return n.terminal + } + + if err != nil { + return false + } } - if l == len(b) { - b = b[l:l] - } else { - b = b[l+1:] - } - return nextN.match(b, prefix) + nextN, ok := n.next[b[0]] + return ok && nextN.match(r, prefix) }