...

Source file src/k8s.io/kubernetes/pkg/proxy/nftables/helpers_test.go

Documentation: k8s.io/kubernetes/pkg/proxy/nftables

     1  //go:build linux
     2  // +build linux
     3  
     4  /*
     5  Copyright 2015 The Kubernetes Authors.
     6  
     7  Licensed under the Apache License, Version 2.0 (the "License");
     8  you may not use this file except in compliance with the License.
     9  You may obtain a copy of the License at
    10  
    11      http://www.apache.org/licenses/LICENSE-2.0
    12  
    13  Unless required by applicable law or agreed to in writing, software
    14  distributed under the License is distributed on an "AS IS" BASIS,
    15  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16  See the License for the specific language governing permissions and
    17  limitations under the License.
    18  */
    19  
    20  package nftables
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"regexp"
    27  	"runtime"
    28  	"sort"
    29  	"strings"
    30  	"testing"
    31  
    32  	"github.com/google/go-cmp/cmp"
    33  	"github.com/lithammer/dedent"
    34  
    35  	"k8s.io/api/core/v1"
    36  	"k8s.io/apimachinery/pkg/util/sets"
    37  	netutils "k8s.io/utils/net"
    38  	"sigs.k8s.io/knftables"
    39  )
    40  
    41  // getLine returns a string containing the file and line number of the caller, if
    42  // possible. This is useful in tests with a large number of cases - when something goes
    43  // wrong you can find which case more easily.
    44  func getLine() string {
    45  	_, file, line, ok := runtime.Caller(1)
    46  	if !ok {
    47  		return ""
    48  	}
    49  	return fmt.Sprintf(" (from %s:%d)", file, line)
    50  }
    51  
    52  // objectOrder defines the order we sort different types into (higher = earlier); while
    53  // not necessary just for comparison purposes, it's more intuitive in the Diff output to
    54  // see rules/sets/maps before chains/elements.
    55  var objectOrder = map[string]int{
    56  	"table":   10,
    57  	"chain":   9,
    58  	"rule":    8,
    59  	"set":     7,
    60  	"map":     6,
    61  	"element": 5,
    62  	// anything else: 0
    63  }
    64  
    65  // sortNFTablesTransaction sorts an nftables transaction into a standard order for comparison
    66  func sortNFTablesTransaction(tx string) string {
    67  	lines := strings.Split(tx, "\n")
    68  
    69  	// strip blank lines and comments
    70  	for i := 0; i < len(lines); {
    71  		if lines[i] == "" || lines[i][0] == '#' {
    72  			lines = append(lines[:i], lines[i+1:]...)
    73  		} else {
    74  			i++
    75  		}
    76  	}
    77  
    78  	// sort remaining lines
    79  	sort.SliceStable(lines, func(i, j int) bool {
    80  		li := lines[i]
    81  		wi := strings.Split(li, " ")
    82  		lj := lines[j]
    83  		wj := strings.Split(lj, " ")
    84  
    85  		// All lines will start with "add OBJECTTYPE ip kube-proxy". Everything
    86  		// except "add table" will have an object name after the table name, and
    87  		// "add table" will have a comment after the table name. So every line
    88  		// should have at least 5 words.
    89  		if len(wi) < 5 || len(wj) < 5 {
    90  			return false
    91  		}
    92  
    93  		// Sort by object type first.
    94  		if wi[1] != wj[1] {
    95  			return objectOrder[wi[1]] >= objectOrder[wj[1]]
    96  		}
    97  
    98  		// Sort by object name when object type is identical.
    99  		if wi[4] != wj[4] {
   100  			return wi[4] < wj[4]
   101  		}
   102  
   103  		// Leave rules in the order they were originally added.
   104  		if wi[1] == "rule" {
   105  			return false
   106  		}
   107  
   108  		// Sort by the whole line when object type and name is identical. (e.g.,
   109  		// individual "add rule" and "add element" lines in a chain/set/map.)
   110  		return li < lj
   111  	})
   112  	return strings.Join(lines, "\n")
   113  }
   114  
   115  // diffNFTablesTransaction is a (testable) helper function for assertNFTablesTransactionEqual
   116  func diffNFTablesTransaction(expected, result string) string {
   117  	expected = sortNFTablesTransaction(expected)
   118  	result = sortNFTablesTransaction(result)
   119  
   120  	return cmp.Diff(expected, result)
   121  }
   122  
   123  // assertNFTablesTransactionEqual asserts that expected and result are equal, ignoring
   124  // irrelevant differences.
   125  func assertNFTablesTransactionEqual(t *testing.T, line string, expected, result string) {
   126  	diff := diffNFTablesTransaction(expected, result)
   127  	if diff != "" {
   128  		t.Errorf("tables do not match%s:\ndiff:\n%s\nfull result: %+v", line, diff, result)
   129  	}
   130  }
   131  
   132  // diffNFTablesChain is a (testable) helper function for assertNFTablesChainEqual
   133  func diffNFTablesChain(nft *knftables.Fake, chain, expected string) string {
   134  	expected = strings.TrimSpace(expected)
   135  	result := ""
   136  	if ch := nft.Table.Chains[chain]; ch != nil {
   137  		for i, rule := range ch.Rules {
   138  			if i > 0 {
   139  				result += "\n"
   140  			}
   141  			result += rule.Rule
   142  		}
   143  	}
   144  
   145  	return cmp.Diff(expected, result)
   146  }
   147  
   148  // nftablesTracer holds data used while virtually tracing a packet through a set of
   149  // iptables rules
   150  type nftablesTracer struct {
   151  	nft     *knftables.Fake
   152  	nodeIPs sets.Set[string]
   153  	t       *testing.T
   154  
   155  	// matches accumulates the list of rules that were matched, for debugging purposes.
   156  	matches []string
   157  
   158  	// outputs accumulates the list of matched terminal rule targets (endpoint
   159  	// IP:ports, or a special target like "REJECT") and is eventually used to generate
   160  	// the return value of tracePacket.
   161  	outputs []string
   162  
   163  	// markMasq tracks whether the packet has been marked for masquerading
   164  	markMasq bool
   165  }
   166  
   167  // newNFTablesTracer creates an nftablesTracer. nodeIPs are the IP to treat as local node
   168  // IPs (for determining whether rules with "fib saddr type local" or "fib daddr type
   169  // local" match).
   170  func newNFTablesTracer(t *testing.T, nft *knftables.Fake, nodeIPs []string) *nftablesTracer {
   171  	return &nftablesTracer{
   172  		nft:     nft,
   173  		nodeIPs: sets.New(nodeIPs...),
   174  		t:       t,
   175  	}
   176  }
   177  
   178  func (tracer *nftablesTracer) addressMatches(ipStr string, wantMatch bool, ruleAddress string) bool {
   179  	ip := netutils.ParseIPSloppy(ipStr)
   180  	if ip == nil {
   181  		tracer.t.Fatalf("Bad IP in test case: %s", ipStr)
   182  	}
   183  
   184  	var match bool
   185  	if strings.Contains(ruleAddress, "/") {
   186  		_, cidr, err := netutils.ParseCIDRSloppy(ruleAddress)
   187  		if err != nil {
   188  			tracer.t.Errorf("Bad CIDR in kube-proxy output: %v", err)
   189  		}
   190  		match = cidr.Contains(ip)
   191  	} else {
   192  		ip2 := netutils.ParseIPSloppy(ruleAddress)
   193  		if ip2 == nil {
   194  			tracer.t.Errorf("Bad IP/CIDR in kube-proxy output: %s", ruleAddress)
   195  		}
   196  		match = ip.Equal(ip2)
   197  	}
   198  
   199  	return match == wantMatch
   200  }
   201  
   202  func (tracer *nftablesTracer) addressMatchesSet(ipStr string, wantMatch bool, ruleAddress string) bool {
   203  	ruleAddress = strings.ReplaceAll(ruleAddress, " ", "")
   204  	addresses := strings.Split(ruleAddress, ",")
   205  	var match bool
   206  	for _, address := range addresses {
   207  		match = tracer.addressMatches(ipStr, true, address)
   208  		if match != wantMatch {
   209  			return false
   210  		}
   211  	}
   212  	return true
   213  }
   214  
   215  // matchDestIPOnly checks an "ip daddr" against a set/map, and returns the matching
   216  // Element, if found.
   217  func (tracer *nftablesTracer) matchDestIPOnly(elements []*knftables.Element, destIP string) *knftables.Element {
   218  	for _, element := range elements {
   219  		if element.Key[0] == destIP {
   220  			return element
   221  		}
   222  	}
   223  	return nil
   224  }
   225  
   226  // matchDest checks an "ip daddr . meta l4proto . th dport" against a set/map, and returns
   227  // the matching Element, if found.
   228  func (tracer *nftablesTracer) matchDest(elements []*knftables.Element, destIP, protocol, destPort string) *knftables.Element {
   229  	for _, element := range elements {
   230  		if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort {
   231  			return element
   232  		}
   233  	}
   234  	return nil
   235  }
   236  
   237  // matchDestAndSource checks an "ip daddr . meta l4proto . th dport . ip saddr" against a
   238  // set/map, where the source is allowed to be a CIDR, and returns the matching Element, if
   239  // found.
   240  func (tracer *nftablesTracer) matchDestAndSource(elements []*knftables.Element, destIP, protocol, destPort, sourceIP string) *knftables.Element {
   241  	for _, element := range elements {
   242  		if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort && tracer.addressMatches(sourceIP, true, element.Key[3]) {
   243  			return element
   244  		}
   245  	}
   246  	return nil
   247  }
   248  
   249  // matchDestPort checks an "meta l4proto . th dport" against a set/map, and returns the
   250  // matching Element, if found.
   251  func (tracer *nftablesTracer) matchDestPort(elements []*knftables.Element, protocol, destPort string) *knftables.Element {
   252  	for _, element := range elements {
   253  		if element.Key[0] == protocol && element.Key[1] == destPort {
   254  			return element
   255  		}
   256  	}
   257  	return nil
   258  }
   259  
   260  // We intentionally don't try to parse arbitrary nftables rules, as the syntax is quite
   261  // complicated and context sensitive. (E.g., "ip daddr" could be the start of an address
   262  // comparison, or it could be the start of a set/map lookup.) Instead, we just have
   263  // regexps to recognize the specific pieces of rules that we create in proxier.go.
   264  // Anything matching ignoredRegexp gets stripped out of the rule, and then what's left
   265  // *must* match one of the cases in runChain or an error will be logged. In cases where
   266  // the regexp doesn't end with `$`, and the matched rule succeeds against the input data,
   267  // runChain will continue trying to match the rest of the rule. E.g., "ip daddr 10.0.0.1
   268  // drop" would first match destAddrRegexp, and then (assuming destIP was "10.0.0.1") would
   269  // match verdictRegexp.
   270  
   271  var destAddrRegexp = regexp.MustCompile(`^ip6* daddr (!= )?(\S+)`)
   272  var destAddrLookupRegexp = regexp.MustCompile(`^ip6* daddr (!= )?\{([^}]*)\}`)
   273  var destAddrLocalRegexp = regexp.MustCompile(`^fib daddr type local`)
   274  var destPortRegexp = regexp.MustCompile(`^(tcp|udp|sctp) dport (\d+)`)
   275  var destIPOnlyLookupRegexp = regexp.MustCompile(`^ip6* daddr @(\S+)`)
   276  var destLookupRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport @(\S+)`)
   277  var destSourceLookupRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport \. ip6* saddr @(\S+)`)
   278  var destPortLookupRegexp = regexp.MustCompile(`^meta l4proto \. th dport @(\S+)`)
   279  
   280  var destDispatchRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport vmap @(\S+)$`)
   281  var destPortDispatchRegexp = regexp.MustCompile(`^meta l4proto \. th dport vmap @(\S+)$`)
   282  
   283  var sourceAddrRegexp = regexp.MustCompile(`^ip6* saddr (!= )?(\S+)`)
   284  var sourceAddrLookupRegexp = regexp.MustCompile(`^ip6* saddr (!= )?\{([^}]*)\}`)
   285  var sourceAddrLocalRegexp = regexp.MustCompile(`^fib saddr type local`)
   286  
   287  var endpointVMAPRegexp = regexp.MustCompile(`^numgen random mod \d+ vmap \{(.*)\}$`)
   288  var endpointVMapEntryRegexp = regexp.MustCompile(`\d+ : goto (\S+)`)
   289  
   290  var masqueradeRegexp = regexp.MustCompile(`^jump ` + markMasqChain + `$`)
   291  var jumpRegexp = regexp.MustCompile(`^(jump|goto) (\S+)$`)
   292  var returnRegexp = regexp.MustCompile(`^return$`)
   293  var verdictRegexp = regexp.MustCompile(`^(drop|reject)$`)
   294  var dnatRegexp = regexp.MustCompile(`^meta l4proto (tcp|udp|sctp) dnat to (\S+)$`)
   295  
   296  var ignoredRegexp = regexp.MustCompile(strings.Join(
   297  	[]string{
   298  		// Ignore comments (which can only appear at the end of a rule).
   299  		` *comment "[^"]*"$`,
   300  
   301  		// The trace tests only check new connections, so for our purposes, this
   302  		// check always succeeds (and thus can be ignored).
   303  		`^ct state new`,
   304  	},
   305  	"|",
   306  ))
   307  
   308  // runChain runs the given packet through the rules in the given table and chain, updating
   309  // tracer's internal state accordingly. It returns true if it hits a terminal action.
   310  func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destPort string) bool {
   311  	ch := tracer.nft.Table.Chains[chname]
   312  	if ch == nil {
   313  		tracer.t.Errorf("unknown chain %q", chname)
   314  		return true
   315  	}
   316  
   317  	for _, ruleObj := range ch.Rules {
   318  		rule := ignoredRegexp.ReplaceAllLiteralString(ruleObj.Rule, "")
   319  		for rule != "" {
   320  			rule = strings.TrimLeft(rule, " ")
   321  
   322  			// Note that the order of (some of) the cases is important. e.g.,
   323  			// masqueradeRegexp must be checked before jumpRegexp, since
   324  			// jumpRegexp would also match masqueradeRegexp but do the wrong
   325  			// thing with it.
   326  
   327  			switch {
   328  			case destIPOnlyLookupRegexp.MatchString(rule):
   329  				// `^ip6* daddr @(\S+)`
   330  				// Tests whether destIP is a member of the indicated set.
   331  				match := destIPOnlyLookupRegexp.FindStringSubmatch(rule)
   332  				rule = strings.TrimPrefix(rule, match[0])
   333  				set := match[1]
   334  				if tracer.matchDestIPOnly(tracer.nft.Table.Sets[set].Elements, destIP) == nil {
   335  					rule = ""
   336  					break
   337  				}
   338  
   339  			case destSourceLookupRegexp.MatchString(rule):
   340  				// `^ip6* daddr . meta l4proto . th dport . ip6* saddr @(\S+)`
   341  				// Tests whether "destIP . protocol . destPort . sourceIP" is
   342  				// a member of the indicated set.
   343  				match := destSourceLookupRegexp.FindStringSubmatch(rule)
   344  				rule = strings.TrimPrefix(rule, match[0])
   345  				set := match[1]
   346  				if tracer.matchDestAndSource(tracer.nft.Table.Sets[set].Elements, destIP, protocol, destPort, sourceIP) == nil {
   347  					rule = ""
   348  					break
   349  				}
   350  
   351  			case destLookupRegexp.MatchString(rule):
   352  				// `^ip6* daddr . meta l4proto . th dport @(\S+)`
   353  				// Tests whether "destIP . protocol . destPort" is a member
   354  				// of the indicated set.
   355  				match := destLookupRegexp.FindStringSubmatch(rule)
   356  				rule = strings.TrimPrefix(rule, match[0])
   357  				set := match[1]
   358  				if tracer.matchDest(tracer.nft.Table.Sets[set].Elements, destIP, protocol, destPort) == nil {
   359  					rule = ""
   360  					break
   361  				}
   362  
   363  			case destPortLookupRegexp.MatchString(rule):
   364  				// `^meta l4proto . th dport @(\S+)`
   365  				// Tests whether "protocol . destPort" is a member of the
   366  				// indicated set.
   367  				match := destPortLookupRegexp.FindStringSubmatch(rule)
   368  				rule = strings.TrimPrefix(rule, match[0])
   369  				set := match[1]
   370  				if tracer.matchDestPort(tracer.nft.Table.Sets[set].Elements, protocol, destPort) == nil {
   371  					rule = ""
   372  					break
   373  				}
   374  
   375  			case destDispatchRegexp.MatchString(rule):
   376  				// `^ip6* daddr \. meta l4proto \. th dport vmap @(\S+)$`
   377  				// Looks up "destIP . protocol . destPort" in the indicated
   378  				// verdict map, and if found, runs the assocated verdict.
   379  				match := destDispatchRegexp.FindStringSubmatch(rule)
   380  				mapName := match[1]
   381  				element := tracer.matchDest(tracer.nft.Table.Maps[mapName].Elements, destIP, protocol, destPort)
   382  				if element == nil {
   383  					rule = ""
   384  					break
   385  				} else {
   386  					rule = element.Value[0]
   387  				}
   388  
   389  			case destPortDispatchRegexp.MatchString(rule):
   390  				// `^meta l4proto \. th dport vmap @(\S+)$`
   391  				// Looks up "protocol . destPort" in the indicated verdict map,
   392  				// and if found, runs the assocated verdict.
   393  				match := destPortDispatchRegexp.FindStringSubmatch(rule)
   394  				mapName := match[1]
   395  				element := tracer.matchDestPort(tracer.nft.Table.Maps[mapName].Elements, protocol, destPort)
   396  				if element == nil {
   397  					rule = ""
   398  					break
   399  				} else {
   400  					rule = element.Value[0]
   401  				}
   402  
   403  			case destAddrLookupRegexp.MatchString(rule):
   404  				// `^ip6* daddr (!= )?\{([^}]*)\}`
   405  				// Tests whether destIP doesn't match an anonymous set.
   406  				match := destAddrLookupRegexp.FindStringSubmatch(rule)
   407  				rule = strings.TrimPrefix(rule, match[0])
   408  				wantMatch, set := match[1] != "!= ", match[2]
   409  				if !tracer.addressMatchesSet(destIP, wantMatch, set) {
   410  					rule = ""
   411  					break
   412  				}
   413  
   414  			case destAddrRegexp.MatchString(rule):
   415  				// `^ip6* daddr (!= )?(\S+)`
   416  				// Tests whether destIP does/doesn't match a literal.
   417  				match := destAddrRegexp.FindStringSubmatch(rule)
   418  				rule = strings.TrimPrefix(rule, match[0])
   419  				wantMatch, ip := match[1] != "!= ", match[2]
   420  				if !tracer.addressMatches(destIP, wantMatch, ip) {
   421  					rule = ""
   422  					break
   423  				}
   424  
   425  			case destAddrLocalRegexp.MatchString(rule):
   426  				// `^fib daddr type local`
   427  				// Tests whether destIP is a local IP.
   428  				match := destAddrLocalRegexp.FindStringSubmatch(rule)
   429  				rule = strings.TrimPrefix(rule, match[0])
   430  				if !tracer.nodeIPs.Has(destIP) {
   431  					rule = ""
   432  					break
   433  				}
   434  
   435  			case destPortRegexp.MatchString(rule):
   436  				// `^(tcp|udp|sctp) dport (\d+)`
   437  				// Tests whether destPort matches a literal.
   438  				match := destPortRegexp.FindStringSubmatch(rule)
   439  				rule = strings.TrimPrefix(rule, match[0])
   440  				proto, port := match[1], match[2]
   441  				if protocol != proto || destPort != port {
   442  					rule = ""
   443  					break
   444  				}
   445  
   446  			case sourceAddrLookupRegexp.MatchString(rule):
   447  				// `^ip6* saddr (!= )?\{([^}]*)\}`
   448  				// Tests whether sourceIP doesn't match an anonymous set.
   449  				match := sourceAddrLookupRegexp.FindStringSubmatch(rule)
   450  				rule = strings.TrimPrefix(rule, match[0])
   451  				wantMatch, set := match[1] != "!= ", match[2]
   452  				if !tracer.addressMatchesSet(sourceIP, wantMatch, set) {
   453  					rule = ""
   454  					break
   455  				}
   456  
   457  			case sourceAddrRegexp.MatchString(rule):
   458  				// `^ip6* saddr (!= )?(\S+)`
   459  				// Tests whether sourceIP does/doesn't match a literal.
   460  				match := sourceAddrRegexp.FindStringSubmatch(rule)
   461  				rule = strings.TrimPrefix(rule, match[0])
   462  				wantMatch, ip := match[1] != "!= ", match[2]
   463  				if !tracer.addressMatches(sourceIP, wantMatch, ip) {
   464  					rule = ""
   465  					break
   466  				}
   467  
   468  			case sourceAddrLocalRegexp.MatchString(rule):
   469  				// `^fib saddr type local`
   470  				// Tests whether sourceIP is a local IP.
   471  				match := sourceAddrLocalRegexp.FindStringSubmatch(rule)
   472  				rule = strings.TrimPrefix(rule, match[0])
   473  				if !tracer.nodeIPs.Has(sourceIP) {
   474  					rule = ""
   475  					break
   476  				}
   477  
   478  			case masqueradeRegexp.MatchString(rule):
   479  				// `^jump mark-for-masquerade$`
   480  				// Mark for masquerade: we just treat the jump rule itself as
   481  				// being what creates the mark, rather than trying to handle
   482  				// the rules inside that chain and the "masquerading" chain.
   483  				match := jumpRegexp.FindStringSubmatch(rule)
   484  				rule = strings.TrimPrefix(rule, match[0])
   485  
   486  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   487  				tracer.markMasq = true
   488  
   489  			case jumpRegexp.MatchString(rule):
   490  				// `^(jump|goto) (\S+)$`
   491  				// Jumps to another chain.
   492  				match := jumpRegexp.FindStringSubmatch(rule)
   493  				rule = strings.TrimPrefix(rule, match[0])
   494  				action, destChain := match[1], match[2]
   495  
   496  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   497  				terminated := tracer.runChain(destChain, sourceIP, protocol, destIP, destPort)
   498  				if terminated {
   499  					// destChain reached a terminal statement, so we
   500  					// terminate too.
   501  					return true
   502  				} else if action == "goto" {
   503  					// After a goto, return to our calling chain
   504  					// (without terminating) rather than continuing
   505  					// with this chain.
   506  					return false
   507  				}
   508  
   509  			case verdictRegexp.MatchString(rule):
   510  				// `^(drop|reject)$`
   511  				// Drop/reject the packet and terminate processing.
   512  				match := verdictRegexp.FindStringSubmatch(rule)
   513  				verdict := match[1]
   514  
   515  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   516  				tracer.outputs = append(tracer.outputs, strings.ToUpper(verdict))
   517  				return true
   518  
   519  			case returnRegexp.MatchString(rule):
   520  				// `^return$`
   521  				// Returns to the calling chain.
   522  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   523  				return false
   524  
   525  			case dnatRegexp.MatchString(rule):
   526  				// `meta l4proto (tcp|udp|sctp) dnat to (\S+)`
   527  				// DNAT to an endpoint IP and terminate processing.
   528  				match := dnatRegexp.FindStringSubmatch(rule)
   529  				destEndpoint := match[2]
   530  
   531  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   532  				tracer.outputs = append(tracer.outputs, destEndpoint)
   533  				return true
   534  
   535  			case endpointVMAPRegexp.MatchString(rule):
   536  				// `^numgen random mod \d+ vmap \{(.*)\}$`
   537  				// Selects a random endpoint and jumps to it. For tracePacket's
   538  				// purposes, we jump to *all* of the endpoints.
   539  				match := endpointVMAPRegexp.FindStringSubmatch(rule)
   540  				elements := match[1]
   541  
   542  				for _, match = range endpointVMapEntryRegexp.FindAllStringSubmatch(elements, -1) {
   543  					// `\d+ : goto (\S+)`
   544  					destChain := match[1]
   545  
   546  					tracer.matches = append(tracer.matches, ruleObj.Rule)
   547  					// Ignore return value; we know each endpoint has a
   548  					// terminating dnat verdict, but we want to gather all
   549  					// of the endpoints into tracer.output.
   550  					_ = tracer.runChain(destChain, sourceIP, protocol, destIP, destPort)
   551  				}
   552  				return true
   553  
   554  			default:
   555  				tracer.t.Errorf("unmatched rule: %s", ruleObj.Rule)
   556  				rule = ""
   557  			}
   558  		}
   559  	}
   560  
   561  	return false
   562  }
   563  
   564  // tracePacket determines what would happen to a packet with the given sourceIP, destIP,
   565  // and destPort, given the indicated iptables ruleData. nodeIPs are the local node IPs (for
   566  // rules matching "local"). (The protocol value should be lowercase as in nftables
   567  // rules, not uppercase as in corev1.)
   568  //
   569  // The return values are: an array of matched rules (for debugging), the final packet
   570  // destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT",
   571  // "DROP", or "REJECT"), and whether the packet would be masqueraded.
   572  func tracePacket(t *testing.T, nft *knftables.Fake, sourceIP, protocol, destIP, destPort string, nodeIPs []string) ([]string, string, bool) {
   573  	var err error
   574  	tracer := newNFTablesTracer(t, nft, nodeIPs)
   575  
   576  	// filter-prerouting goes first, then nat-prerouting if not terminated.
   577  	if tracer.runChain("filter-prerouting", sourceIP, protocol, destIP, destPort) {
   578  		return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq
   579  	}
   580  	tracer.runChain("nat-prerouting", sourceIP, protocol, destIP, destPort)
   581  	// After the prerouting rules run, pending DNATs are processed (which would affect
   582  	// the destination IP that later rules match against).
   583  	if len(tracer.outputs) != 0 {
   584  		destIP, _, err = net.SplitHostPort(tracer.outputs[0])
   585  		if err != nil {
   586  			t.Errorf("failed to parse host port '%s': %s", tracer.outputs[0], err.Error())
   587  		}
   588  	}
   589  
   590  	// Run filter-forward, return if packet is terminated.
   591  	if tracer.runChain("filter-forward", sourceIP, protocol, destIP, destPort) {
   592  		return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq
   593  	}
   594  
   595  	// Run filter-input
   596  	tracer.runChain("filter-input", sourceIP, protocol, destIP, destPort)
   597  
   598  	// Skip filter-output and nat-output as they ought to be fully redundant with the prerouting chains.
   599  	// Skip nat-postrouting because it only does masquerading and we handle that separately.
   600  	return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq
   601  }
   602  
   603  type packetFlowTest struct {
   604  	name     string
   605  	sourceIP string
   606  	protocol v1.Protocol
   607  	destIP   string
   608  	destPort int
   609  	output   string
   610  	masq     bool
   611  }
   612  
   613  func runPacketFlowTests(t *testing.T, line string, nft *knftables.Fake, nodeIPs []string, testCases []packetFlowTest) {
   614  	for _, tc := range testCases {
   615  		t.Run(tc.name, func(t *testing.T) {
   616  			protocol := strings.ToLower(string(tc.protocol))
   617  			if protocol == "" {
   618  				protocol = "tcp"
   619  			}
   620  			matches, output, masq := tracePacket(t, nft, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIPs)
   621  			var errors []string
   622  			if output != tc.output {
   623  				errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output))
   624  			}
   625  			if masq != tc.masq {
   626  				errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq))
   627  			}
   628  			if errors != nil {
   629  				t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
   630  					tc.name, tc.sourceIP, tc.destIP, tc.destPort, line, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
   631  			}
   632  		})
   633  	}
   634  }
   635  
   636  // helpers_test unit tests
   637  
   638  var testInput = dedent.Dedent(`
   639  	add table ip testing { comment "rules for kube-proxy" ; }
   640  
   641  	add chain ip testing mark-for-masquerade
   642  	add rule ip testing mark-for-masquerade mark set mark or 0x4000
   643  	add chain ip testing masquerading
   644  	add rule ip testing masquerading mark and 0x4000 == 0 return
   645  	add rule ip testing masquerading mark set mark xor 0x4000
   646  	add rule ip testing masquerading masquerade fully-random
   647  
   648  	add set ip testing firewall { type ipv4_addr . inet_proto . inet_service ; comment "destinations that are subject to LoadBalancerSourceRanges" ; }
   649  	add set ip testing firewall-allow { type ipv4_addr . inet_proto . inet_service . ipv4_addr ; flags interval ; comment "destinations+sources that are allowed by LoadBalancerSourceRanges" ; }
   650  	add chain ip testing firewall-check
   651  	add chain ip testing firewall-allow-check
   652  	add rule ip testing firewall-allow-check ip daddr . meta l4proto . th dport . ip saddr @firewall-allow return
   653  	add rule ip testing firewall-allow-check drop
   654  	add rule ip testing firewall-check ip daddr . meta l4proto . th dport @firewall jump firewall-allow-check
   655  
   656  	# svc1
   657  	add chain ip testing service-ULMVA6XW-ns1/svc1/tcp/p80
   658  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 ip daddr 172.30.0.41 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   659  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 }
   660  
   661  	add chain ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80
   662  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 ip saddr 10.180.0.1 jump mark-for-masquerade
   663  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 meta l4proto tcp dnat to 10.180.0.1:80
   664  
   665  	add element ip testing service-ips { 172.30.0.41 . tcp . 80 : goto service-ULMVA6XW-ns1/svc1/tcp/p80 }
   666  
   667  	# svc2
   668  	add chain ip testing service-42NFTM6N-ns2/svc2/tcp/p80
   669  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 ip daddr 172.30.0.42 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   670  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 }
   671  	add chain ip testing external-42NFTM6N-ns2/svc2/tcp/p80
   672  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 ip saddr 10.0.0.0/8 goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit pod traffic"
   673  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local jump mark-for-masquerade comment "masquerade local traffic"
   674  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit local traffic"
   675  	add chain ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80
   676  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 ip saddr 10.180.0.2 jump mark-for-masquerade
   677  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 meta l4proto tcp dnat to 10.180.0.2:80
   678  
   679  	add element ip testing service-ips { 172.30.0.42 . tcp . 80 : goto service-42NFTM6N-ns2/svc2/tcp/p80 }
   680  	add element ip testing service-ips { 192.168.99.22 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   681  	add element ip testing service-ips { 1.2.3.4 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   682  	add element ip testing service-nodeports { tcp . 3001 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   683  
   684  	add element ip testing no-endpoint-nodeports { tcp . 3001 comment "ns2/svc2:p80" : drop }
   685  	add element ip testing no-endpoint-services { 1.2.3.4 . tcp . 80 comment "ns2/svc2:p80" : drop }
   686  	add element ip testing no-endpoint-services { 192.168.99.22 . tcp . 80 comment "ns2/svc2:p80" : drop }
   687  	`)
   688  
   689  var testExpected = dedent.Dedent(`
   690  	add table ip testing { comment "rules for kube-proxy" ; }
   691  	add chain ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80
   692  	add chain ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80
   693  	add chain ip testing external-42NFTM6N-ns2/svc2/tcp/p80
   694  	add chain ip testing firewall-allow-check
   695  	add chain ip testing firewall-check
   696  	add chain ip testing mark-for-masquerade
   697  	add chain ip testing masquerading
   698  	add chain ip testing service-42NFTM6N-ns2/svc2/tcp/p80
   699  	add chain ip testing service-ULMVA6XW-ns1/svc1/tcp/p80
   700  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 ip saddr 10.180.0.1 jump mark-for-masquerade
   701  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 meta l4proto tcp dnat to 10.180.0.1:80
   702  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 ip saddr 10.180.0.2 jump mark-for-masquerade
   703  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 meta l4proto tcp dnat to 10.180.0.2:80
   704  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 ip saddr 10.0.0.0/8 goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit pod traffic"
   705  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local jump mark-for-masquerade comment "masquerade local traffic"
   706  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit local traffic"
   707  	add rule ip testing firewall-allow-check ip daddr . meta l4proto . th dport . ip saddr @firewall-allow return
   708  	add rule ip testing firewall-allow-check drop
   709  	add rule ip testing firewall-check ip daddr . meta l4proto . th dport @firewall jump firewall-allow-check
   710  	add rule ip testing mark-for-masquerade mark set mark or 0x4000
   711  	add rule ip testing masquerading mark and 0x4000 == 0 return
   712  	add rule ip testing masquerading mark set mark xor 0x4000
   713  	add rule ip testing masquerading masquerade fully-random
   714  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 ip daddr 172.30.0.42 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   715  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 }
   716  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 ip daddr 172.30.0.41 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   717  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 }
   718  	add set ip testing firewall { type ipv4_addr . inet_proto . inet_service ; comment "destinations that are subject to LoadBalancerSourceRanges" ; }
   719  	add set ip testing firewall-allow { type ipv4_addr . inet_proto . inet_service . ipv4_addr ; flags interval ; comment "destinations+sources that are allowed by LoadBalancerSourceRanges" ; }
   720  	add element ip testing no-endpoint-nodeports { tcp . 3001 comment "ns2/svc2:p80" : drop }
   721  	add element ip testing no-endpoint-services { 1.2.3.4 . tcp . 80 comment "ns2/svc2:p80" : drop }
   722  	add element ip testing no-endpoint-services { 192.168.99.22 . tcp . 80 comment "ns2/svc2:p80" : drop }
   723  	add element ip testing service-ips { 1.2.3.4 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   724  	add element ip testing service-ips { 172.30.0.41 . tcp . 80 : goto service-ULMVA6XW-ns1/svc1/tcp/p80 }
   725  	add element ip testing service-ips { 172.30.0.42 . tcp . 80 : goto service-42NFTM6N-ns2/svc2/tcp/p80 }
   726  	add element ip testing service-ips { 192.168.99.22 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   727  	add element ip testing service-nodeports { tcp . 3001 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   728  	`)
   729  
   730  func Test_sortNFTablesTransaction(t *testing.T) {
   731  	output := sortNFTablesTransaction(testInput)
   732  	expected := strings.TrimSpace(testExpected)
   733  
   734  	diff := cmp.Diff(expected, output)
   735  	if diff != "" {
   736  		t.Errorf("output does not match expected:\n%s", diff)
   737  	}
   738  }
   739  
   740  func Test_diffNFTablesTransaction(t *testing.T) {
   741  	diff := diffNFTablesTransaction(testInput, testExpected)
   742  	if diff != "" {
   743  		t.Errorf("found diff in inputs that should have been equal:\n%s", diff)
   744  	}
   745  
   746  	notExpected := strings.Join(strings.Split(testExpected, "\n")[2:], "\n")
   747  	diff = diffNFTablesTransaction(testInput, notExpected)
   748  	if diff == "" {
   749  		t.Errorf("found no diff in inputs that should have been different")
   750  	}
   751  }
   752  
   753  func Test_diffNFTablesChain(t *testing.T) {
   754  	fake := knftables.NewFake(knftables.IPv4Family, "testing")
   755  	tx := fake.NewTransaction()
   756  
   757  	tx.Add(&knftables.Table{})
   758  	tx.Add(&knftables.Chain{
   759  		Name: "mark-masq-chain",
   760  	})
   761  	tx.Add(&knftables.Chain{
   762  		Name: "masquerade-chain",
   763  	})
   764  	tx.Add(&knftables.Chain{
   765  		Name: "empty-chain",
   766  	})
   767  
   768  	tx.Add(&knftables.Rule{
   769  		Chain: "mark-masq-chain",
   770  		Rule:  "mark set mark or 0x4000",
   771  	})
   772  
   773  	tx.Add(&knftables.Rule{
   774  		Chain: "masquerade-chain",
   775  		Rule:  "mark and 0x4000 == 0 return",
   776  	})
   777  	tx.Add(&knftables.Rule{
   778  		Chain: "masquerade-chain",
   779  		Rule:  "mark set mark xor 0x4000",
   780  	})
   781  	tx.Add(&knftables.Rule{
   782  		Chain: "masquerade-chain",
   783  		Rule:  "masquerade fully-random",
   784  	})
   785  
   786  	err := fake.Run(context.Background(), tx)
   787  	if err != nil {
   788  		t.Fatalf("Unexpected error running transaction: %v", err)
   789  	}
   790  
   791  	diff := diffNFTablesChain(fake, "mark-masq-chain", "mark set mark or 0x4000")
   792  	if diff != "" {
   793  		t.Errorf("unexpected difference in mark-masq-chain:\n%s", diff)
   794  	}
   795  	diff = diffNFTablesChain(fake, "mark-masq-chain", "mark set mark or 0x4000\n")
   796  	if diff != "" {
   797  		t.Errorf("unexpected difference in mark-masq-chain with trailing newline:\n%s", diff)
   798  	}
   799  
   800  	diff = diffNFTablesChain(fake, "masquerade-chain", "mark and 0x4000 == 0 return\nmark set mark xor 0x4000\nmasquerade fully-random")
   801  	if diff != "" {
   802  		t.Errorf("unexpected difference in masquerade-chain:\n%s", diff)
   803  	}
   804  	diff = diffNFTablesChain(fake, "masquerade-chain", "mark set mark xor 0x4000\nmasquerade fully-random")
   805  	if diff == "" {
   806  		t.Errorf("unexpected lack of difference in wrong masquerade-chain")
   807  	}
   808  
   809  	diff = diffNFTablesChain(fake, "empty-chain", "")
   810  	if diff != "" {
   811  		t.Errorf("unexpected difference in empty-chain:\n%s", diff)
   812  	}
   813  	diff = diffNFTablesChain(fake, "empty-chain", "\n")
   814  	if diff != "" {
   815  		t.Errorf("unexpected difference in empty-chain with trailing newline:\n%s", diff)
   816  	}
   817  }
   818  

View as plain text