1
2
3
4
5 package netipx
6
7 import (
8 "fmt"
9 "net/netip"
10 "runtime"
11 "sort"
12 "strings"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26 type IPSetBuilder struct {
27
28 in []IPRange
29
30
31 out []IPRange
32
33
34 errs multiErr
35 }
36
37
38
39 func (s *IPSetBuilder) normalize() {
40 const debug = false
41 if debug {
42 debugf("ranges start in=%v out=%v", s.in, s.out)
43 }
44 in, ok := mergeIPRanges(s.in)
45 if !ok {
46 return
47 }
48 out, ok := mergeIPRanges(s.out)
49 if !ok {
50 return
51 }
52 if debug {
53 debugf("ranges sort in=%v out=%v", in, out)
54 }
55
56
57
58
59
60 min := make([]IPRange, 0, len(in))
61 for len(in) > 0 && len(out) > 0 {
62 rin, rout := in[0], out[0]
63 if debug {
64 debugf("step in=%v out=%v", rin, rout)
65 }
66
67 switch {
68 case !rout.IsValid() || !rin.IsValid():
69
70
71 panic("invalid IPRanges during Ranges merge")
72 case rout.entirelyBefore(rin):
73
74
75
76
77 out = out[1:]
78 if debug {
79 debugf("out before in; drop out")
80 }
81 case rin.entirelyBefore(rout):
82
83
84
85
86 min = append(min, rin)
87 in = in[1:]
88 if debug {
89 debugf("in before out; append in")
90 debugf("min=%v", min)
91 }
92 case rin.coveredBy(rout):
93
94
95
96
97
98
99 in = in[1:]
100 if debug {
101 debugf("in inside out; drop in")
102 }
103 case rout.inMiddleOf(rin):
104
105
106
107
108
109
110 min = append(min, IPRange{from: rin.from, to: AddrPrior(rout.from)})
111
112
113 in[0].from = rout.to.Next()
114 out = out[1:]
115 if debug {
116 debugf("out inside in; split in, append first in, drop out, adjust second in")
117 debugf("min=%v", min)
118 }
119 case rout.overlapsStartOf(rin):
120
121
122
123
124
125
126 in[0].from = rout.to.Next()
127
128
129 out = out[1:]
130 if debug {
131 debugf("out cuts start of in; adjust in, drop out")
132 }
133 case rout.overlapsEndOf(rin):
134
135
136
137
138
139
140 min = append(min, IPRange{from: rin.from, to: AddrPrior(rout.from)})
141 in = in[1:]
142 if debug {
143 debugf("merge out cuts end of in; append shortened in")
144 debugf("min=%v", min)
145 }
146 default:
147
148
149 panic("unexpected additional overlap scenario")
150 }
151 }
152 if len(in) > 0 {
153
154 min = append(min, in...)
155 if debug {
156 debugf("min=%v", min)
157 }
158 }
159
160 s.in = min
161 s.out = nil
162 }
163
164
165 func (s *IPSetBuilder) Clone() *IPSetBuilder {
166 return &IPSetBuilder{
167 in: append([]IPRange(nil), s.in...),
168 out: append([]IPRange(nil), s.out...),
169 }
170 }
171
172 func (s *IPSetBuilder) addError(msg string, args ...interface{}) {
173 se := new(stacktraceErr)
174
175
176
177
178 n := runtime.Callers(3, se.pcs[:])
179 se.at = se.pcs[:n]
180 se.err = fmt.Errorf(msg, args...)
181 s.errs = append(s.errs, se)
182 }
183
184
185 func (s *IPSetBuilder) Add(ip netip.Addr) {
186 if !ip.IsValid() {
187 s.addError("Add(IP{})")
188 return
189 }
190 s.AddRange(IPRangeFrom(ip, ip))
191 }
192
193
194 func (s *IPSetBuilder) AddPrefix(p netip.Prefix) {
195 if r := RangeOfPrefix(p); r.IsValid() {
196 s.AddRange(r)
197 } else {
198 s.addError("AddPrefix(%v/%v)", p.Addr(), p.Bits())
199 }
200 }
201
202
203
204 func (s *IPSetBuilder) AddRange(r IPRange) {
205 if !r.IsValid() {
206 s.addError("AddRange(%v-%v)", r.From(), r.To())
207 return
208 }
209
210
211 if len(s.out) > 0 {
212 s.normalize()
213 }
214 s.in = append(s.in, r)
215 }
216
217
218 func (s *IPSetBuilder) AddSet(b *IPSet) {
219 if b == nil {
220 return
221 }
222 for _, r := range b.rr {
223 s.AddRange(r)
224 }
225 }
226
227
228 func (s *IPSetBuilder) Remove(ip netip.Addr) {
229 if !ip.IsValid() {
230 s.addError("Remove(IP{})")
231 } else {
232 s.RemoveRange(IPRangeFrom(ip, ip))
233 }
234 }
235
236
237 func (s *IPSetBuilder) RemovePrefix(p netip.Prefix) {
238 if r := RangeOfPrefix(p); r.IsValid() {
239 s.RemoveRange(r)
240 } else {
241 s.addError("RemovePrefix(%v/%v)", p.Addr(), p.Bits())
242 }
243 }
244
245
246 func (s *IPSetBuilder) RemoveRange(r IPRange) {
247 if r.IsValid() {
248 s.out = append(s.out, r)
249 } else {
250 s.addError("RemoveRange(%v-%v)", r.From(), r.To())
251 }
252 }
253
254
255 func (s *IPSetBuilder) RemoveSet(b *IPSet) {
256 if b == nil {
257 return
258 }
259 for _, r := range b.rr {
260 s.RemoveRange(r)
261 }
262 }
263
264
265 func (s *IPSetBuilder) removeBuilder(b *IPSetBuilder) {
266 b.normalize()
267 for _, r := range b.in {
268 s.RemoveRange(r)
269 }
270 }
271
272
273
274 func (s *IPSetBuilder) Complement() {
275 s.normalize()
276 s.out = s.in
277 s.in = []IPRange{
278 RangeOfPrefix(netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0)),
279 RangeOfPrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0)),
280 }
281 }
282
283
284 func (s *IPSetBuilder) Intersect(b *IPSet) {
285 var o IPSetBuilder
286 o.Complement()
287 o.RemoveSet(b)
288 s.removeBuilder(&o)
289 }
290
291 func discardf(format string, args ...interface{}) {}
292
293
294 var debugf = discardf
295
296
297
298
299
300
301
302
303
304
305
306
307
308 func (s *IPSetBuilder) IPSet() (*IPSet, error) {
309 s.normalize()
310 ret := &IPSet{
311 rr: append([]IPRange{}, s.in...),
312 }
313 if len(s.errs) == 0 {
314 return ret, nil
315 } else {
316 errs := s.errs
317 s.errs = nil
318 return ret, errs
319 }
320 }
321
322
323
324
325
326
327 type IPSet struct {
328
329
330
331
332
333 rr []IPRange
334 }
335
336
337
338 func (s *IPSet) Ranges() []IPRange {
339 return append([]IPRange{}, s.rr...)
340 }
341
342
343
344 func (s *IPSet) Prefixes() []netip.Prefix {
345 out := make([]netip.Prefix, 0, len(s.rr))
346 for _, r := range s.rr {
347 out = append(out, r.Prefixes()...)
348 }
349 return out
350 }
351
352
353
354 func (s *IPSet) Equal(o *IPSet) bool {
355 if len(s.rr) != len(o.rr) {
356 return false
357 }
358 for i := range s.rr {
359 if s.rr[i] != o.rr[i] {
360 return false
361 }
362 }
363 return true
364 }
365
366
367
368
369 func (s *IPSet) Contains(ip netip.Addr) bool {
370 if ip.Zone() != "" {
371 return false
372 }
373
374
375 i := sort.Search(len(s.rr), func(i int) bool {
376 return ip.Less(s.rr[i].from)
377 })
378 if i == 0 {
379 return false
380 }
381 i--
382 return s.rr[i].contains(ip)
383 }
384
385
386 func (s *IPSet) ContainsRange(r IPRange) bool {
387 for _, x := range s.rr {
388 if r.coveredBy(x) {
389 return true
390 }
391 }
392 return false
393 }
394
395
396 func (s *IPSet) ContainsPrefix(p netip.Prefix) bool {
397 return s.ContainsRange(RangeOfPrefix(p))
398 }
399
400
401 func (s *IPSet) Overlaps(b *IPSet) bool {
402
403 for _, r := range s.rr {
404 for _, or := range b.rr {
405 if r.Overlaps(or) {
406 return true
407 }
408 }
409 }
410 return false
411 }
412
413
414 func (s *IPSet) OverlapsRange(r IPRange) bool {
415
416 for _, x := range s.rr {
417 if x.Overlaps(r) {
418 return true
419 }
420 }
421 return false
422 }
423
424
425 func (s *IPSet) OverlapsPrefix(p netip.Prefix) bool {
426 return s.OverlapsRange(RangeOfPrefix(p))
427 }
428
429
430
431
432
433
434 func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p netip.Prefix, newSet *IPSet, ok bool) {
435 var bestFit netip.Prefix
436 for _, r := range s.rr {
437 for _, prefix := range r.Prefixes() {
438 if uint8(prefix.Bits()) > bitLen {
439 continue
440 }
441 if !bestFit.Addr().IsValid() || prefix.Bits() > bestFit.Bits() {
442 bestFit = prefix
443 if uint8(bestFit.Bits()) == bitLen {
444
445 break
446 }
447 }
448 }
449 }
450
451 if !bestFit.Addr().IsValid() {
452 return netip.Prefix{}, s, false
453 }
454
455 prefix := netip.PrefixFrom(bestFit.Addr(), int(bitLen))
456
457 var b IPSetBuilder
458 b.AddSet(s)
459 b.RemovePrefix(prefix)
460 newSet, _ = b.IPSet()
461 return prefix, newSet, true
462 }
463
464 type multiErr []error
465
466 func (e multiErr) Error() string {
467 var ret []string
468 for _, err := range e {
469 ret = append(ret, err.Error())
470 }
471 return strings.Join(ret, "; ")
472 }
473
474
475 type stacktraceErr struct {
476 pcs [16]uintptr
477 at []uintptr
478 err error
479 }
480
481 func (e *stacktraceErr) Error() string {
482 frames := runtime.CallersFrames(e.at)
483 buf := new(strings.Builder)
484 buf.WriteString(e.err.Error())
485 buf.WriteString(" @ ")
486 for {
487 frame, more := frames.Next()
488 if !more {
489 break
490 }
491 fmt.Fprintf(buf, "%s:%d ", frame.File, frame.Line)
492 }
493 return strings.TrimSpace(buf.String())
494 }
495
496 func (e *stacktraceErr) Unwrap() error {
497 return e.err
498 }
499
View as plain text