...

Source file src/go4.org/netipx/ipset.go

Documentation: go4.org/netipx

     1  // Copyright 2020 The Inet.Af AUTHORS. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package netipx
     6  
     7  import (
     8  	"fmt"
     9  	"net/netip"
    10  	"runtime"
    11  	"sort"
    12  	"strings"
    13  )
    14  
    15  // IPSetBuilder builds an immutable IPSet.
    16  //
    17  // The zero value is a valid value representing a set of no IPs.
    18  //
    19  // The Add and Remove methods add or remove IPs to/from the set.
    20  // Removals only affect the current membership of the set, so in
    21  // general Adds should be called first. Input ranges may overlap in
    22  // any way.
    23  //
    24  // Most IPSetBuilder methods do not return errors.
    25  // Instead, errors are accumulated and reported by IPSetBuilder.IPSet.
    26  type IPSetBuilder struct {
    27  	// in are the ranges in the set.
    28  	in []IPRange
    29  
    30  	// out are the ranges to be removed from 'in'.
    31  	out []IPRange
    32  
    33  	// errs are errors accumulated during construction.
    34  	errs multiErr
    35  }
    36  
    37  // normalize normalizes s: s.in becomes the minimal sorted list of
    38  // ranges required to describe s, and s.out becomes empty.
    39  func (s *IPSetBuilder) normalize() {
    40  	const debug = false
    41  	if debug {
    42  		debugf("ranges start in=%v out=%v", s.in, s.out)
    43  	}
    44  	in, ok := mergeIPRanges(s.in)
    45  	if !ok {
    46  		return
    47  	}
    48  	out, ok := mergeIPRanges(s.out)
    49  	if !ok {
    50  		return
    51  	}
    52  	if debug {
    53  		debugf("ranges sort  in=%v out=%v", in, out)
    54  	}
    55  
    56  	// in and out are sorted in ascending range order, and have no
    57  	// overlaps within each other. We can run a merge of the two lists
    58  	// in one pass.
    59  
    60  	min := make([]IPRange, 0, len(in))
    61  	for len(in) > 0 && len(out) > 0 {
    62  		rin, rout := in[0], out[0]
    63  		if debug {
    64  			debugf("step in=%v out=%v", rin, rout)
    65  		}
    66  
    67  		switch {
    68  		case !rout.IsValid() || !rin.IsValid():
    69  			// mergeIPRanges should have prevented invalid ranges from
    70  			// sneaking in.
    71  			panic("invalid IPRanges during Ranges merge")
    72  		case rout.entirelyBefore(rin):
    73  			// "out" is entirely before "in".
    74  			//
    75  			//    out         in
    76  			// f-------t   f-------t
    77  			out = out[1:]
    78  			if debug {
    79  				debugf("out before in; drop out")
    80  			}
    81  		case rin.entirelyBefore(rout):
    82  			// "in" is entirely before "out".
    83  			//
    84  			//    in         out
    85  			// f------t   f-------t
    86  			min = append(min, rin)
    87  			in = in[1:]
    88  			if debug {
    89  				debugf("in before out; append in")
    90  				debugf("min=%v", min)
    91  			}
    92  		case rin.coveredBy(rout):
    93  			// "out" entirely covers "in".
    94  			//
    95  			//       out
    96  			// f-------------t
    97  			//    f------t
    98  			//       in
    99  			in = in[1:]
   100  			if debug {
   101  				debugf("in inside out; drop in")
   102  			}
   103  		case rout.inMiddleOf(rin):
   104  			// "in" entirely covers "out".
   105  			//
   106  			//       in
   107  			// f-------------t
   108  			//    f------t
   109  			//       out
   110  			min = append(min, IPRange{from: rin.from, to: AddrPrior(rout.from)})
   111  			// Adjust in[0], not ir, because we want to consider the
   112  			// mutated range on the next iteration.
   113  			in[0].from = rout.to.Next()
   114  			out = out[1:]
   115  			if debug {
   116  				debugf("out inside in; split in, append first in, drop out, adjust second in")
   117  				debugf("min=%v", min)
   118  			}
   119  		case rout.overlapsStartOf(rin):
   120  			// "out" overlaps start of "in".
   121  			//
   122  			//   out
   123  			// f------t
   124  			//    f------t
   125  			//       in
   126  			in[0].from = rout.to.Next()
   127  			// Can't move ir onto min yet, another later out might
   128  			// trim it further. Just discard or and continue.
   129  			out = out[1:]
   130  			if debug {
   131  				debugf("out cuts start of in; adjust in, drop out")
   132  			}
   133  		case rout.overlapsEndOf(rin):
   134  			// "out" overlaps end of "in".
   135  			//
   136  			//           out
   137  			//        f------t
   138  			//    f------t
   139  			//       in
   140  			min = append(min, IPRange{from: rin.from, to: AddrPrior(rout.from)})
   141  			in = in[1:]
   142  			if debug {
   143  				debugf("merge out cuts end of in; append shortened in")
   144  				debugf("min=%v", min)
   145  			}
   146  		default:
   147  			// The above should account for all combinations of in and
   148  			// out overlapping, but insert a panic to be sure.
   149  			panic("unexpected additional overlap scenario")
   150  		}
   151  	}
   152  	if len(in) > 0 {
   153  		// Ran out of removals before the end of in.
   154  		min = append(min, in...)
   155  		if debug {
   156  			debugf("min=%v", min)
   157  		}
   158  	}
   159  
   160  	s.in = min
   161  	s.out = nil
   162  }
   163  
   164  // Clone returns a copy of s that shares no memory with s.
   165  func (s *IPSetBuilder) Clone() *IPSetBuilder {
   166  	return &IPSetBuilder{
   167  		in:  append([]IPRange(nil), s.in...),
   168  		out: append([]IPRange(nil), s.out...),
   169  	}
   170  }
   171  
   172  func (s *IPSetBuilder) addError(msg string, args ...interface{}) {
   173  	se := new(stacktraceErr)
   174  	// Skip three frames: runtime.Callers, addError, and the IPSetBuilder
   175  	// method that called addError (such as IPSetBuilder.Add).
   176  	// The resulting stack trace ends at the line in the user's
   177  	// code where they called into netaddr.
   178  	n := runtime.Callers(3, se.pcs[:])
   179  	se.at = se.pcs[:n]
   180  	se.err = fmt.Errorf(msg, args...)
   181  	s.errs = append(s.errs, se)
   182  }
   183  
   184  // Add adds ip to s.
   185  func (s *IPSetBuilder) Add(ip netip.Addr) {
   186  	if !ip.IsValid() {
   187  		s.addError("Add(IP{})")
   188  		return
   189  	}
   190  	s.AddRange(IPRangeFrom(ip, ip))
   191  }
   192  
   193  // AddPrefix adds all IPs in p to s.
   194  func (s *IPSetBuilder) AddPrefix(p netip.Prefix) {
   195  	if r := RangeOfPrefix(p); r.IsValid() {
   196  		s.AddRange(r)
   197  	} else {
   198  		s.addError("AddPrefix(%v/%v)", p.Addr(), p.Bits())
   199  	}
   200  }
   201  
   202  // AddRange adds r to s.
   203  // If r is not Valid, AddRange does nothing.
   204  func (s *IPSetBuilder) AddRange(r IPRange) {
   205  	if !r.IsValid() {
   206  		s.addError("AddRange(%v-%v)", r.From(), r.To())
   207  		return
   208  	}
   209  	// If there are any removals (s.out), then we need to compact the set
   210  	// first to get the order right.
   211  	if len(s.out) > 0 {
   212  		s.normalize()
   213  	}
   214  	s.in = append(s.in, r)
   215  }
   216  
   217  // AddSet adds all IPs in b to s.
   218  func (s *IPSetBuilder) AddSet(b *IPSet) {
   219  	if b == nil {
   220  		return
   221  	}
   222  	for _, r := range b.rr {
   223  		s.AddRange(r)
   224  	}
   225  }
   226  
   227  // Remove removes ip from s.
   228  func (s *IPSetBuilder) Remove(ip netip.Addr) {
   229  	if !ip.IsValid() {
   230  		s.addError("Remove(IP{})")
   231  	} else {
   232  		s.RemoveRange(IPRangeFrom(ip, ip))
   233  	}
   234  }
   235  
   236  // RemovePrefix removes all IPs in p from s.
   237  func (s *IPSetBuilder) RemovePrefix(p netip.Prefix) {
   238  	if r := RangeOfPrefix(p); r.IsValid() {
   239  		s.RemoveRange(r)
   240  	} else {
   241  		s.addError("RemovePrefix(%v/%v)", p.Addr(), p.Bits())
   242  	}
   243  }
   244  
   245  // RemoveRange removes all IPs in r from s.
   246  func (s *IPSetBuilder) RemoveRange(r IPRange) {
   247  	if r.IsValid() {
   248  		s.out = append(s.out, r)
   249  	} else {
   250  		s.addError("RemoveRange(%v-%v)", r.From(), r.To())
   251  	}
   252  }
   253  
   254  // RemoveSet removes all IPs in o from s.
   255  func (s *IPSetBuilder) RemoveSet(b *IPSet) {
   256  	if b == nil {
   257  		return
   258  	}
   259  	for _, r := range b.rr {
   260  		s.RemoveRange(r)
   261  	}
   262  }
   263  
   264  // removeBuilder removes all IPs in b from s.
   265  func (s *IPSetBuilder) removeBuilder(b *IPSetBuilder) {
   266  	b.normalize()
   267  	for _, r := range b.in {
   268  		s.RemoveRange(r)
   269  	}
   270  }
   271  
   272  // Complement updates s to contain the complement of its current
   273  // contents.
   274  func (s *IPSetBuilder) Complement() {
   275  	s.normalize()
   276  	s.out = s.in
   277  	s.in = []IPRange{
   278  		RangeOfPrefix(netip.PrefixFrom(netip.AddrFrom4([4]byte{}), 0)),
   279  		RangeOfPrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0)),
   280  	}
   281  }
   282  
   283  // Intersect updates s to the set intersection of s and b.
   284  func (s *IPSetBuilder) Intersect(b *IPSet) {
   285  	var o IPSetBuilder
   286  	o.Complement()
   287  	o.RemoveSet(b)
   288  	s.removeBuilder(&o)
   289  }
   290  
   291  func discardf(format string, args ...interface{}) {}
   292  
   293  // debugf is reassigned by tests.
   294  var debugf = discardf
   295  
   296  // IPSet returns an immutable IPSet representing the current state of s.
   297  //
   298  // Most IPSetBuilder methods do not return errors.
   299  // Rather, the builder ignores any invalid inputs (such as an invalid IPPrefix),
   300  // and accumulates a list of any such errors that it encountered.
   301  //
   302  // IPSet also reports any such accumulated errors.
   303  // Even if the returned error is non-nil, the returned IPSet is usable
   304  // and contains all modifications made with valid inputs.
   305  //
   306  // The builder remains usable after calling IPSet.
   307  // Calling IPSet clears any accumulated errors.
   308  func (s *IPSetBuilder) IPSet() (*IPSet, error) {
   309  	s.normalize()
   310  	ret := &IPSet{
   311  		rr: append([]IPRange{}, s.in...),
   312  	}
   313  	if len(s.errs) == 0 {
   314  		return ret, nil
   315  	} else {
   316  		errs := s.errs
   317  		s.errs = nil
   318  		return ret, errs
   319  	}
   320  }
   321  
   322  // IPSet represents a set of IP addresses.
   323  //
   324  // IPSet is safe for concurrent use.
   325  // The zero value is a valid value representing a set of no IPs.
   326  // Use IPSetBuilder to construct IPSets.
   327  type IPSet struct {
   328  	// rr is the set of IPs that belong to this IPSet. The IPRanges
   329  	// are normalized according to IPSetBuilder.normalize, meaning
   330  	// they are a sorted, minimal representation (no overlapping
   331  	// ranges, no contiguous ranges). The implementation of various
   332  	// methods rely on this property.
   333  	rr []IPRange
   334  }
   335  
   336  // Ranges returns the minimum and sorted set of IP
   337  // ranges that covers s.
   338  func (s *IPSet) Ranges() []IPRange {
   339  	return append([]IPRange{}, s.rr...)
   340  }
   341  
   342  // Prefixes returns the minimum and sorted set of IP prefixes
   343  // that covers s.
   344  func (s *IPSet) Prefixes() []netip.Prefix {
   345  	out := make([]netip.Prefix, 0, len(s.rr))
   346  	for _, r := range s.rr {
   347  		out = append(out, r.Prefixes()...)
   348  	}
   349  	return out
   350  }
   351  
   352  // Equal reports whether s and o represent the same set of IP
   353  // addresses.
   354  func (s *IPSet) Equal(o *IPSet) bool {
   355  	if len(s.rr) != len(o.rr) {
   356  		return false
   357  	}
   358  	for i := range s.rr {
   359  		if s.rr[i] != o.rr[i] {
   360  			return false
   361  		}
   362  	}
   363  	return true
   364  }
   365  
   366  // Contains reports whether ip is in s.
   367  // If ip has an IPv6 zone, Contains returns false,
   368  // because IPSets do not track zones.
   369  func (s *IPSet) Contains(ip netip.Addr) bool {
   370  	if ip.Zone() != "" {
   371  		return false
   372  	}
   373  	// TODO: data structure permitting more efficient lookups:
   374  	// https://github.com/inetaf/netaddr/issues/139
   375  	i := sort.Search(len(s.rr), func(i int) bool {
   376  		return ip.Less(s.rr[i].from)
   377  	})
   378  	if i == 0 {
   379  		return false
   380  	}
   381  	i--
   382  	return s.rr[i].contains(ip)
   383  }
   384  
   385  // ContainsRange reports whether all IPs in r are in s.
   386  func (s *IPSet) ContainsRange(r IPRange) bool {
   387  	for _, x := range s.rr {
   388  		if r.coveredBy(x) {
   389  			return true
   390  		}
   391  	}
   392  	return false
   393  }
   394  
   395  // ContainsPrefix reports whether all IPs in p are in s.
   396  func (s *IPSet) ContainsPrefix(p netip.Prefix) bool {
   397  	return s.ContainsRange(RangeOfPrefix(p))
   398  }
   399  
   400  // Overlaps reports whether any IP in b is also in s.
   401  func (s *IPSet) Overlaps(b *IPSet) bool {
   402  	// TODO: sorted ranges lets us do this in O(n+m)
   403  	for _, r := range s.rr {
   404  		for _, or := range b.rr {
   405  			if r.Overlaps(or) {
   406  				return true
   407  			}
   408  		}
   409  	}
   410  	return false
   411  }
   412  
   413  // OverlapsRange reports whether any IP in r is also in s.
   414  func (s *IPSet) OverlapsRange(r IPRange) bool {
   415  	// TODO: sorted ranges lets us do this more efficiently.
   416  	for _, x := range s.rr {
   417  		if x.Overlaps(r) {
   418  			return true
   419  		}
   420  	}
   421  	return false
   422  }
   423  
   424  // OverlapsPrefix reports whether any IP in p is also in s.
   425  func (s *IPSet) OverlapsPrefix(p netip.Prefix) bool {
   426  	return s.OverlapsRange(RangeOfPrefix(p))
   427  }
   428  
   429  // RemoveFreePrefix splits s into a Prefix of length bitLen and a new
   430  // IPSet with that prefix removed.
   431  //
   432  // If no contiguous prefix of length bitLen exists in s,
   433  // RemoveFreePrefix returns ok=false.
   434  func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p netip.Prefix, newSet *IPSet, ok bool) {
   435  	var bestFit netip.Prefix
   436  	for _, r := range s.rr {
   437  		for _, prefix := range r.Prefixes() {
   438  			if uint8(prefix.Bits()) > bitLen {
   439  				continue
   440  			}
   441  			if !bestFit.Addr().IsValid() || prefix.Bits() > bestFit.Bits() {
   442  				bestFit = prefix
   443  				if uint8(bestFit.Bits()) == bitLen {
   444  					// exact match, done.
   445  					break
   446  				}
   447  			}
   448  		}
   449  	}
   450  
   451  	if !bestFit.Addr().IsValid() {
   452  		return netip.Prefix{}, s, false
   453  	}
   454  
   455  	prefix := netip.PrefixFrom(bestFit.Addr(), int(bitLen))
   456  
   457  	var b IPSetBuilder
   458  	b.AddSet(s)
   459  	b.RemovePrefix(prefix)
   460  	newSet, _ = b.IPSet()
   461  	return prefix, newSet, true
   462  }
   463  
   464  type multiErr []error
   465  
   466  func (e multiErr) Error() string {
   467  	var ret []string
   468  	for _, err := range e {
   469  		ret = append(ret, err.Error())
   470  	}
   471  	return strings.Join(ret, "; ")
   472  }
   473  
   474  // A stacktraceErr combines an error with a stack trace.
   475  type stacktraceErr struct {
   476  	pcs [16]uintptr // preallocated array of PCs
   477  	at  []uintptr   // stack trace whence the error
   478  	err error       // underlying error
   479  }
   480  
   481  func (e *stacktraceErr) Error() string {
   482  	frames := runtime.CallersFrames(e.at)
   483  	buf := new(strings.Builder)
   484  	buf.WriteString(e.err.Error())
   485  	buf.WriteString(" @ ")
   486  	for {
   487  		frame, more := frames.Next()
   488  		if !more {
   489  			break
   490  		}
   491  		fmt.Fprintf(buf, "%s:%d ", frame.File, frame.Line)
   492  	}
   493  	return strings.TrimSpace(buf.String())
   494  }
   495  
   496  func (e *stacktraceErr) Unwrap() error {
   497  	return e.err
   498  }
   499  

View as plain text