1
16
17 package iptree
18
19 import (
20 "fmt"
21 "math/bits"
22 "net/netip"
23 )
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 type node[T any] struct {
61
62 prefix netip.Prefix
63
64 public bool
65 val T
66
67 child [2]*node[T]
68 }
69
70
71
72
73 func (n *node[T]) mergeChild() {
74
75 if n.public {
76 return
77 }
78
79 if n.child[0] != nil &&
80 n.child[1] != nil {
81 return
82 }
83
84 if n.child[0] == nil &&
85 n.child[1] == nil {
86 return
87 }
88
89 var child *node[T]
90 if n.child[0] != nil {
91 child = n.child[0]
92 } else if n.child[1] != nil {
93 child = n.child[1]
94 }
95 n.prefix = child.prefix
96 n.public = child.public
97 n.val = child.val
98 n.child = child.child
99
100
101 child.child[0] = nil
102 child.child[1] = nil
103 }
104
105
106 type Tree[T any] struct {
107 rootV4 *node[T]
108 rootV6 *node[T]
109 }
110
111
112 func New[T any]() *Tree[T] {
113 return &Tree[T]{
114 rootV4: &node[T]{
115 prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
116 },
117 rootV6: &node[T]{
118 prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
119 },
120 }
121 }
122
123
124 func (t *Tree[T]) GetPrefix(prefix netip.Prefix) (T, bool) {
125 var zeroT T
126
127 n := t.rootV4
128 if prefix.Addr().Is6() {
129 n = t.rootV6
130 }
131 bitPosition := 0
132
133 address := prefix.Masked().Addr()
134
135 mask := prefix.Bits()
136
137 for bitPosition < mask {
138
139 n = n.child[getBitFromAddr(address, bitPosition+1)]
140 if n == nil {
141 return zeroT, false
142 }
143
144 if !n.prefix.Contains(address) {
145 return zeroT, false
146 }
147
148 bitPosition = n.prefix.Bits()
149 }
150
151 if n != nil && n.public && n.prefix == prefix {
152 return n.val, true
153 }
154
155 return zeroT, false
156 }
157
158
159
160
161 func (t *Tree[T]) LongestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) {
162 n := t.rootV4
163 if prefix.Addr().Is6() {
164 n = t.rootV6
165 }
166
167 var last *node[T]
168
169 bitPosition := 0
170
171 address := prefix.Masked().Addr()
172 mask := prefix.Bits()
173
174 for bitPosition < mask {
175 if n.public {
176 last = n
177 }
178
179 n = n.child[getBitFromAddr(address, bitPosition+1)]
180 if n == nil {
181 break
182 }
183
184 if !n.prefix.Contains(address) {
185 break
186 }
187
188 bitPosition = n.prefix.Bits()
189 }
190
191 if n != nil && n.public && n.prefix == prefix {
192 last = n
193 }
194
195 if last != nil {
196 return last.prefix, last.val, true
197 }
198 var zeroT T
199 return netip.Prefix{}, zeroT, false
200 }
201
202
203
204
205 func (t *Tree[T]) ShortestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) {
206 var zeroT T
207
208 n := t.rootV4
209 if prefix.Addr().Is6() {
210 n = t.rootV6
211 }
212
213 bitPosition := 0
214
215 address := prefix.Masked().Addr()
216 mask := prefix.Bits()
217 for bitPosition < mask {
218 if n.public {
219 return n.prefix, n.val, true
220 }
221
222 n = n.child[getBitFromAddr(address, bitPosition+1)]
223 if n == nil {
224 return netip.Prefix{}, zeroT, false
225 }
226
227 if !n.prefix.Contains(address) {
228 return netip.Prefix{}, zeroT, false
229 }
230
231 bitPosition = n.prefix.Bits()
232 }
233
234 if n != nil && n.public && n.prefix == prefix {
235 return n.prefix, n.val, true
236 }
237 return netip.Prefix{}, zeroT, false
238 }
239
240
241
242 func (t *Tree[T]) InsertPrefix(prefix netip.Prefix, v T) bool {
243 n := t.rootV4
244 if prefix.Addr().Is6() {
245 n = t.rootV6
246 }
247 var parent *node[T]
248
249 bitPosition := 0
250
251 address := prefix.Masked().Addr()
252 mask := prefix.Bits()
253 for bitPosition < mask {
254
255 childIndex := getBitFromAddr(address, bitPosition+1)
256 parent = n
257 n = n.child[childIndex]
258
259 if n == nil {
260 parent.child[childIndex] = &node[T]{
261 public: true,
262 val: v,
263 prefix: prefix,
264 }
265 return false
266 }
267
268
269 bitPosition = n.prefix.Bits()
270
271
272
273 if n.prefix.Contains(address) && bitPosition <= mask {
274 continue
275 }
276
277
278
279
280
281 child := &node[T]{
282 prefix: prefix,
283 public: true,
284 val: v,
285 }
286
287 if prefix.Contains(n.prefix.Addr()) && bitPosition > mask {
288
289 parent.child[childIndex] = child
290 pos := prefix.Bits() + 1
291
292 child.child[getBitFromAddr(n.prefix.Addr(), pos)] = n
293 return false
294 }
295
296
297
298 ancestor := findAncestor(prefix, n.prefix)
299 link := &node[T]{
300 prefix: ancestor,
301 }
302 pos := parent.prefix.Bits() + 1
303 parent.child[getBitFromAddr(ancestor.Addr(), pos)] = link
304
305 pos = ancestor.Bits() + 1
306 idxChild := getBitFromAddr(prefix.Addr(), pos)
307 idxN := getBitFromAddr(n.prefix.Addr(), pos)
308 if idxChild == idxN {
309 panic(fmt.Sprintf("wrong ancestor %s: child %s N %s", ancestor.String(), prefix.String(), n.prefix.String()))
310 }
311 link.child[idxChild] = child
312 link.child[idxN] = n
313 return false
314 }
315
316
317 if n != nil && n.prefix == prefix {
318 if n.public {
319 n.val = v
320 n.public = true
321 return true
322 }
323 n.val = v
324 n.public = true
325 return false
326 }
327
328 return false
329 }
330
331
332 func (t *Tree[T]) DeletePrefix(prefix netip.Prefix) bool {
333 root := t.rootV4
334 if prefix.Addr().Is6() {
335 root = t.rootV6
336 }
337 var parent *node[T]
338 n := root
339
340 bitPosition := 0
341
342 address := prefix.Masked().Addr()
343 mask := prefix.Bits()
344 for bitPosition < mask {
345
346 parent = n
347 n = n.child[getBitFromAddr(address, bitPosition+1)]
348 if n == nil {
349 return false
350 }
351
352 if !n.prefix.Contains(address) {
353 return false
354 }
355
356 bitPosition = n.prefix.Bits()
357 }
358
359 if n.prefix != prefix {
360 return false
361 }
362
363 n.public = false
364 var zeroT T
365 n.val = zeroT
366
367 nodeChildren := 0
368 if n.child[0] != nil {
369 nodeChildren++
370 }
371 if n.child[1] != nil {
372 nodeChildren++
373 }
374
375
376
377 if parent != nil && nodeChildren == 0 {
378 if parent.child[0] != nil && parent.child[0] == n {
379 parent.child[0] = nil
380 } else if parent.child[1] != nil && parent.child[1] == n {
381 parent.child[1] = nil
382 } else {
383 panic("wrong parent")
384 }
385 n = nil
386 }
387
388
389 if n != root && nodeChildren == 1 {
390 n.mergeChild()
391 }
392
393
394
395 parentChildren := 0
396 if parent != nil {
397 if parent.child[0] != nil {
398 parentChildren++
399 }
400 if parent.child[1] != nil {
401 parentChildren++
402 }
403 if parent != root && parentChildren == 1 && !parent.public {
404 parent.mergeChild()
405 }
406 }
407 return true
408 }
409
410
411 func (t *Tree[T]) Len(isV6 bool) int {
412 count := 0
413 t.DepthFirstWalk(isV6, func(k netip.Prefix, v T) bool {
414 count++
415 return false
416 })
417 return count
418 }
419
420
421
422
423 type WalkFn[T any] func(s netip.Prefix, v T) bool
424
425
426 func (t *Tree[T]) DepthFirstWalk(isIPv6 bool, fn WalkFn[T]) {
427 if isIPv6 {
428 recursiveWalk(t.rootV6, fn)
429 }
430 recursiveWalk(t.rootV4, fn)
431 }
432
433
434
435 func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
436 if n == nil {
437 return true
438 }
439
440 if n.public && fn(n.prefix, n.val) {
441 return true
442 }
443
444
445 if n.child[0] != nil {
446 if recursiveWalk(n.child[0], fn) {
447 return true
448 }
449 }
450 if n.child[1] != nil {
451 if recursiveWalk(n.child[1], fn) {
452 return true
453 }
454 }
455 return false
456 }
457
458
459 func (t *Tree[T]) WalkPrefix(prefix netip.Prefix, fn WalkFn[T]) {
460 n := t.rootV4
461 if prefix.Addr().Is6() {
462 n = t.rootV6
463 }
464 bitPosition := 0
465
466 address := prefix.Masked().Addr()
467
468 mask := prefix.Bits()
469
470 for bitPosition < mask {
471
472 n = n.child[getBitFromAddr(address, bitPosition+1)]
473 if n == nil {
474 return
475 }
476
477 if !n.prefix.Contains(address) {
478 break
479 }
480
481 bitPosition = n.prefix.Bits()
482 }
483 recursiveWalk[T](n, fn)
484
485 }
486
487
488
489
490
491 func (t *Tree[T]) WalkPath(path netip.Prefix, fn WalkFn[T]) {
492 n := t.rootV4
493 if path.Addr().Is6() {
494 n = t.rootV6
495 }
496 bitPosition := 0
497
498 address := path.Masked().Addr()
499
500 mask := path.Bits()
501
502 for bitPosition < mask {
503
504 if n.public && fn(n.prefix, n.val) {
505 return
506 }
507
508 n = n.child[getBitFromAddr(address, bitPosition+1)]
509 if n == nil {
510 return
511 }
512
513 if !n.prefix.Contains(address) {
514 return
515 }
516
517 bitPosition = n.prefix.Bits()
518 }
519
520 if n != nil && n.public && n.prefix == path {
521 fn(n.prefix, n.val)
522 }
523 }
524
525
526
527
528
529 func (t *Tree[T]) TopLevelPrefixes(isIPv6 bool) map[string]T {
530 if isIPv6 {
531 return t.topLevelPrefixes(t.rootV6)
532 }
533 return t.topLevelPrefixes(t.rootV4)
534 }
535
536
537 func (t *Tree[T]) topLevelPrefixes(root *node[T]) map[string]T {
538 result := map[string]T{}
539 queue := []*node[T]{root}
540
541 for len(queue) > 0 {
542 n := queue[0]
543 queue = queue[1:]
544
545 if n.public {
546 result[n.prefix.String()] = n.val
547 continue
548 }
549 if n.child[0] != nil {
550 queue = append(queue, n.child[0])
551 }
552 if n.child[1] != nil {
553 queue = append(queue, n.child[1])
554 }
555 }
556 return result
557 }
558
559
560
561
562 func (t *Tree[T]) GetHostIPPrefixMatches(ip netip.Addr) map[netip.Prefix]T {
563
564 ipPrefix := netip.PrefixFrom(ip, ip.BitLen())
565 prefixes := map[netip.Prefix]T{}
566 t.WalkPath(ipPrefix, func(k netip.Prefix, v T) bool {
567 if prefixContainIP(k, ipPrefix.Addr()) {
568 prefixes[k] = v
569 }
570 return false
571 })
572 return prefixes
573 }
574
575
576
577 func getBitFromAddr(ip netip.Addr, pos int) int {
578 bytes := ip.AsSlice()
579
580 index := (pos - 1) / 8
581 if index >= len(bytes) {
582 panic(fmt.Sprintf("ip %s pos %d index %d bytes %v", ip, pos, index, bytes))
583 }
584
585 offset := (pos - 1) % 8
586
587 if bytes[index]&(uint8(0x80)>>offset) > 0 {
588 return 1
589 }
590 return 0
591 }
592
593
594 func findAncestor(a, b netip.Prefix) netip.Prefix {
595 bytesA := a.Addr().AsSlice()
596 bytesB := b.Addr().AsSlice()
597 bytes := make([]byte, len(bytesA))
598
599 max := a.Bits()
600 if l := b.Bits(); l < max {
601 max = l
602 }
603
604 mask := 0
605 for i := range bytesA {
606 xor := bytesA[i] ^ bytesB[i]
607 if xor == 0 {
608 bytes[i] = bytesA[i]
609 mask += 8
610
611 } else {
612 pos := bits.LeadingZeros8(xor)
613 mask += pos
614
615 bytes[i] = bytesA[i] & (^uint8(0) << (8 - pos))
616 break
617 }
618 }
619 if mask > max {
620 mask = max
621 }
622
623 addr, ok := netip.AddrFromSlice(bytes)
624 if !ok {
625 panic(bytes)
626 }
627 ancestor := netip.PrefixFrom(addr, mask)
628 return ancestor.Masked()
629 }
630
631
632
633
634
635 func prefixContainIP(prefix netip.Prefix, ip netip.Addr) bool {
636
637 if prefix.Masked().Addr() == ip {
638 return false
639 }
640
641 if !ip.Is6() {
642 ipLast, err := broadcastAddress(prefix)
643 if err != nil || ipLast == ip {
644 return false
645 }
646 }
647 return prefix.Contains(ip)
648 }
649
650
651
652
653
654
655
656
657 func broadcastAddress(subnet netip.Prefix) (netip.Addr, error) {
658 base := subnet.Masked().Addr()
659 bytes := base.AsSlice()
660
661 n := 8*len(bytes) - subnet.Bits()
662
663 for i := len(bytes) - 1; i >= 0 && n > 0; i-- {
664 if n >= 8 {
665 bytes[i] = 0xff
666 n -= 8
667 } else {
668 mask := ^uint8(0) >> (8 - n)
669 bytes[i] |= mask
670 break
671 }
672 }
673
674 addr, ok := netip.AddrFromSlice(bytes)
675 if !ok {
676 return netip.Addr{}, fmt.Errorf("invalid address %v", bytes)
677 }
678 return addr, nil
679 }
680
View as plain text