1 package netlink
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "net"
9 "time"
10
11 "github.com/vishvananda/netlink/nl"
12 "golang.org/x/sys/unix"
13 )
14
15
16 type ConntrackTableType uint8
17
18 const (
19
20
21 ConntrackTable = 1
22
23
24 ConntrackExpectTable = 2
25 )
26
27 const (
28
29 seekCurrent = 1
30 )
31
32
33 type InetFamily uint8
34
35
36
37
38
39
40
41
42
43
44
45
46
47 func ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) {
48 return pkgHandle.ConntrackTableList(table, family)
49 }
50
51
52
53
54 func ConntrackTableFlush(table ConntrackTableType) error {
55 return pkgHandle.ConntrackTableFlush(table)
56 }
57
58
59
60 func ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
61 return pkgHandle.ConntrackCreate(table, family, flow)
62 }
63
64
65
66 func ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
67 return pkgHandle.ConntrackUpdate(table, family, flow)
68 }
69
70
71
72
73
74 func ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) {
75 return pkgHandle.ConntrackDeleteFilters(table, family, filter)
76 }
77
78
79
80 func ConntrackDeleteFilters(table ConntrackTableType, family InetFamily, filters ...CustomConntrackFilter) (uint, error) {
81 return pkgHandle.ConntrackDeleteFilters(table, family, filters...)
82 }
83
84
85
86 func (h *Handle) ConntrackTableList(table ConntrackTableType, family InetFamily) ([]*ConntrackFlow, error) {
87 res, err := h.dumpConntrackTable(table, family)
88 if err != nil {
89 return nil, err
90 }
91
92
93 var result []*ConntrackFlow
94 for _, dataRaw := range res {
95 result = append(result, parseRawData(dataRaw))
96 }
97
98 return result, nil
99 }
100
101
102
103
104 func (h *Handle) ConntrackTableFlush(table ConntrackTableType) error {
105 req := h.newConntrackRequest(table, unix.AF_INET, nl.IPCTNL_MSG_CT_DELETE, unix.NLM_F_ACK)
106 _, err := req.Execute(unix.NETLINK_NETFILTER, 0)
107 return err
108 }
109
110
111
112 func (h *Handle) ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
113 req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_CREATE)
114 attr, err := flow.toNlData()
115 if err != nil {
116 return err
117 }
118
119 for _, a := range attr {
120 req.AddData(a)
121 }
122
123 _, err = req.Execute(unix.NETLINK_NETFILTER, 0)
124 return err
125 }
126
127
128
129 func (h *Handle) ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
130 req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_REPLACE)
131 attr, err := flow.toNlData()
132 if err != nil {
133 return err
134 }
135
136 for _, a := range attr {
137 req.AddData(a)
138 }
139
140 _, err = req.Execute(unix.NETLINK_NETFILTER, 0)
141 return err
142 }
143
144
145
146
147
148 func (h *Handle) ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) {
149 return h.ConntrackDeleteFilters(table, family, filter)
150 }
151
152
153
154 func (h *Handle) ConntrackDeleteFilters(table ConntrackTableType, family InetFamily, filters ...CustomConntrackFilter) (uint, error) {
155 res, err := h.dumpConntrackTable(table, family)
156 if err != nil {
157 return 0, err
158 }
159
160 var matched uint
161 for _, dataRaw := range res {
162 flow := parseRawData(dataRaw)
163 for _, filter := range filters {
164 if match := filter.MatchConntrackFlow(flow); match {
165 req2 := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_DELETE, unix.NLM_F_ACK)
166
167 req2.AddRawData(dataRaw[4:])
168 req2.Execute(unix.NETLINK_NETFILTER, 0)
169 matched++
170
171 break
172 }
173 }
174 }
175
176 return matched, nil
177 }
178
179 func (h *Handle) newConntrackRequest(table ConntrackTableType, family InetFamily, operation, flags int) *nl.NetlinkRequest {
180
181 req := h.newNetlinkRequest((int(table)<<8)|operation, flags)
182
183 msg := &nl.Nfgenmsg{
184 NfgenFamily: uint8(family),
185 Version: nl.NFNETLINK_V0,
186 ResId: 0,
187 }
188 req.AddData(msg)
189 return req
190 }
191
192 func (h *Handle) dumpConntrackTable(table ConntrackTableType, family InetFamily) ([][]byte, error) {
193 req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_GET, unix.NLM_F_DUMP)
194 return req.Execute(unix.NETLINK_NETFILTER, 0)
195 }
196
197
198
199
200 type ProtoInfo interface {
201 Protocol() string
202 }
203
204
205
206 type ProtoInfoTCP struct {
207 State uint8
208 }
209
210 func (*ProtoInfoTCP) Protocol() string {return "tcp"}
211 func (p *ProtoInfoTCP) toNlData() ([]*nl.RtAttr, error) {
212 ctProtoInfo := nl.NewRtAttr(unix.NLA_F_NESTED | nl.CTA_PROTOINFO, []byte{})
213 ctProtoInfoTCP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{})
214 ctProtoInfoTCPState := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State))
215 ctProtoInfoTCP.AddChild(ctProtoInfoTCPState)
216 ctProtoInfo.AddChild(ctProtoInfoTCP)
217
218 return []*nl.RtAttr{ctProtoInfo}, nil
219 }
220
221
222 type ProtoInfoSCTP struct {}
223
224 func (*ProtoInfoSCTP) Protocol() string {return "sctp"}
225
226
227 type ProtoInfoDCCP struct {}
228
229 func (*ProtoInfoDCCP) Protocol() string {return "dccp"}
230
231
232
233
234 type IPTuple struct {
235 Bytes uint64
236 DstIP net.IP
237 DstPort uint16
238 Packets uint64
239 Protocol uint8
240 SrcIP net.IP
241 SrcPort uint16
242 }
243
244
245
246 func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) {
247
248 var srcIPsFlag, dstIPsFlag int
249 if family == nl.FAMILY_V4 {
250 srcIPsFlag = nl.CTA_IP_V4_SRC
251 dstIPsFlag = nl.CTA_IP_V4_DST
252 } else if family == nl.FAMILY_V6 {
253 srcIPsFlag = nl.CTA_IP_V6_SRC
254 dstIPsFlag = nl.CTA_IP_V6_DST
255 } else {
256 return []*nl.RtAttr{}, fmt.Errorf("couldn't generate netlink message for tuple due to unrecognized FamilyType '%d'", family)
257 }
258
259 ctTupleIP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_IP, nil)
260 ctTupleIPSrc := nl.NewRtAttr(srcIPsFlag, t.SrcIP)
261 ctTupleIP.AddChild(ctTupleIPSrc)
262 ctTupleIPDst := nl.NewRtAttr(dstIPsFlag, t.DstIP)
263 ctTupleIP.AddChild(ctTupleIPDst)
264
265 ctTupleProto := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_PROTO, nil)
266 ctTupleProtoNum := nl.NewRtAttr(nl.CTA_PROTO_NUM, []byte{t.Protocol})
267 ctTupleProto.AddChild(ctTupleProtoNum)
268 ctTupleProtoSrcPort := nl.NewRtAttr(nl.CTA_PROTO_SRC_PORT, nl.BEUint16Attr(t.SrcPort))
269 ctTupleProto.AddChild(ctTupleProtoSrcPort)
270 ctTupleProtoDstPort := nl.NewRtAttr(nl.CTA_PROTO_DST_PORT, nl.BEUint16Attr(t.DstPort))
271 ctTupleProto.AddChild(ctTupleProtoDstPort, )
272
273 return []*nl.RtAttr{ctTupleIP, ctTupleProto}, nil
274 }
275
276 type ConntrackFlow struct {
277 FamilyType uint8
278 Forward IPTuple
279 Reverse IPTuple
280 Mark uint32
281 Zone uint16
282 TimeStart uint64
283 TimeStop uint64
284 TimeOut uint32
285 Labels []byte
286 ProtoInfo ProtoInfo
287 }
288
289 func (s *ConntrackFlow) String() string {
290
291
292
293 start := time.Unix(0, int64(s.TimeStart))
294 stop := time.Unix(0, int64(s.TimeStop))
295 timeout := int32(s.TimeOut)
296 res := fmt.Sprintf("%s\t%d src=%s dst=%s sport=%d dport=%d packets=%d bytes=%d\tsrc=%s dst=%s sport=%d dport=%d packets=%d bytes=%d mark=0x%x ",
297 nl.L4ProtoMap[s.Forward.Protocol], s.Forward.Protocol,
298 s.Forward.SrcIP.String(), s.Forward.DstIP.String(), s.Forward.SrcPort, s.Forward.DstPort, s.Forward.Packets, s.Forward.Bytes,
299 s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.SrcPort, s.Reverse.DstPort, s.Reverse.Packets, s.Reverse.Bytes,
300 s.Mark)
301 if len(s.Labels) > 0 {
302 res += fmt.Sprintf("labels=0x%x ", s.Labels)
303 }
304 if s.Zone != 0 {
305 res += fmt.Sprintf("zone=%d ", s.Zone)
306 }
307 res += fmt.Sprintf("start=%v stop=%v timeout=%d(sec)", start, stop, timeout)
308 return res
309 }
310
311
312 func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) {
313 var payload []*nl.RtAttr
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350 ctTupleOrig := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_ORIG, nil)
351 forwardFlowAttrs, err := s.Forward.toNlData(s.FamilyType)
352 if err != nil {
353 return nil, fmt.Errorf("couldn't generate netlink data for conntrack forward flow: %w", err)
354 }
355 for _, a := range forwardFlowAttrs {
356 ctTupleOrig.AddChild(a)
357 }
358
359
360 ctTupleReply := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_REPLY, nil)
361 reverseFlowAttrs, err := s.Reverse.toNlData(s.FamilyType)
362 if err != nil {
363 return nil, fmt.Errorf("couldn't generate netlink data for conntrack reverse flow: %w", err)
364 }
365 for _, a := range reverseFlowAttrs {
366 ctTupleReply.AddChild(a)
367 }
368
369 ctMark := nl.NewRtAttr(nl.CTA_MARK, nl.BEUint32Attr(s.Mark))
370 ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut))
371
372 payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout)
373
374 if s.ProtoInfo != nil {
375 switch p := s.ProtoInfo.(type) {
376 case *ProtoInfoTCP:
377 attrs, err := p.toNlData()
378 if err != nil {
379 return nil, fmt.Errorf("couldn't generate netlink data for conntrack flow's TCP protoinfo: %w", err)
380 }
381 payload = append(payload, attrs...)
382 default:
383 return nil, errors.New("couldn't generate netlink data for conntrack: field 'ProtoInfo' only supports TCP or nil")
384 }
385 }
386
387 return payload, nil
388 }
389
390
391
392
393
394
395
396
397 func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 {
398 for i := 0; i < 2; i++ {
399 _, t, _, v := parseNfAttrTLV(reader)
400 switch t {
401 case nl.CTA_IP_V4_SRC, nl.CTA_IP_V6_SRC:
402 tpl.SrcIP = v
403 case nl.CTA_IP_V4_DST, nl.CTA_IP_V6_DST:
404 tpl.DstIP = v
405 }
406 }
407
408 _, _, protoInfoTotalLen := parseNfAttrTL(reader)
409 _, t, l, v := parseNfAttrTLV(reader)
410
411 protoInfoBytesRead := uint16(nl.SizeofNfattr) + l
412 if t == nl.CTA_PROTO_NUM {
413 tpl.Protocol = uint8(v[0])
414 }
415
416 if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP {
417
418 bytesRemaining := protoInfoTotalLen - protoInfoBytesRead
419 reader.Seek(int64(bytesRemaining), seekCurrent)
420 return tpl.Protocol
421 }
422
423 reader.Seek(3, seekCurrent)
424 protoInfoBytesRead += 3
425 for i := 0; i < 2; i++ {
426 _, t, _ := parseNfAttrTL(reader)
427 protoInfoBytesRead += uint16(nl.SizeofNfattr)
428 switch t {
429 case nl.CTA_PROTO_SRC_PORT:
430 parseBERaw16(reader, &tpl.SrcPort)
431 protoInfoBytesRead += 2
432 case nl.CTA_PROTO_DST_PORT:
433 parseBERaw16(reader, &tpl.DstPort)
434 protoInfoBytesRead += 2
435 }
436
437 reader.Seek(2, seekCurrent)
438 protoInfoBytesRead += 2
439 }
440
441 bytesRemaining := protoInfoTotalLen - protoInfoBytesRead
442 reader.Seek(int64(bytesRemaining), seekCurrent)
443
444 return tpl.Protocol
445 }
446
447 func parseNfAttrTLV(r *bytes.Reader) (isNested bool, attrType, len uint16, value []byte) {
448 isNested, attrType, len = parseNfAttrTL(r)
449
450 value = make([]byte, len)
451 binary.Read(r, binary.BigEndian, &value)
452 return isNested, attrType, len, value
453 }
454
455 func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) {
456 binary.Read(r, nl.NativeEndian(), &len)
457 len -= nl.SizeofNfattr
458
459 binary.Read(r, nl.NativeEndian(), &attrType)
460 isNested = (attrType & nl.NLA_F_NESTED) == nl.NLA_F_NESTED
461 attrType = attrType & (nl.NLA_F_NESTED - 1)
462 return isNested, attrType, len
463 }
464
465
466
467
468 func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 {
469 len = (len + nl.NLA_ALIGNTO - 1) & ^(nl.NLA_ALIGNTO - 1)
470 r.Seek(int64(len), seekCurrent)
471 return len
472 }
473
474 func parseBERaw16(r *bytes.Reader, v *uint16) {
475 binary.Read(r, binary.BigEndian, v)
476 }
477
478 func parseBERaw32(r *bytes.Reader, v *uint32) {
479 binary.Read(r, binary.BigEndian, v)
480 }
481
482 func parseBERaw64(r *bytes.Reader, v *uint64) {
483 binary.Read(r, binary.BigEndian, v)
484 }
485
486 func parseRaw32(r *bytes.Reader, v *uint32) {
487 binary.Read(r, nl.NativeEndian(), v)
488 }
489
490 func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) {
491 for i := 0; i < 2; i++ {
492 switch _, t, _ := parseNfAttrTL(r); t {
493 case nl.CTA_COUNTERS_BYTES:
494 parseBERaw64(r, &bytes)
495 case nl.CTA_COUNTERS_PACKETS:
496 parseBERaw64(r, &packets)
497 default:
498 return
499 }
500 }
501 return
502 }
503
504
505 func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) {
506 var numTimeStamps int
507 oneItem := nl.SizeofNfattr + 8
508 if readSize == uint16(oneItem) {
509 numTimeStamps = 1
510 } else if readSize == 2*uint16(oneItem) {
511 numTimeStamps = 2
512 } else {
513 return
514 }
515 for i := 0; i < numTimeStamps; i++ {
516 switch _, t, _ := parseNfAttrTL(r); t {
517 case nl.CTA_TIMESTAMP_START:
518 parseBERaw64(r, &tstart)
519 case nl.CTA_TIMESTAMP_STOP:
520 parseBERaw64(r, &tstop)
521 default:
522 return
523 }
524 }
525 return
526
527 }
528
529 func parseProtoInfoTCPState(r *bytes.Reader) (s uint8) {
530 binary.Read(r, binary.BigEndian, &s)
531 r.Seek(nl.SizeofNfattr - 1, seekCurrent)
532 return s
533 }
534
535
536 func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) (*ProtoInfoTCP) {
537 p := new(ProtoInfoTCP)
538 bytesRead := 0
539 for bytesRead < int(attrLen) {
540 _, t, l := parseNfAttrTL(r)
541 bytesRead += nl.SizeofNfattr
542
543 switch t {
544 case nl.CTA_PROTOINFO_TCP_STATE:
545 p.State = parseProtoInfoTCPState(r)
546 bytesRead += nl.SizeofNfattr
547 default:
548 bytesRead += int(skipNfAttrValue(r, l))
549 }
550 }
551
552 return p
553 }
554
555 func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) {
556 bytesRead := 0
557 for bytesRead < int(attrLen) {
558 _, t, l := parseNfAttrTL(r)
559 bytesRead += nl.SizeofNfattr
560
561 switch t {
562 case nl.CTA_PROTOINFO_TCP:
563 p = parseProtoInfoTCP(r, l)
564 bytesRead += int(l)
565
566 case nl.CTA_PROTOINFO_DCCP:
567 p = new(ProtoInfoDCCP)
568 skipped := skipNfAttrValue(r, l)
569 bytesRead += int(skipped)
570 case nl.CTA_PROTOINFO_SCTP:
571 p = new(ProtoInfoSCTP)
572 skipped := skipNfAttrValue(r, l)
573 bytesRead += int(skipped)
574 default:
575 skipped := skipNfAttrValue(r, l)
576 bytesRead += int(skipped)
577 }
578 }
579
580 return p
581 }
582
583 func parseTimeOut(r *bytes.Reader) (ttimeout uint32) {
584 parseBERaw32(r, &ttimeout)
585 return
586 }
587
588 func parseConnectionMark(r *bytes.Reader) (mark uint32) {
589 parseBERaw32(r, &mark)
590 return
591 }
592
593 func parseConnectionLabels(r *bytes.Reader) (label []byte) {
594 label = make([]byte, 16)
595 binary.Read(r, nl.NativeEndian(), &label)
596 return
597 }
598
599 func parseConnectionZone(r *bytes.Reader) (zone uint16) {
600 parseBERaw16(r, &zone)
601 r.Seek(2, seekCurrent)
602 return
603 }
604
605 func parseRawData(data []byte) *ConntrackFlow {
606 s := &ConntrackFlow{}
607
608
609 reader := bytes.NewReader(data)
610 binary.Read(reader, nl.NativeEndian(), &s.FamilyType)
611
612
613 reader.Seek(3, seekCurrent)
614
615
616
617
618
619
620
621 for reader.Len() > 0 {
622 if nested, t, l := parseNfAttrTL(reader); nested {
623 switch t {
624 case nl.CTA_TUPLE_ORIG:
625 if nested, t, l = parseNfAttrTL(reader); nested && t == nl.CTA_TUPLE_IP {
626 parseIpTuple(reader, &s.Forward)
627 }
628 case nl.CTA_TUPLE_REPLY:
629 if nested, t, l = parseNfAttrTL(reader); nested && t == nl.CTA_TUPLE_IP {
630 parseIpTuple(reader, &s.Reverse)
631 } else {
632
633 skipNfAttrValue(reader, l)
634 }
635 case nl.CTA_COUNTERS_ORIG:
636 s.Forward.Bytes, s.Forward.Packets = parseByteAndPacketCounters(reader)
637 case nl.CTA_COUNTERS_REPLY:
638 s.Reverse.Bytes, s.Reverse.Packets = parseByteAndPacketCounters(reader)
639 case nl.CTA_TIMESTAMP:
640 s.TimeStart, s.TimeStop = parseTimeStamp(reader, l)
641 case nl.CTA_PROTOINFO:
642 s.ProtoInfo = parseProtoInfo(reader, l)
643 default:
644 skipNfAttrValue(reader, l)
645 }
646 } else {
647 switch t {
648 case nl.CTA_MARK:
649 s.Mark = parseConnectionMark(reader)
650 case nl.CTA_LABELS:
651 s.Labels = parseConnectionLabels(reader)
652 case nl.CTA_TIMEOUT:
653 s.TimeOut = parseTimeOut(reader)
654 case nl.CTA_ID, nl.CTA_STATUS, nl.CTA_USE:
655 skipNfAttrValue(reader, l)
656 case nl.CTA_ZONE:
657 s.Zone = parseConnectionZone(reader)
658 default:
659 skipNfAttrValue(reader, l)
660 }
661 }
662 }
663 return s
664 }
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699 type ConntrackFilterType uint8
700
701 const (
702 ConntrackOrigSrcIP = iota
703 ConntrackOrigDstIP
704 ConntrackReplySrcIP
705 ConntrackReplyDstIP
706 ConntrackReplyAnyIP
707 ConntrackOrigSrcPort
708 ConntrackOrigDstPort
709 ConntrackMatchLabels
710 ConntrackUnmatchLabels
711 ConntrackNatSrcIP = ConntrackReplySrcIP
712 ConntrackNatDstIP = ConntrackReplyDstIP
713 ConntrackNatAnyIP = ConntrackReplyAnyIP
714 )
715
716 type CustomConntrackFilter interface {
717
718
719 MatchConntrackFlow(flow *ConntrackFlow) bool
720 }
721
722 type ConntrackFilter struct {
723 ipNetFilter map[ConntrackFilterType]*net.IPNet
724 portFilter map[ConntrackFilterType]uint16
725 protoFilter uint8
726 labelFilter map[ConntrackFilterType][][]byte
727 zoneFilter *uint16
728 }
729
730
731 func (f *ConntrackFilter) AddIPNet(tp ConntrackFilterType, ipNet *net.IPNet) error {
732 if ipNet == nil {
733 return fmt.Errorf("Filter attribute empty")
734 }
735 if f.ipNetFilter == nil {
736 f.ipNetFilter = make(map[ConntrackFilterType]*net.IPNet)
737 }
738 if _, ok := f.ipNetFilter[tp]; ok {
739 return errors.New("Filter attribute already present")
740 }
741 f.ipNetFilter[tp] = ipNet
742 return nil
743 }
744
745
746 func (f *ConntrackFilter) AddIP(tp ConntrackFilterType, ip net.IP) error {
747 if ip == nil {
748 return fmt.Errorf("Filter attribute empty")
749 }
750 return f.AddIPNet(tp, NewIPNet(ip))
751 }
752
753
754 func (f *ConntrackFilter) AddPort(tp ConntrackFilterType, port uint16) error {
755 switch f.protoFilter {
756
757 case 6, 17, 33, 132, 136:
758 default:
759 return fmt.Errorf("Filter attribute not available without a valid Layer 4 protocol: %d", f.protoFilter)
760 }
761
762 if f.portFilter == nil {
763 f.portFilter = make(map[ConntrackFilterType]uint16)
764 }
765 if _, ok := f.portFilter[tp]; ok {
766 return errors.New("Filter attribute already present")
767 }
768 f.portFilter[tp] = port
769 return nil
770 }
771
772
773 func (f *ConntrackFilter) AddProtocol(proto uint8) error {
774 if f.protoFilter != 0 {
775 return errors.New("Filter attribute already present")
776 }
777 f.protoFilter = proto
778 return nil
779 }
780
781
782
783
784
785
786
787
788
789
790
791 func (f *ConntrackFilter) AddLabels(tp ConntrackFilterType, labels [][]byte) error {
792 if len(labels) == 0 {
793 return errors.New("Invalid length for provided labels")
794 }
795 if f.labelFilter == nil {
796 f.labelFilter = make(map[ConntrackFilterType][][]byte)
797 }
798 if _, ok := f.labelFilter[tp]; ok {
799 return errors.New("Filter attribute already present")
800 }
801 f.labelFilter[tp] = labels
802 return nil
803 }
804
805
806 func (f *ConntrackFilter) AddZone(zone uint16) error {
807 if f.zoneFilter != nil {
808 return errors.New("Filter attribute already present")
809 }
810 f.zoneFilter = &zone
811 return nil
812 }
813
814
815
816 func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool {
817 if len(f.ipNetFilter) == 0 && len(f.portFilter) == 0 && f.protoFilter == 0 && len(f.labelFilter) == 0 && f.zoneFilter == nil {
818
819 return false
820 }
821
822
823 if f.protoFilter != 0 && flow.Forward.Protocol != f.protoFilter {
824
825 return false
826 }
827
828
829 if f.zoneFilter != nil && *f.zoneFilter != flow.Zone {
830 return false
831 }
832
833 match := true
834
835
836 if len(f.ipNetFilter) > 0 {
837
838 if elem, found := f.ipNetFilter[ConntrackOrigSrcIP]; found {
839 match = match && elem.Contains(flow.Forward.SrcIP)
840 }
841
842
843 if elem, found := f.ipNetFilter[ConntrackOrigDstIP]; match && found {
844 match = match && elem.Contains(flow.Forward.DstIP)
845 }
846
847
848 if elem, found := f.ipNetFilter[ConntrackReplySrcIP]; match && found {
849 match = match && elem.Contains(flow.Reverse.SrcIP)
850 }
851
852
853 if elem, found := f.ipNetFilter[ConntrackReplyDstIP]; match && found {
854 match = match && elem.Contains(flow.Reverse.DstIP)
855 }
856
857
858 if elem, found := f.ipNetFilter[ConntrackReplyAnyIP]; match && found {
859 match = match && (elem.Contains(flow.Reverse.SrcIP) || elem.Contains(flow.Reverse.DstIP))
860 }
861 }
862
863
864 if len(f.portFilter) > 0 {
865
866 if elem, found := f.portFilter[ConntrackOrigSrcPort]; match && found {
867 match = match && elem == flow.Forward.SrcPort
868 }
869
870
871 if elem, found := f.portFilter[ConntrackOrigDstPort]; match && found {
872 match = match && elem == flow.Forward.DstPort
873 }
874 }
875
876
877 if len(f.labelFilter) > 0 {
878 if len(flow.Labels) > 0 {
879
880
881 if elem, found := f.labelFilter[ConntrackMatchLabels]; match && found {
882 for _, label := range elem {
883 match = match && (bytes.Contains(flow.Labels, label))
884 }
885 }
886
887
888 if elem, found := f.labelFilter[ConntrackUnmatchLabels]; match && found {
889 for _, label := range elem {
890 match = match && !(bytes.Contains(flow.Labels, label))
891 }
892 }
893 } else {
894
895 match = false
896 }
897 }
898
899 return match
900 }
901
902 var _ CustomConntrackFilter = (*ConntrackFilter)(nil)
903
View as plain text