...

Source file src/k8s.io/kubernetes/pkg/util/iptree/iptree_test.go

Documentation: k8s.io/kubernetes/pkg/util/iptree

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package iptree
    18  
    19  import (
    20  	"math/rand"
    21  	"net/netip"
    22  	"reflect"
    23  	"sort"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  	"k8s.io/apimachinery/pkg/util/sets"
    29  )
    30  
    31  func Test_InsertGetDelete(t *testing.T) {
    32  	testCases := []struct {
    33  		name   string
    34  		prefix netip.Prefix
    35  	}{
    36  		{
    37  			name:   "ipv4",
    38  			prefix: netip.MustParsePrefix("192.168.0.0/24"),
    39  		},
    40  		{
    41  			name:   "ipv6",
    42  			prefix: netip.MustParsePrefix("fd00:1:2:3::/124"),
    43  		},
    44  	}
    45  
    46  	for _, tc := range testCases {
    47  		t.Run(tc.name, func(t *testing.T) {
    48  			tree := New[int]()
    49  			ok := tree.InsertPrefix(tc.prefix, 1)
    50  			if ok {
    51  				t.Fatal("should not exist")
    52  			}
    53  			if _, ok := tree.GetPrefix(tc.prefix); !ok {
    54  				t.Errorf("CIDR %s not found", tc.prefix)
    55  			}
    56  			if ok := tree.DeletePrefix(tc.prefix); !ok {
    57  				t.Errorf("CIDR %s not deleted", tc.prefix)
    58  			}
    59  			if _, ok := tree.GetPrefix(tc.prefix); ok {
    60  				t.Errorf("CIDR %s found", tc.prefix)
    61  			}
    62  		})
    63  	}
    64  
    65  }
    66  
    67  func TestBasicIPv4(t *testing.T) {
    68  	tree := New[int]()
    69  	// insert
    70  	ipnet := netip.MustParsePrefix("192.168.0.0/24")
    71  	ok := tree.InsertPrefix(ipnet, 1)
    72  	if ok {
    73  		t.Fatal("should not exist")
    74  	}
    75  	// check exist
    76  	if _, ok := tree.GetPrefix(ipnet); !ok {
    77  		t.Errorf("CIDR %s not found", ipnet)
    78  	}
    79  
    80  	// check does not exist
    81  	ipnet2 := netip.MustParsePrefix("12.1.0.0/16")
    82  	if _, ok := tree.GetPrefix(ipnet2); ok {
    83  		t.Errorf("CIDR %s not expected", ipnet2)
    84  	}
    85  
    86  	// check insert existing prefix updates the value
    87  	ok = tree.InsertPrefix(ipnet2, 2)
    88  	if ok {
    89  		t.Errorf("should not exist: %s", ipnet2)
    90  	}
    91  
    92  	ok = tree.InsertPrefix(ipnet2, 3)
    93  	if !ok {
    94  		t.Errorf("should be updated: %s", ipnet2)
    95  	}
    96  
    97  	if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 {
    98  		t.Errorf("CIDR %s not expected", ipnet2)
    99  	}
   100  
   101  	// check longer prefix matching
   102  	ipnet3 := netip.MustParsePrefix("12.1.0.2/32")
   103  	lpm, _, ok := tree.LongestPrefixMatch(ipnet3)
   104  	if !ok || lpm != ipnet2 {
   105  		t.Errorf("expected %s got %s", ipnet2, lpm)
   106  	}
   107  }
   108  
   109  func TestBasicIPv6(t *testing.T) {
   110  	tree := New[int]()
   111  	// insert
   112  	ipnet := netip.MustParsePrefix("2001:db8::/64")
   113  	ok := tree.InsertPrefix(ipnet, 1)
   114  	if ok {
   115  		t.Fatal("should not exist")
   116  	}
   117  	// check exist
   118  	if _, ok := tree.GetPrefix(ipnet); !ok {
   119  		t.Errorf("CIDR %s not found", ipnet)
   120  	}
   121  
   122  	// check does not exist
   123  	ipnet2 := netip.MustParsePrefix("2001:db8:1:3:4::/64")
   124  	if _, ok := tree.GetPrefix(ipnet2); ok {
   125  		t.Errorf("CIDR %s not expected", ipnet2)
   126  	}
   127  
   128  	// check insert existing prefix updates the value
   129  	ok = tree.InsertPrefix(ipnet2, 2)
   130  	if ok {
   131  		t.Errorf("should not exist: %s", ipnet2)
   132  	}
   133  
   134  	ok = tree.InsertPrefix(ipnet2, 3)
   135  	if !ok {
   136  		t.Errorf("should be updated: %s", ipnet2)
   137  	}
   138  
   139  	if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 {
   140  		t.Errorf("CIDR %s not expected", ipnet2)
   141  	}
   142  
   143  	// check longer prefix matching
   144  	ipnet3 := netip.MustParsePrefix("2001:db8:1:3:4::/96")
   145  	lpm, _, ok := tree.LongestPrefixMatch(ipnet3)
   146  	if !ok || lpm != ipnet2 {
   147  		t.Errorf("expected %s got %s", ipnet2, lpm)
   148  	}
   149  }
   150  
   151  func TestInsertGetDelete100K(t *testing.T) {
   152  	testCases := []struct {
   153  		name string
   154  		is6  bool
   155  	}{
   156  		{
   157  			name: "ipv4",
   158  		},
   159  		{
   160  			name: "ipv6",
   161  			is6:  true,
   162  		},
   163  	}
   164  
   165  	for _, tc := range testCases {
   166  		t.Run(tc.name, func(t *testing.T) {
   167  			cidrs := generateRandomCIDRs(tc.is6, 100*1000)
   168  			tree := New[string]()
   169  
   170  			for k := range cidrs {
   171  				ok := tree.InsertPrefix(k, k.String())
   172  				if ok {
   173  					t.Errorf("error inserting: %v", k)
   174  				}
   175  			}
   176  
   177  			if tree.Len(tc.is6) != len(cidrs) {
   178  				t.Errorf("expected %d nodes on the tree, got %d", len(cidrs), tree.Len(tc.is6))
   179  			}
   180  
   181  			list := cidrs.UnsortedList()
   182  			for _, k := range list {
   183  				if v, ok := tree.GetPrefix(k); !ok {
   184  					t.Errorf("CIDR %s not found", k)
   185  					return
   186  				} else if v != k.String() {
   187  					t.Errorf("CIDR value %s not found", k)
   188  					return
   189  				}
   190  				ok := tree.DeletePrefix(k)
   191  				if !ok {
   192  					t.Errorf("CIDR delete %s error", k)
   193  				}
   194  			}
   195  
   196  			if tree.Len(tc.is6) != 0 {
   197  				t.Errorf("No node expected on the tree, got: %d %v", tree.Len(tc.is6), cidrs)
   198  			}
   199  		})
   200  	}
   201  }
   202  
   203  func Test_findAncestor(t *testing.T) {
   204  	tests := []struct {
   205  		name string
   206  		a    netip.Prefix
   207  		b    netip.Prefix
   208  		want netip.Prefix
   209  	}{
   210  		{
   211  			name: "ipv4 direct parent",
   212  			a:    netip.MustParsePrefix("192.168.0.0/24"),
   213  			b:    netip.MustParsePrefix("192.168.1.0/24"),
   214  			want: netip.MustParsePrefix("192.168.0.0/23"),
   215  		},
   216  		{
   217  			name: "ipv4 root parent ",
   218  			a:    netip.MustParsePrefix("192.168.0.0/24"),
   219  			b:    netip.MustParsePrefix("1.168.1.0/24"),
   220  			want: netip.MustParsePrefix("0.0.0.0/0"),
   221  		},
   222  		{
   223  			name: "ipv4 parent /1",
   224  			a:    netip.MustParsePrefix("192.168.0.0/24"),
   225  			b:    netip.MustParsePrefix("184.168.1.0/24"),
   226  			want: netip.MustParsePrefix("128.0.0.0/1"),
   227  		},
   228  		{
   229  			name: "ipv6 direct parent",
   230  			a:    netip.MustParsePrefix("fd00:1:1:1::/64"),
   231  			b:    netip.MustParsePrefix("fd00:1:1:2::/64"),
   232  			want: netip.MustParsePrefix("fd00:1:1::/62"),
   233  		},
   234  		{
   235  			name: "ipv6 root parent ",
   236  			a:    netip.MustParsePrefix("fd00:1:1:1::/64"),
   237  			b:    netip.MustParsePrefix("1:1:1:1::/64"),
   238  			want: netip.MustParsePrefix("::/0"),
   239  		},
   240  	}
   241  	for _, tt := range tests {
   242  		t.Run(tt.name, func(t *testing.T) {
   243  			if got := findAncestor(tt.a, tt.b); !reflect.DeepEqual(got, tt.want) {
   244  				t.Errorf("findAncestor() = %v, want %v", got, tt.want)
   245  			}
   246  		})
   247  	}
   248  }
   249  
   250  func Test_getBitFromAddr(t *testing.T) {
   251  	tests := []struct {
   252  		name string
   253  		ip   netip.Addr
   254  		pos  int
   255  		want int
   256  	}{
   257  		// 192.168.0.0
   258  		// 11000000.10101000.00000000.00000001
   259  		{
   260  			name: "ipv4 first is a one",
   261  			ip:   netip.MustParseAddr("192.168.0.0"),
   262  			pos:  1,
   263  			want: 1,
   264  		},
   265  		{
   266  			name: "ipv4 middle is a zero",
   267  			ip:   netip.MustParseAddr("192.168.0.0"),
   268  			pos:  16,
   269  			want: 0,
   270  		},
   271  		{
   272  			name: "ipv4 middle is a one",
   273  			ip:   netip.MustParseAddr("192.168.0.0"),
   274  			pos:  13,
   275  			want: 1,
   276  		},
   277  		{
   278  			name: "ipv4 last is a zero",
   279  			ip:   netip.MustParseAddr("192.168.0.0"),
   280  			pos:  32,
   281  			want: 0,
   282  		},
   283  		// 2001:db8::ff00:42:8329
   284  		// 0010000000000001:0000110110111000:0000000000000000:0000000000000000:0000000000000000:1111111100000000:0000000001000010:1000001100101001
   285  		{
   286  			name: "ipv6 first is a zero",
   287  			ip:   netip.MustParseAddr("2001:db8::ff00:42:8329"),
   288  			pos:  1,
   289  			want: 0,
   290  		},
   291  		{
   292  			name: "ipv6 middle is a zero",
   293  			ip:   netip.MustParseAddr("2001:db8::ff00:42:8329"),
   294  			pos:  56,
   295  			want: 0,
   296  		},
   297  		{
   298  			name: "ipv6 middle is a one",
   299  			ip:   netip.MustParseAddr("2001:db8::ff00:42:8329"),
   300  			pos:  81,
   301  			want: 1,
   302  		},
   303  		{
   304  			name: "ipv6 last is a one",
   305  			ip:   netip.MustParseAddr("2001:db8::ff00:42:8329"),
   306  			pos:  128,
   307  			want: 1,
   308  		},
   309  	}
   310  	for _, tt := range tests {
   311  		t.Run(tt.name, func(t *testing.T) {
   312  			if got := getBitFromAddr(tt.ip, tt.pos); got != tt.want {
   313  				t.Errorf("getBitFromAddr() = %v, want %v", got, tt.want)
   314  			}
   315  		})
   316  	}
   317  }
   318  
   319  func TestShortestPrefix(t *testing.T) {
   320  	r := New[int]()
   321  
   322  	keys := []string{
   323  		"10.0.0.0/8",
   324  		"10.21.0.0/16",
   325  		"10.221.0.0/16",
   326  		"10.1.2.3/32",
   327  		"10.1.2.0/24",
   328  		"192.168.0.0/24",
   329  		"192.168.0.0/16",
   330  	}
   331  	for _, k := range keys {
   332  		ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
   333  		if ok {
   334  			t.Errorf("unexpected update on insert %s", k)
   335  		}
   336  	}
   337  	if r.Len(false) != len(keys) {
   338  		t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
   339  	}
   340  
   341  	type exp struct {
   342  		inp string
   343  		out string
   344  	}
   345  	cases := []exp{
   346  		{"192.168.0.3/32", "192.168.0.0/16"},
   347  		{"10.1.2.4/21", "10.0.0.0/8"},
   348  		{"192.168.0.0/16", "192.168.0.0/16"},
   349  		{"192.168.0.0/32", "192.168.0.0/16"},
   350  		{"10.1.2.3/32", "10.0.0.0/8"},
   351  	}
   352  	for _, test := range cases {
   353  		m, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix(test.inp))
   354  		if !ok {
   355  			t.Fatalf("no match: %v", test)
   356  		}
   357  		if m != netip.MustParsePrefix(test.out) {
   358  			t.Fatalf("mis-match: %v %v", m, test)
   359  		}
   360  	}
   361  
   362  	// not match
   363  	_, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
   364  	if ok {
   365  		t.Fatalf("match unexpected for 0.0.0.0/0")
   366  	}
   367  }
   368  
   369  func TestLongestPrefixMatch(t *testing.T) {
   370  	r := New[int]()
   371  
   372  	keys := []string{
   373  		"10.0.0.0/8",
   374  		"10.21.0.0/16",
   375  		"10.221.0.0/16",
   376  		"10.1.2.3/32",
   377  		"10.1.2.0/24",
   378  		"192.168.0.0/24",
   379  		"192.168.0.0/16",
   380  	}
   381  	for _, k := range keys {
   382  		ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
   383  		if ok {
   384  			t.Errorf("unexpected update on insert %s", k)
   385  		}
   386  	}
   387  	if r.Len(false) != len(keys) {
   388  		t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
   389  	}
   390  
   391  	type exp struct {
   392  		inp string
   393  		out string
   394  	}
   395  	cases := []exp{
   396  		{"192.168.0.3/32", "192.168.0.0/24"},
   397  		{"10.1.2.4/21", "10.0.0.0/8"},
   398  		{"10.21.2.0/24", "10.21.0.0/16"},
   399  		{"10.1.2.3/32", "10.1.2.3/32"},
   400  	}
   401  	for _, test := range cases {
   402  		m, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix(test.inp))
   403  		if !ok {
   404  			t.Fatalf("no match: %v", test)
   405  		}
   406  		if m != netip.MustParsePrefix(test.out) {
   407  			t.Fatalf("mis-match: %v %v", m, test)
   408  		}
   409  	}
   410  	// not match
   411  	_, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
   412  	if ok {
   413  		t.Fatalf("match unexpected for 0.0.0.0/0")
   414  	}
   415  }
   416  
   417  func TestTopLevelPrefixesV4(t *testing.T) {
   418  	r := New[string]()
   419  
   420  	keys := []string{
   421  		"10.0.0.0/8",
   422  		"10.21.0.0/16",
   423  		"10.221.0.0/16",
   424  		"10.1.2.3/32",
   425  		"10.1.2.0/24",
   426  		"192.168.0.0/20",
   427  		"192.168.1.0/24",
   428  		"172.16.0.0/12",
   429  		"172.21.23.0/24",
   430  	}
   431  	for _, k := range keys {
   432  		ok := r.InsertPrefix(netip.MustParsePrefix(k), k)
   433  		if ok {
   434  			t.Errorf("unexpected update on insert %s", k)
   435  		}
   436  	}
   437  	if r.Len(false) != len(keys) {
   438  		t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
   439  	}
   440  
   441  	expected := []string{
   442  		"10.0.0.0/8",
   443  		"192.168.0.0/20",
   444  		"172.16.0.0/12",
   445  	}
   446  	parents := r.TopLevelPrefixes(false)
   447  	if len(parents) != len(expected) {
   448  		t.Fatalf("bad len: %v %v", len(parents), len(expected))
   449  	}
   450  
   451  	for _, k := range expected {
   452  		v, ok := parents[k]
   453  		if !ok {
   454  			t.Errorf("key %s not found", k)
   455  		}
   456  		if v != k {
   457  			t.Errorf("value expected %s got %s", k, v)
   458  		}
   459  	}
   460  }
   461  
   462  func TestTopLevelPrefixesV6(t *testing.T) {
   463  	r := New[string]()
   464  
   465  	keys := []string{
   466  		"2001:db8:1:2:3::/64",
   467  		"2001:db8::/64",
   468  		"2001:db8:1:1:1::/64",
   469  		"2001:db8:1:1:1::/112",
   470  	}
   471  	for _, k := range keys {
   472  		ok := r.InsertPrefix(netip.MustParsePrefix(k), k)
   473  		if ok {
   474  			t.Errorf("unexpected update on insert %s", k)
   475  		}
   476  	}
   477  
   478  	if r.Len(true) != len(keys) {
   479  		t.Fatalf("bad len: %v %v", r.Len(true), len(keys))
   480  	}
   481  
   482  	expected := []string{
   483  		"2001:db8::/64",
   484  		"2001:db8:1:2:3::/64",
   485  		"2001:db8:1:1:1::/64",
   486  	}
   487  	parents := r.TopLevelPrefixes(true)
   488  	if len(parents) != len(expected) {
   489  		t.Fatalf("bad len: %v %v", len(parents), len(expected))
   490  	}
   491  
   492  	for _, k := range expected {
   493  		v, ok := parents[k]
   494  		if !ok {
   495  			t.Errorf("key %s not found", k)
   496  		}
   497  		if v != k {
   498  			t.Errorf("value expected %s got %s", k, v)
   499  		}
   500  	}
   501  }
   502  
   503  func TestWalkV4(t *testing.T) {
   504  	r := New[int]()
   505  
   506  	keys := []string{
   507  		"10.0.0.0/8",
   508  		"10.1.0.0/16",
   509  		"10.1.1.0/24",
   510  		"10.1.1.32/26",
   511  		"10.1.1.33/32",
   512  	}
   513  	for _, k := range keys {
   514  		ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
   515  		if ok {
   516  			t.Errorf("unexpected update on insert %s", k)
   517  		}
   518  	}
   519  	if r.Len(false) != len(keys) {
   520  		t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
   521  	}
   522  
   523  	// match exact prefix
   524  	path := []string{}
   525  	r.WalkPath(netip.MustParsePrefix("10.1.1.32/26"), func(k netip.Prefix, v int) bool {
   526  		path = append(path, k.String())
   527  		return false
   528  	})
   529  	if !cmp.Equal(path, keys[:4]) {
   530  		t.Errorf("Walkpath expected %v got %v", keys[:4], path)
   531  	}
   532  	// not match on prefix
   533  	path = []string{}
   534  	r.WalkPath(netip.MustParsePrefix("10.1.1.33/26"), func(k netip.Prefix, v int) bool {
   535  		path = append(path, k.String())
   536  		return false
   537  	})
   538  	if !cmp.Equal(path, keys[:3]) {
   539  		t.Errorf("Walkpath expected %v got %v", keys[:3], path)
   540  	}
   541  	// match exact prefix
   542  	path = []string{}
   543  	r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/8"), func(k netip.Prefix, v int) bool {
   544  		path = append(path, k.String())
   545  		return false
   546  	})
   547  	if !cmp.Equal(path, keys) {
   548  		t.Errorf("WalkPrefix expected %v got %v", keys, path)
   549  	}
   550  	// not match on prefix
   551  	path = []string{}
   552  	r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/9"), func(k netip.Prefix, v int) bool {
   553  		path = append(path, k.String())
   554  		return false
   555  	})
   556  	if !cmp.Equal(path, keys[1:]) {
   557  		t.Errorf("WalkPrefix expected %v got %v", keys[1:], path)
   558  	}
   559  }
   560  
   561  func TestWalkV6(t *testing.T) {
   562  	r := New[int]()
   563  
   564  	keys := []string{
   565  		"2001:db8::/48",
   566  		"2001:db8::/64",
   567  		"2001:db8::/96",
   568  		"2001:db8::/112",
   569  		"2001:db8::/128",
   570  	}
   571  	for _, k := range keys {
   572  		ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
   573  		if ok {
   574  			t.Errorf("unexpected update on insert %s", k)
   575  		}
   576  	}
   577  	if r.Len(true) != len(keys) {
   578  		t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
   579  	}
   580  
   581  	// match exact prefix
   582  	path := []string{}
   583  	r.WalkPath(netip.MustParsePrefix("2001:db8::/112"), func(k netip.Prefix, v int) bool {
   584  		path = append(path, k.String())
   585  		return false
   586  	})
   587  	if !cmp.Equal(path, keys[:4]) {
   588  		t.Errorf("Walkpath expected %v got %v", keys[:4], path)
   589  	}
   590  	// not match on prefix
   591  	path = []string{}
   592  	r.WalkPath(netip.MustParsePrefix("2001:db8::1/112"), func(k netip.Prefix, v int) bool {
   593  		path = append(path, k.String())
   594  		return false
   595  	})
   596  	if !cmp.Equal(path, keys[:3]) {
   597  		t.Errorf("Walkpath expected %v got %v", keys[:3], path)
   598  	}
   599  	// match exact prefix
   600  	path = []string{}
   601  	r.WalkPrefix(netip.MustParsePrefix("2001:db8::/48"), func(k netip.Prefix, v int) bool {
   602  		path = append(path, k.String())
   603  		return false
   604  	})
   605  	if !cmp.Equal(path, keys) {
   606  		t.Errorf("WalkPrefix expected %v got %v", keys, path)
   607  	}
   608  	// not match on prefix
   609  	path = []string{}
   610  	r.WalkPrefix(netip.MustParsePrefix("2001:db8::/49"), func(k netip.Prefix, v int) bool {
   611  		path = append(path, k.String())
   612  		return false
   613  	})
   614  	if !cmp.Equal(path, keys[1:]) {
   615  		t.Errorf("WalkPrefix expected %v got %v", keys[1:], path)
   616  	}
   617  }
   618  
   619  func TestGetHostIPPrefixMatches(t *testing.T) {
   620  	r := New[int]()
   621  
   622  	keys := []string{
   623  		"10.0.0.0/8",
   624  		"10.21.0.0/16",
   625  		"10.221.0.0/16",
   626  		"10.1.2.3/32",
   627  		"10.1.2.0/24",
   628  		"192.168.0.0/24",
   629  		"192.168.0.0/16",
   630  		"2001:db8::/48",
   631  		"2001:db8::/64",
   632  		"2001:db8::/96",
   633  	}
   634  	for _, k := range keys {
   635  		ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
   636  		if ok {
   637  			t.Errorf("unexpected update on insert %s", k)
   638  		}
   639  	}
   640  
   641  	type exp struct {
   642  		inp string
   643  		out []string
   644  	}
   645  	cases := []exp{
   646  		{"192.168.0.3", []string{"192.168.0.0/24", "192.168.0.0/16"}},
   647  		{"10.1.2.4", []string{"10.1.2.0/24", "10.0.0.0/8"}},
   648  		{"10.1.2.0", []string{"10.0.0.0/8"}},
   649  		{"10.1.2.255", []string{"10.0.0.0/8"}},
   650  		{"192.168.0.0", []string{}},
   651  		{"192.168.1.0", []string{"192.168.0.0/16"}},
   652  		{"10.1.2.255", []string{"10.0.0.0/8"}},
   653  		{"2001:db8::1", []string{"2001:db8::/96", "2001:db8::/64", "2001:db8::/48"}},
   654  		{"2001:db8::", []string{}},
   655  		{"2001:db8::ffff:ffff:ffff:ffff", []string{"2001:db8::/64", "2001:db8::/48"}},
   656  	}
   657  	for _, test := range cases {
   658  		m := r.GetHostIPPrefixMatches(netip.MustParseAddr(test.inp))
   659  		in := []netip.Prefix{}
   660  		for k := range m {
   661  			in = append(in, k)
   662  		}
   663  		out := []netip.Prefix{}
   664  		for _, s := range test.out {
   665  			out = append(out, netip.MustParsePrefix(s))
   666  		}
   667  
   668  		// sort by prefix bits to avoid flakes
   669  		sort.Slice(in, func(i, j int) bool { return in[i].Bits() < in[j].Bits() })
   670  		sort.Slice(out, func(i, j int) bool { return out[i].Bits() < out[j].Bits() })
   671  		if !reflect.DeepEqual(in, out) {
   672  			t.Fatalf("mis-match: %v %v", in, out)
   673  		}
   674  	}
   675  
   676  	// not match
   677  	_, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
   678  	if ok {
   679  		t.Fatalf("match unexpected for 0.0.0.0/0")
   680  	}
   681  }
   682  
   683  func Test_prefixContainIP(t *testing.T) {
   684  	tests := []struct {
   685  		name   string
   686  		prefix netip.Prefix
   687  		ip     netip.Addr
   688  		want   bool
   689  	}{
   690  		{
   691  			name:   "IPv4 contains",
   692  			prefix: netip.MustParsePrefix("192.168.0.0/24"),
   693  			ip:     netip.MustParseAddr("192.168.0.1"),
   694  			want:   true,
   695  		},
   696  		{
   697  			name:   "IPv4 network address",
   698  			prefix: netip.MustParsePrefix("192.168.0.0/24"),
   699  			ip:     netip.MustParseAddr("192.168.0.0"),
   700  		},
   701  		{
   702  			name:   "IPv4 broadcast address",
   703  			prefix: netip.MustParsePrefix("192.168.0.0/24"),
   704  			ip:     netip.MustParseAddr("192.168.0.255"),
   705  		},
   706  		{
   707  			name:   "IPv4 does not contain",
   708  			prefix: netip.MustParsePrefix("192.168.0.0/24"),
   709  			ip:     netip.MustParseAddr("192.168.1.2"),
   710  		},
   711  		{
   712  			name:   "IPv6 contains",
   713  			prefix: netip.MustParsePrefix("2001:db2::/96"),
   714  			ip:     netip.MustParseAddr("2001:db2::1"),
   715  			want:   true,
   716  		},
   717  		{
   718  			name:   "IPv6 network address",
   719  			prefix: netip.MustParsePrefix("2001:db2::/96"),
   720  			ip:     netip.MustParseAddr("2001:db2::"),
   721  		},
   722  		{
   723  			name:   "IPv6 broadcast address",
   724  			prefix: netip.MustParsePrefix("2001:db2::/96"),
   725  			ip:     netip.MustParseAddr("2001:db2::ffff:ffff"),
   726  			want:   true,
   727  		},
   728  		{
   729  			name:   "IPv6 does not contain",
   730  			prefix: netip.MustParsePrefix("2001:db2::/96"),
   731  			ip:     netip.MustParseAddr("2001:db2:1:2:3::1"),
   732  		},
   733  	}
   734  	for _, tt := range tests {
   735  		t.Run(tt.name, func(t *testing.T) {
   736  			if got := prefixContainIP(tt.prefix, tt.ip); got != tt.want {
   737  				t.Errorf("prefixContainIP() = %v, want %v", got, tt.want)
   738  			}
   739  		})
   740  	}
   741  }
   742  
   743  func BenchmarkInsertUpdate(b *testing.B) {
   744  	r := New[bool]()
   745  	ipList := generateRandomCIDRs(true, 20000).UnsortedList()
   746  	for _, ip := range ipList {
   747  		r.InsertPrefix(ip, true)
   748  	}
   749  
   750  	b.ResetTimer()
   751  	for n := 0; n < b.N; n++ {
   752  		r.InsertPrefix(ipList[n%len(ipList)], true)
   753  	}
   754  }
   755  
   756  func generateRandomCIDRs(is6 bool, number int) sets.Set[netip.Prefix] {
   757  	n := 4
   758  	if is6 {
   759  		n = 16
   760  	}
   761  	cidrs := sets.Set[netip.Prefix]{}
   762  	rand.New(rand.NewSource(time.Now().UnixNano()))
   763  	for i := 0; i < number; i++ {
   764  		bytes := make([]byte, n)
   765  		for i := 0; i < n; i++ {
   766  			bytes[i] = uint8(rand.Intn(255))
   767  		}
   768  
   769  		ip, ok := netip.AddrFromSlice(bytes)
   770  		if !ok {
   771  			continue
   772  		}
   773  
   774  		bits := rand.Intn(n * 8)
   775  		prefix := netip.PrefixFrom(ip, bits).Masked()
   776  		if prefix.IsValid() {
   777  			cidrs.Insert(prefix)
   778  		}
   779  	}
   780  	return cidrs
   781  }
   782  

View as plain text