diff --git a/patricia.go b/patricia.go index 56ec4e7..7fc8a5b 100644 --- a/patricia.go +++ b/patricia.go @@ -9,11 +9,19 @@ import ( // and cannot be changed after instantiation. type patriciaTree struct { root *ptNode + buf []byte // preallocated buffer to read data while matching } -func newPatriciaTree(b ...[]byte) *patriciaTree { +func newPatriciaTree(bs ...[]byte) *patriciaTree { + max := 0 + for _, b := range bs { + if max < len(b) { + max = len(b) + } + } return &patriciaTree{ - root: newNode(b), + root: newNode(bs), + buf: make([]byte, max+1), } } @@ -22,17 +30,17 @@ func newPatriciaTreeString(strs ...string) *patriciaTree { for i, s := range strs { b[i] = []byte(s) } - return &patriciaTree{ - root: newNode(b), - } + return newPatriciaTree(b...) } func (t *patriciaTree) matchPrefix(r io.Reader) bool { - return t.root.match(r, true) + n, _ := io.ReadFull(r, t.buf) + return t.root.match(t.buf[:n], true) } func (t *patriciaTree) match(r io.Reader) bool { - return t.root.match(r, false) + n, _ := io.ReadFull(r, t.buf) + return t.root.match(t.buf[:n], false) } type ptNode struct { @@ -122,52 +130,30 @@ func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) { return prefix, rest } -func readBytes(r io.Reader, n int) (b []byte, err error) { - b = make([]byte, n) - o := 0 - for o < n { - nr, err := r.Read(b[o:]) - if err != nil && err != io.EOF { - return b, err +func (n *ptNode) match(b []byte, prefix bool) bool { + l := len(n.prefix) + if l > 0 { + if l > len(b) { + l = len(b) } - - o += nr - - if err == io.EOF { - break - } - } - return b[:o], nil -} - -func (n *ptNode) match(r io.Reader, prefix bool) bool { - if l := len(n.prefix); l > 0 { - b, err := readBytes(r, l) - if err != nil || len(b) != l || !bytes.Equal(b, n.prefix) { + if !bytes.Equal(b[:l], n.prefix) { return false } } - if prefix && n.terminal { + if n.terminal && (prefix || len(n.prefix) == len(b)) { return true } - b := make([]byte, 1) - for { - nr, err := r.Read(b) - if nr != 0 { - break - } - - if err == io.EOF { - return n.terminal - } - - if err != nil { - return false - } + nextN, ok := n.next[b[l]] + if !ok { + return false } - nextN, ok := n.next[b[0]] - return ok && nextN.match(r, prefix) + if l == len(b) { + b = b[l:l] + } else { + b = b[l+1:] + } + return nextN.match(b, prefix) }