1
2
3
4
19
20 package testing
21
22 import (
23 "bytes"
24 "fmt"
25 "strings"
26 "time"
27
28 "k8s.io/apimachinery/pkg/util/sets"
29 "k8s.io/kubernetes/pkg/util/iptables"
30 )
31
32
33 type FakeIPTables struct {
34 hasRandomFully bool
35 protocol iptables.Protocol
36
37 Dump *IPTablesDump
38 }
39
40
41 func NewFake() *FakeIPTables {
42 f := &FakeIPTables{
43 protocol: iptables.ProtocolIPv4,
44 Dump: &IPTablesDump{
45 Tables: []Table{
46 {
47 Name: iptables.TableNAT,
48 Chains: []Chain{
49 {Name: iptables.ChainPrerouting},
50 {Name: iptables.ChainInput},
51 {Name: iptables.ChainOutput},
52 {Name: iptables.ChainPostrouting},
53 },
54 },
55 {
56 Name: iptables.TableFilter,
57 Chains: []Chain{
58 {Name: iptables.ChainInput},
59 {Name: iptables.ChainForward},
60 {Name: iptables.ChainOutput},
61 },
62 },
63 {
64 Name: iptables.TableMangle,
65 Chains: []Chain{},
66 },
67 },
68 },
69 }
70
71 return f
72 }
73
74
75 func NewIPv6Fake() *FakeIPTables {
76 f := NewFake()
77 f.protocol = iptables.ProtocolIPv6
78 return f
79 }
80
81
82 func (f *FakeIPTables) SetHasRandomFully(can bool) *FakeIPTables {
83 f.hasRandomFully = can
84 return f
85 }
86
87
88 func (f *FakeIPTables) EnsureChain(table iptables.Table, chain iptables.Chain) (bool, error) {
89 t, err := f.Dump.GetTable(table)
90 if err != nil {
91 return false, err
92 }
93 if c, _ := f.Dump.GetChain(table, chain); c != nil {
94 return true, nil
95 }
96 t.Chains = append(t.Chains, Chain{Name: chain})
97 return false, nil
98 }
99
100
101 func (f *FakeIPTables) FlushChain(table iptables.Table, chain iptables.Chain) error {
102 if c, _ := f.Dump.GetChain(table, chain); c != nil {
103 c.Rules = nil
104 }
105 return nil
106 }
107
108
109 func (f *FakeIPTables) DeleteChain(table iptables.Table, chain iptables.Chain) error {
110 t, err := f.Dump.GetTable(table)
111 if err != nil {
112 return err
113 }
114 for i := range t.Chains {
115 if t.Chains[i].Name == chain {
116 t.Chains = append(t.Chains[:i], t.Chains[i+1:]...)
117 return nil
118 }
119 }
120 return nil
121 }
122
123
124 func (f *FakeIPTables) ChainExists(table iptables.Table, chain iptables.Chain) (bool, error) {
125 if _, err := f.Dump.GetTable(table); err != nil {
126 return false, err
127 }
128 if c, _ := f.Dump.GetChain(table, chain); c != nil {
129 return true, nil
130 }
131 return false, nil
132 }
133
134
135 func (f *FakeIPTables) EnsureRule(position iptables.RulePosition, table iptables.Table, chain iptables.Chain, args ...string) (bool, error) {
136 c, err := f.Dump.GetChain(table, chain)
137 if err != nil {
138 return false, err
139 }
140
141 rule := "-A " + string(chain) + " " + strings.Join(args, " ")
142 for _, r := range c.Rules {
143 if r.Raw == rule {
144 return true, nil
145 }
146 }
147
148 parsed, err := ParseRule(rule, false)
149 if err != nil {
150 return false, err
151 }
152
153 if position == iptables.Append {
154 c.Rules = append(c.Rules, parsed)
155 } else {
156 c.Rules = append([]*Rule{parsed}, c.Rules...)
157 }
158 return false, nil
159 }
160
161
162 func (f *FakeIPTables) DeleteRule(table iptables.Table, chain iptables.Chain, args ...string) error {
163 c, err := f.Dump.GetChain(table, chain)
164 if err != nil {
165 return err
166 }
167
168 rule := "-A " + string(chain) + " " + strings.Join(args, " ")
169 for i, r := range c.Rules {
170 if r.Raw == rule {
171 c.Rules = append(c.Rules[:i], c.Rules[i+1:]...)
172 break
173 }
174 }
175 return nil
176 }
177
178
179 func (f *FakeIPTables) IsIPv6() bool {
180 return f.protocol == iptables.ProtocolIPv6
181 }
182
183
184 func (f *FakeIPTables) Protocol() iptables.Protocol {
185 return f.protocol
186 }
187
188 func (f *FakeIPTables) saveTable(table iptables.Table, buffer *bytes.Buffer) error {
189 t, err := f.Dump.GetTable(table)
190 if err != nil {
191 return err
192 }
193
194 fmt.Fprintf(buffer, "*%s\n", table)
195 for _, c := range t.Chains {
196 fmt.Fprintf(buffer, ":%s - [%d:%d]\n", c.Name, c.Packets, c.Bytes)
197 }
198 for _, c := range t.Chains {
199 for _, r := range c.Rules {
200 fmt.Fprintf(buffer, "%s\n", r.Raw)
201 }
202 }
203 fmt.Fprintf(buffer, "COMMIT\n")
204 return nil
205 }
206
207
208 func (f *FakeIPTables) SaveInto(table iptables.Table, buffer *bytes.Buffer) error {
209 if table == "" {
210
211
212 for i := range f.Dump.Tables {
213 err := f.saveTable(f.Dump.Tables[i].Name, buffer)
214 if err != nil {
215 return err
216 }
217 }
218 return nil
219 }
220
221 return f.saveTable(table, buffer)
222 }
223
224
225 var builtinTargets = sets.New("ACCEPT", "DROP", "RETURN", "REJECT", "DNAT", "SNAT", "MASQUERADE", "MARK")
226
227 func (f *FakeIPTables) restoreTable(newDump *IPTablesDump, newTable *Table, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
228 oldTable, err := f.Dump.GetTable(newTable.Name)
229 if err != nil {
230 return err
231 }
232
233 backupChains := make([]Chain, len(oldTable.Chains))
234 copy(backupChains, oldTable.Chains)
235
236
237 if flush == iptables.FlushTables {
238 oldTable.Chains = make([]Chain, 0, len(newTable.Chains))
239 }
240 for _, newChain := range newTable.Chains {
241 oldChain, _ := f.Dump.GetChain(newTable.Name, newChain.Name)
242 switch {
243 case oldChain == nil && newChain.Deleted:
244
245 case oldChain == nil && !newChain.Deleted:
246 oldTable.Chains = append(oldTable.Chains, newChain)
247 case oldChain != nil && newChain.Deleted:
248 _ = f.DeleteChain(newTable.Name, newChain.Name)
249 case oldChain != nil && !newChain.Deleted:
250
251 oldChain.Rules = newChain.Rules
252 if counters == iptables.RestoreCounters {
253 oldChain.Packets = newChain.Packets
254 oldChain.Bytes = newChain.Bytes
255 }
256 }
257 }
258
259
260 for _, chain := range oldTable.Chains {
261 for _, rule := range chain.Rules {
262 if rule.Jump == nil {
263 continue
264 }
265 if builtinTargets.Has(rule.Jump.Value) {
266 continue
267 }
268
269 jumpedChain, _ := f.Dump.GetChain(oldTable.Name, iptables.Chain(rule.Jump.Value))
270 if jumpedChain == nil {
271 newChain, _ := newDump.GetChain(oldTable.Name, iptables.Chain(rule.Jump.Value))
272 if newChain != nil {
273
274
275 oldTable.Chains = backupChains
276 return fmt.Errorf("deleted chain %q is referenced by existing rules", newChain.Name)
277 } else {
278
279
280 oldTable.Chains = backupChains
281 return fmt.Errorf("rule %q jumps to a non-existent chain", rule.Raw)
282 }
283 }
284 }
285 }
286
287 return nil
288 }
289
290
291 func (f *FakeIPTables) Restore(table iptables.Table, data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
292 dump, err := ParseIPTablesDump(string(data))
293 if err != nil {
294 return err
295 }
296
297 newTable, err := dump.GetTable(table)
298 if err != nil {
299 return err
300 }
301
302 return f.restoreTable(dump, newTable, flush, counters)
303 }
304
305
306 func (f *FakeIPTables) RestoreAll(data []byte, flush iptables.FlushFlag, counters iptables.RestoreCountersFlag) error {
307 dump, err := ParseIPTablesDump(string(data))
308 if err != nil {
309 return err
310 }
311
312 for i := range dump.Tables {
313 err = f.restoreTable(dump, &dump.Tables[i], flush, counters)
314 if err != nil {
315 return err
316 }
317 }
318 return nil
319 }
320
321
322 func (f *FakeIPTables) Monitor(canary iptables.Chain, tables []iptables.Table, reloadFunc func(), interval time.Duration, stopCh <-chan struct{}) {
323 }
324
325
326 func (f *FakeIPTables) HasRandomFully() bool {
327 return f.hasRandomFully
328 }
329
330 func (f *FakeIPTables) Present() bool {
331 return true
332 }
333
334 var _ = iptables.Interface(&FakeIPTables{})
335
View as plain text