...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cmux
16
17 import (
18 "bytes"
19 "io"
20 )
21
22
23
24 type patriciaTree struct {
25 root *ptNode
26 maxDepth int
27 }
28
29 func newPatriciaTree(bs ...[]byte) *patriciaTree {
30 max := 0
31 for _, b := range bs {
32 if max < len(b) {
33 max = len(b)
34 }
35 }
36 return &patriciaTree{
37 root: newNode(bs),
38 maxDepth: max + 1,
39 }
40 }
41
42 func newPatriciaTreeString(strs ...string) *patriciaTree {
43 b := make([][]byte, len(strs))
44 for i, s := range strs {
45 b[i] = []byte(s)
46 }
47 return newPatriciaTree(b...)
48 }
49
50 func (t *patriciaTree) matchPrefix(r io.Reader) bool {
51 buf := make([]byte, t.maxDepth)
52 n, _ := io.ReadFull(r, buf)
53 return t.root.match(buf[:n], true)
54 }
55
56 func (t *patriciaTree) match(r io.Reader) bool {
57 buf := make([]byte, t.maxDepth)
58 n, _ := io.ReadFull(r, buf)
59 return t.root.match(buf[:n], false)
60 }
61
62 type ptNode struct {
63 prefix []byte
64 next map[byte]*ptNode
65 terminal bool
66 }
67
68 func newNode(strs [][]byte) *ptNode {
69 if len(strs) == 0 {
70 return &ptNode{
71 prefix: []byte{},
72 terminal: true,
73 }
74 }
75
76 if len(strs) == 1 {
77 return &ptNode{
78 prefix: strs[0],
79 terminal: true,
80 }
81 }
82
83 p, strs := splitPrefix(strs)
84 n := &ptNode{
85 prefix: p,
86 }
87
88 nexts := make(map[byte][][]byte)
89 for _, s := range strs {
90 if len(s) == 0 {
91 n.terminal = true
92 continue
93 }
94 nexts[s[0]] = append(nexts[s[0]], s[1:])
95 }
96
97 n.next = make(map[byte]*ptNode)
98 for first, rests := range nexts {
99 n.next[first] = newNode(rests)
100 }
101
102 return n
103 }
104
105 func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) {
106 if len(bss) == 0 || len(bss[0]) == 0 {
107 return prefix, bss
108 }
109
110 if len(bss) == 1 {
111 return bss[0], [][]byte{{}}
112 }
113
114 for i := 0; ; i++ {
115 var cur byte
116 eq := true
117 for j, b := range bss {
118 if len(b) <= i {
119 eq = false
120 break
121 }
122
123 if j == 0 {
124 cur = b[i]
125 continue
126 }
127
128 if cur != b[i] {
129 eq = false
130 break
131 }
132 }
133
134 if !eq {
135 break
136 }
137
138 prefix = append(prefix, cur)
139 }
140
141 rest = make([][]byte, 0, len(bss))
142 for _, b := range bss {
143 rest = append(rest, b[len(prefix):])
144 }
145
146 return prefix, rest
147 }
148
149 func (n *ptNode) match(b []byte, prefix bool) bool {
150 l := len(n.prefix)
151 if l > 0 {
152 if l > len(b) {
153 l = len(b)
154 }
155 if !bytes.Equal(b[:l], n.prefix) {
156 return false
157 }
158 }
159
160 if n.terminal && (prefix || len(n.prefix) == len(b)) {
161 return true
162 }
163
164 if l >= len(b) {
165 return false
166 }
167
168 nextN, ok := n.next[b[l]]
169 if !ok {
170 return false
171 }
172
173 if l == len(b) {
174 b = b[l:l]
175 } else {
176 b = b[l+1:]
177 }
178 return nextN.match(b, prefix)
179 }
180
View as plain text