1
2
3
4
19
20 package iptables
21
22 import (
23 "context"
24 "fmt"
25 "io"
26 "sync"
27 "sync/atomic"
28 "testing"
29 "time"
30
31 "k8s.io/apimachinery/pkg/util/sets"
32 utilwait "k8s.io/apimachinery/pkg/util/wait"
33 "k8s.io/utils/exec"
34 )
35
36
37
38
39
40
41 type monitorFakeExec struct {
42 sync.Mutex
43
44 tables map[string]sets.Set[string]
45
46 block bool
47 wasBlocked bool
48 }
49
50 func newMonitorFakeExec() *monitorFakeExec {
51 tables := make(map[string]sets.Set[string])
52 tables["mangle"] = sets.New[string]()
53 tables["filter"] = sets.New[string]()
54 tables["nat"] = sets.New[string]()
55 return &monitorFakeExec{tables: tables}
56 }
57
58 func (mfe *monitorFakeExec) blockIPTables(block bool) {
59 mfe.Lock()
60 defer mfe.Unlock()
61
62 mfe.block = block
63 }
64
65 func (mfe *monitorFakeExec) getWasBlocked() bool {
66 mfe.Lock()
67 defer mfe.Unlock()
68
69 wasBlocked := mfe.wasBlocked
70 mfe.wasBlocked = false
71 return wasBlocked
72 }
73
74 func (mfe *monitorFakeExec) Command(cmd string, args ...string) exec.Cmd {
75 return &monitorFakeCmd{mfe: mfe, cmd: cmd, args: args}
76 }
77
78 func (mfe *monitorFakeExec) CommandContext(ctx context.Context, cmd string, args ...string) exec.Cmd {
79 return mfe.Command(cmd, args...)
80 }
81
82 func (mfe *monitorFakeExec) LookPath(file string) (string, error) {
83 return file, nil
84 }
85
86 type monitorFakeCmd struct {
87 mfe *monitorFakeExec
88 cmd string
89 args []string
90 }
91
92 func (mfc *monitorFakeCmd) CombinedOutput() ([]byte, error) {
93 if mfc.cmd == cmdIPTablesRestore {
94
95 return []byte{}, nil
96 } else if mfc.cmd != cmdIPTables {
97 panic("bad command " + mfc.cmd)
98 }
99
100 if len(mfc.args) == 1 && mfc.args[0] == "--version" {
101 return []byte("iptables v1.6.2"), nil
102 }
103
104 if len(mfc.args) != 8 || mfc.args[0] != WaitString || mfc.args[1] != WaitSecondsValue || mfc.args[2] != WaitIntervalString || mfc.args[3] != WaitIntervalUsecondsValue || mfc.args[6] != "-t" {
105 panic(fmt.Sprintf("bad args %#v", mfc.args))
106 }
107 op := operation(mfc.args[4])
108 chainName := mfc.args[5]
109 tableName := mfc.args[7]
110
111 mfc.mfe.Lock()
112 defer mfc.mfe.Unlock()
113
114 table := mfc.mfe.tables[tableName]
115 if table == nil {
116 return []byte{}, fmt.Errorf("no such table %q", tableName)
117 }
118
119
120 if mfc.mfe.block && op != opDeleteChain {
121 mfc.mfe.wasBlocked = true
122 return []byte{}, exec.CodeExitError{Code: 4, Err: fmt.Errorf("could not get xtables.lock, etc")}
123 }
124
125 switch op {
126 case opCreateChain:
127 if !table.Has(chainName) {
128 table.Insert(chainName)
129 }
130 return []byte{}, nil
131 case opListChain:
132 if table.Has(chainName) {
133 return []byte{}, nil
134 }
135 return []byte{}, fmt.Errorf("no such chain %q", chainName)
136 case opDeleteChain:
137 table.Delete(chainName)
138 return []byte{}, nil
139 default:
140 panic("should not be reached")
141 }
142 }
143
144 func (mfc *monitorFakeCmd) SetStdin(in io.Reader) {
145
146 }
147
148 func (mfc *monitorFakeCmd) Run() error {
149 panic("should not be reached")
150 }
151
152 func (mfc *monitorFakeCmd) Output() ([]byte, error) {
153 panic("should not be reached")
154 }
155
156 func (mfc *monitorFakeCmd) SetDir(dir string) {
157 panic("should not be reached")
158 }
159
160 func (mfc *monitorFakeCmd) SetStdout(out io.Writer) {
161 panic("should not be reached")
162 }
163
164 func (mfc *monitorFakeCmd) SetStderr(out io.Writer) {
165 panic("should not be reached")
166 }
167
168 func (mfc *monitorFakeCmd) SetEnv(env []string) {
169 panic("should not be reached")
170 }
171
172 func (mfc *monitorFakeCmd) StdoutPipe() (io.ReadCloser, error) {
173 panic("should not be reached")
174 }
175
176 func (mfc *monitorFakeCmd) StderrPipe() (io.ReadCloser, error) {
177 panic("should not be reached")
178 }
179
180 func (mfc *monitorFakeCmd) Start() error {
181 panic("should not be reached")
182 }
183
184 func (mfc *monitorFakeCmd) Wait() error {
185 panic("should not be reached")
186 }
187
188 func (mfc *monitorFakeCmd) Stop() {
189 panic("should not be reached")
190 }
191
192 func TestIPTablesMonitor(t *testing.T) {
193 mfe := newMonitorFakeExec()
194 ipt := New(mfe, ProtocolIPv4)
195
196 var reloads uint32
197 stopCh := make(chan struct{})
198
199 canary := Chain("MONITOR-TEST-CANARY")
200 tables := []Table{TableMangle, TableFilter, TableNAT}
201 go ipt.Monitor(canary, tables, func() {
202 if !ensureNoChains(mfe) {
203 t.Errorf("reload called while canaries still exist")
204 }
205 atomic.AddUint32(&reloads, 1)
206 }, 100*time.Millisecond, stopCh)
207
208
209 if err := waitForChains(mfe, canary, tables); err != nil {
210 t.Errorf("failed to create iptables canaries: %v", err)
211 }
212
213 if err := waitForReloads(&reloads, 0); err != nil {
214 t.Errorf("got unexpected reloads: %v", err)
215 }
216
217
218 ipt.DeleteChain(TableMangle, canary)
219 ipt.DeleteChain(TableFilter, canary)
220 ipt.DeleteChain(TableNAT, canary)
221
222 if err := waitForReloads(&reloads, 1); err != nil {
223 t.Errorf("got unexpected number of reloads after flush: %v", err)
224 }
225 if err := waitForChains(mfe, canary, tables); err != nil {
226 t.Errorf("failed to create iptables canaries: %v", err)
227 }
228
229
230 ipt.DeleteChain(TableMangle, canary)
231 ipt.DeleteChain(TableFilter, canary)
232
233 if err := waitForNoReload(&reloads, 1); err != nil {
234 t.Errorf("got unexpected number of reloads after partial flush: %v", err)
235 }
236
237
238
239
240 mfe.blockIPTables(true)
241 ipt.DeleteChain(TableNAT, canary)
242 if err := waitForBlocked(mfe); err != nil {
243 t.Errorf("failed waiting for monitor to be blocked from monitoring: %v", err)
244 }
245
246
247 mfe.blockIPTables(false)
248
249 if err := waitForReloads(&reloads, 2); err != nil {
250 t.Errorf("got unexpected number of reloads after slow flush: %v", err)
251 }
252 if err := waitForChains(mfe, canary, tables); err != nil {
253 t.Errorf("failed to create iptables canaries: %v", err)
254 }
255
256
257 close(stopCh)
258
259 if err := waitForNoReload(&reloads, 2); err != nil {
260 t.Errorf("got unexpected number of reloads after stop: %v", err)
261 }
262 if !ensureNoChains(mfe) {
263 t.Errorf("canaries still exist after stopping monitor")
264 }
265
266
267
268
269 stopCh = make(chan struct{})
270 _ = mfe.getWasBlocked()
271 mfe.blockIPTables(true)
272 go ipt.Monitor(canary, tables, func() {
273 if !ensureNoChains(mfe) {
274 t.Errorf("reload called while canaries still exist")
275 }
276 atomic.AddUint32(&reloads, 1)
277 }, 100*time.Millisecond, stopCh)
278
279
280 if !ensureNoChains(mfe) {
281 t.Errorf("canary created while iptables blocked")
282 }
283
284 if err := waitForBlocked(mfe); err != nil {
285 t.Errorf("failed waiting for monitor to fail creating canaries: %v", err)
286 }
287
288 mfe.blockIPTables(false)
289 if err := waitForChains(mfe, canary, tables); err != nil {
290 t.Errorf("failed to create iptables canaries: %v", err)
291 }
292
293 close(stopCh)
294 }
295
296 func waitForChains(mfe *monitorFakeExec, canary Chain, tables []Table) error {
297 return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
298 mfe.Lock()
299 defer mfe.Unlock()
300
301 for _, table := range tables {
302 if !mfe.tables[string(table)].Has(string(canary)) {
303 return false, nil
304 }
305 }
306 return true, nil
307 })
308 }
309
310 func ensureNoChains(mfe *monitorFakeExec) bool {
311 mfe.Lock()
312 defer mfe.Unlock()
313 return mfe.tables["mangle"].Len() == 0 &&
314 mfe.tables["filter"].Len() == 0 &&
315 mfe.tables["nat"].Len() == 0
316 }
317
318 func waitForReloads(reloads *uint32, expected uint32) error {
319 if atomic.LoadUint32(reloads) < expected {
320 utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
321 return atomic.LoadUint32(reloads) >= expected, nil
322 })
323 }
324 got := atomic.LoadUint32(reloads)
325 if got != expected {
326 return fmt.Errorf("expected %d, got %d", expected, got)
327 }
328 return nil
329 }
330
331 func waitForNoReload(reloads *uint32, expected uint32) error {
332 utilwait.PollImmediate(50*time.Millisecond, 250*time.Millisecond, func() (bool, error) {
333 return atomic.LoadUint32(reloads) > expected, nil
334 })
335
336 got := atomic.LoadUint32(reloads)
337 if got != expected {
338 return fmt.Errorf("expected %d, got %d", expected, got)
339 }
340 return nil
341 }
342
343 func waitForBlocked(mfe *monitorFakeExec) error {
344 return utilwait.PollImmediate(100*time.Millisecond, time.Second, func() (bool, error) {
345 blocked := mfe.getWasBlocked()
346 return blocked, nil
347 })
348 }
349
View as plain text