...

Source file src/k8s.io/kubernetes/pkg/util/iptables/monitor_test.go

Documentation: k8s.io/kubernetes/pkg/util/iptables

     1  //go:build linux
     2  // +build linux
     3  
     4  /*
     5  Copyright 2019 The Kubernetes Authors.
     6  
     7  Licensed under the Apache License, Version 2.0 (the "License");
     8  you may not use this file except in compliance with the License.
     9  You may obtain a copy of the License at
    10  
    11      http://www.apache.org/licenses/LICENSE-2.0
    12  
    13  Unless required by applicable law or agreed to in writing, software
    14  distributed under the License is distributed on an "AS IS" BASIS,
    15  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16  See the License for the specific language governing permissions and
    17  limitations under the License.
    18  */
    19  
    20  package iptables
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"io"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"k8s.io/apimachinery/pkg/util/sets"
    32  	utilwait "k8s.io/apimachinery/pkg/util/wait"
    33  	"k8s.io/utils/exec"
    34  )
    35  
    36  // We can't use the normal FakeExec because we don't know precisely how many times the
    37  // Monitor thread will do its checks, and we don't know precisely how its iptables calls
    38  // will interleave with the main thread's. So we use our own fake Exec implementation that
    39  // implements a minimal iptables interface. This will need updates as iptables.runner
    40  // changes its use of Exec.
    41  type monitorFakeExec struct {
    42  	sync.Mutex
    43  
    44  	tables map[string]sets.Set[string]
    45  
    46  	block      bool
    47  	wasBlocked bool
    48  }
    49  
    50  func newMonitorFakeExec() *monitorFakeExec {
    51  	tables := make(map[string]sets.Set[string])
    52  	tables["mangle"] = sets.New[string]()
    53  	tables["filter"] = sets.New[string]()
    54  	tables["nat"] = sets.New[string]()
    55  	return &monitorFakeExec{tables: tables}
    56  }
    57  
    58  func (mfe *monitorFakeExec) blockIPTables(block bool) {
    59  	mfe.Lock()
    60  	defer mfe.Unlock()
    61  
    62  	mfe.block = block
    63  }
    64  
    65  func (mfe *monitorFakeExec) getWasBlocked() bool {
    66  	mfe.Lock()
    67  	defer mfe.Unlock()
    68  
    69  	wasBlocked := mfe.wasBlocked
    70  	mfe.wasBlocked = false
    71  	return wasBlocked
    72  }
    73  
    74  func (mfe *monitorFakeExec) Command(cmd string, args ...string) exec.Cmd {
    75  	return &monitorFakeCmd{mfe: mfe, cmd: cmd, args: args}
    76  }
    77  
    78  func (mfe *monitorFakeExec) CommandContext(ctx context.Context, cmd string, args ...string) exec.Cmd {
    79  	return mfe.Command(cmd, args...)
    80  }
    81  
    82  func (mfe *monitorFakeExec) LookPath(file string) (string, error) {
    83  	return file, nil
    84  }
    85  
    86  type monitorFakeCmd struct {
    87  	mfe  *monitorFakeExec
    88  	cmd  string
    89  	args []string
    90  }
    91  
    92  func (mfc *monitorFakeCmd) CombinedOutput() ([]byte, error) {
    93  	if mfc.cmd == cmdIPTablesRestore {
    94  		// Only used for "iptables-restore --version", and the result doesn't matter
    95  		return []byte{}, nil
    96  	} else if mfc.cmd != cmdIPTables {
    97  		panic("bad command " + mfc.cmd)
    98  	}
    99  
   100  	if len(mfc.args) == 1 && mfc.args[0] == "--version" {
   101  		return []byte("iptables v1.6.2"), nil
   102  	}
   103  
   104  	if len(mfc.args) != 8 || mfc.args[0] != WaitString || mfc.args[1] != WaitSecondsValue || mfc.args[2] != WaitIntervalString || mfc.args[3] != WaitIntervalUsecondsValue || mfc.args[6] != "-t" {
   105  		panic(fmt.Sprintf("bad args %#v", mfc.args))
   106  	}
   107  	op := operation(mfc.args[4])
   108  	chainName := mfc.args[5]
   109  	tableName := mfc.args[7]
   110  
   111  	mfc.mfe.Lock()
   112  	defer mfc.mfe.Unlock()
   113  
   114  	table := mfc.mfe.tables[tableName]
   115  	if table == nil {
   116  		return []byte{}, fmt.Errorf("no such table %q", tableName)
   117  	}
   118  
   119  	// For ease-of-testing reasons, blockIPTables blocks create and list, but not delete
   120  	if mfc.mfe.block && op != opDeleteChain {
   121  		mfc.mfe.wasBlocked = true
   122  		return []byte{}, exec.CodeExitError{Code: 4, Err: fmt.Errorf("could not get xtables.lock, etc")}
   123  	}
   124  
   125  	switch op {
   126  	case opCreateChain:
   127  		if !table.Has(chainName) {
   128  			table.Insert(chainName)
   129  		}
   130  		return []byte{}, nil
   131  	case opListChain:
   132  		if table.Has(chainName) {
   133  			return []byte{}, nil
   134  		}
   135  		return []byte{}, fmt.Errorf("no such chain %q", chainName)
   136  	case opDeleteChain:
   137  		table.Delete(chainName)
   138  		return []byte{}, nil
   139  	default:
   140  		panic("should not be reached")
   141  	}
   142  }
   143  
   144  func (mfc *monitorFakeCmd) SetStdin(in io.Reader) {
   145  	// Used by getIPTablesRestoreVersionString(), can be ignored
   146  }
   147  
   148  func (mfc *monitorFakeCmd) Run() error {
   149  	panic("should not be reached")
   150  }
   151  
   152  func (mfc *monitorFakeCmd) Output() ([]byte, error) {
   153  	panic("should not be reached")
   154  }
   155  
   156  func (mfc *monitorFakeCmd) SetDir(dir string) {
   157  	panic("should not be reached")
   158  }
   159  
   160  func (mfc *monitorFakeCmd) SetStdout(out io.Writer) {
   161  	panic("should not be reached")
   162  }
   163  
   164  func (mfc *monitorFakeCmd) SetStderr(out io.Writer) {
   165  	panic("should not be reached")
   166  }
   167  
   168  func (mfc *monitorFakeCmd) SetEnv(env []string) {
   169  	panic("should not be reached")
   170  }
   171  
   172  func (mfc *monitorFakeCmd) StdoutPipe() (io.ReadCloser, error) {
   173  	panic("should not be reached")
   174  }
   175  
   176  func (mfc *monitorFakeCmd) StderrPipe() (io.ReadCloser, error) {
   177  	panic("should not be reached")
   178  }
   179  
   180  func (mfc *monitorFakeCmd) Start() error {
   181  	panic("should not be reached")
   182  }
   183  
   184  func (mfc *monitorFakeCmd) Wait() error {
   185  	panic("should not be reached")
   186  }
   187  
   188  func (mfc *monitorFakeCmd) Stop() {
   189  	panic("should not be reached")
   190  }
   191  
   192  func TestIPTablesMonitor(t *testing.T) {
   193  	mfe := newMonitorFakeExec()
   194  	ipt := New(mfe, ProtocolIPv4)
   195  
   196  	var reloads uint32
   197  	stopCh := make(chan struct{})
   198  
   199  	canary := Chain("MONITOR-TEST-CANARY")
   200  	tables := []Table{TableMangle, TableFilter, TableNAT}
   201  	go ipt.Monitor(canary, tables, func() {
   202  		if !ensureNoChains(mfe) {
   203  			t.Errorf("reload called while canaries still exist")
   204  		}
   205  		atomic.AddUint32(&reloads, 1)
   206  	}, 100*time.Millisecond, stopCh)
   207  
   208  	// Monitor should create canary chains quickly
   209  	if err := waitForChains(mfe, canary, tables); err != nil {
   210  		t.Errorf("failed to create iptables canaries: %v", err)
   211  	}
   212  
   213  	if err := waitForReloads(&reloads, 0); err != nil {
   214  		t.Errorf("got unexpected reloads: %v", err)
   215  	}
   216  
   217  	// If we delete all of the chains, it should reload
   218  	ipt.DeleteChain(TableMangle, canary)
   219  	ipt.DeleteChain(TableFilter, canary)
   220  	ipt.DeleteChain(TableNAT, canary)
   221  
   222  	if err := waitForReloads(&reloads, 1); err != nil {
   223  		t.Errorf("got unexpected number of reloads after flush: %v", err)
   224  	}
   225  	if err := waitForChains(mfe, canary, tables); err != nil {
   226  		t.Errorf("failed to create iptables canaries: %v", err)
   227  	}
   228  
   229  	// If we delete two chains, it should not reload yet
   230  	ipt.DeleteChain(TableMangle, canary)
   231  	ipt.DeleteChain(TableFilter, canary)
   232  
   233  	if err := waitForNoReload(&reloads, 1); err != nil {
   234  		t.Errorf("got unexpected number of reloads after partial flush: %v", err)
   235  	}
   236  
   237  	// Now ensure that "iptables -L" will get an error about the xtables.lock, and
   238  	// delete the last chain. The monitor should not reload, because it can't actually
   239  	// tell if the chain was deleted or not.
   240  	mfe.blockIPTables(true)
   241  	ipt.DeleteChain(TableNAT, canary)
   242  	if err := waitForBlocked(mfe); err != nil {
   243  		t.Errorf("failed waiting for monitor to be blocked from monitoring: %v", err)
   244  	}
   245  
   246  	// After unblocking the monitor, it should now reload
   247  	mfe.blockIPTables(false)
   248  
   249  	if err := waitForReloads(&reloads, 2); err != nil {
   250  		t.Errorf("got unexpected number of reloads after slow flush: %v", err)
   251  	}
   252  	if err := waitForChains(mfe, canary, tables); err != nil {
   253  		t.Errorf("failed to create iptables canaries: %v", err)
   254  	}
   255  
   256  	// If we close the stop channel, it should stop running
   257  	close(stopCh)
   258  
   259  	if err := waitForNoReload(&reloads, 2); err != nil {
   260  		t.Errorf("got unexpected number of reloads after stop: %v", err)
   261  	}
   262  	if !ensureNoChains(mfe) {
   263  		t.Errorf("canaries still exist after stopping monitor")
   264  	}
   265  
   266  	// If we create a new monitor while the iptables lock is held, it will
   267  	// retry creating canaries until it succeeds
   268  
   269  	stopCh = make(chan struct{})
   270  	_ = mfe.getWasBlocked()
   271  	mfe.blockIPTables(true)
   272  	go ipt.Monitor(canary, tables, func() {
   273  		if !ensureNoChains(mfe) {
   274  			t.Errorf("reload called while canaries still exist")
   275  		}
   276  		atomic.AddUint32(&reloads, 1)
   277  	}, 100*time.Millisecond, stopCh)
   278  
   279  	// Monitor should not have created canaries yet
   280  	if !ensureNoChains(mfe) {
   281  		t.Errorf("canary created while iptables blocked")
   282  	}
   283  
   284  	if err := waitForBlocked(mfe); err != nil {
   285  		t.Errorf("failed waiting for monitor to fail creating canaries: %v", err)
   286  	}
   287  
   288  	mfe.blockIPTables(false)
   289  	if err := waitForChains(mfe, canary, tables); err != nil {
   290  		t.Errorf("failed to create iptables canaries: %v", err)
   291  	}
   292  
   293  	close(stopCh)
   294  }
   295  
   296  func waitForChains(mfe *monitorFakeExec, canary Chain, tables []Table) error {
   297  	return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
   298  		mfe.Lock()
   299  		defer mfe.Unlock()
   300  
   301  		for _, table := range tables {
   302  			if !mfe.tables[string(table)].Has(string(canary)) {
   303  				return false, nil
   304  			}
   305  		}
   306  		return true, nil
   307  	})
   308  }
   309  
   310  func ensureNoChains(mfe *monitorFakeExec) bool {
   311  	mfe.Lock()
   312  	defer mfe.Unlock()
   313  	return mfe.tables["mangle"].Len() == 0 &&
   314  		mfe.tables["filter"].Len() == 0 &&
   315  		mfe.tables["nat"].Len() == 0
   316  }
   317  
   318  func waitForReloads(reloads *uint32, expected uint32) error {
   319  	if atomic.LoadUint32(reloads) < expected {
   320  		utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
   321  			return atomic.LoadUint32(reloads) >= expected, nil
   322  		})
   323  	}
   324  	got := atomic.LoadUint32(reloads)
   325  	if got != expected {
   326  		return fmt.Errorf("expected %d, got %d", expected, got)
   327  	}
   328  	return nil
   329  }
   330  
   331  func waitForNoReload(reloads *uint32, expected uint32) error {
   332  	utilwait.PollImmediate(50*time.Millisecond, 250*time.Millisecond, func() (bool, error) {
   333  		return atomic.LoadUint32(reloads) > expected, nil
   334  	})
   335  
   336  	got := atomic.LoadUint32(reloads)
   337  	if got != expected {
   338  		return fmt.Errorf("expected %d, got %d", expected, got)
   339  	}
   340  	return nil
   341  }
   342  
   343  func waitForBlocked(mfe *monitorFakeExec) error {
   344  	return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
   345  		blocked := mfe.getWasBlocked()
   346  		return blocked, nil
   347  	})
   348  }
   349  

View as plain text