1 package netlink
2
3 import (
4 "bytes"
5 "fmt"
6 "net"
7
8 "github.com/vishvananda/netlink/nl"
9 "golang.org/x/sys/unix"
10 )
11
12 const FibRuleInvert = 0x2
13
14
15
16 func RuleAdd(rule *Rule) error {
17 return pkgHandle.RuleAdd(rule)
18 }
19
20
21
22 func (h *Handle) RuleAdd(rule *Rule) error {
23 req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
24 return ruleHandle(rule, req)
25 }
26
27
28
29 func RuleDel(rule *Rule) error {
30 return pkgHandle.RuleDel(rule)
31 }
32
33
34
35 func (h *Handle) RuleDel(rule *Rule) error {
36 req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
37 return ruleHandle(rule, req)
38 }
39
40 func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
41 msg := nl.NewRtMsg()
42 msg.Family = unix.AF_INET
43 msg.Protocol = unix.RTPROT_BOOT
44 msg.Scope = unix.RT_SCOPE_UNIVERSE
45 msg.Table = unix.RT_TABLE_UNSPEC
46 msg.Type = rule.Type
47 if msg.Type == 0 && req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
48 msg.Type = unix.RTN_UNICAST
49 }
50 if rule.Invert {
51 msg.Flags |= FibRuleInvert
52 }
53 if rule.Family != 0 {
54 msg.Family = uint8(rule.Family)
55 }
56 if rule.Table >= 0 && rule.Table < 256 {
57 msg.Table = uint8(rule.Table)
58 }
59 if rule.Tos != 0 {
60 msg.Tos = uint8(rule.Tos)
61 }
62
63 var dstFamily uint8
64 var rtAttrs []*nl.RtAttr
65 if rule.Dst != nil && rule.Dst.IP != nil {
66 dstLen, _ := rule.Dst.Mask.Size()
67 msg.Dst_len = uint8(dstLen)
68 msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
69 dstFamily = msg.Family
70 var dstData []byte
71 if msg.Family == unix.AF_INET {
72 dstData = rule.Dst.IP.To4()
73 } else {
74 dstData = rule.Dst.IP.To16()
75 }
76 rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
77 }
78
79 if rule.Src != nil && rule.Src.IP != nil {
80 msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
81 if dstFamily != 0 && dstFamily != msg.Family {
82 return fmt.Errorf("source and destination ip are not the same IP family")
83 }
84 srcLen, _ := rule.Src.Mask.Size()
85 msg.Src_len = uint8(srcLen)
86 var srcData []byte
87 if msg.Family == unix.AF_INET {
88 srcData = rule.Src.IP.To4()
89 } else {
90 srcData = rule.Src.IP.To16()
91 }
92 rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
93 }
94
95 req.AddData(msg)
96 for i := range rtAttrs {
97 req.AddData(rtAttrs[i])
98 }
99
100 if rule.Priority >= 0 {
101 b := make([]byte, 4)
102 native.PutUint32(b, uint32(rule.Priority))
103 req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
104 }
105 if rule.Mark != 0 || rule.Mask != nil {
106 b := make([]byte, 4)
107 native.PutUint32(b, rule.Mark)
108 req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
109 }
110 if rule.Mask != nil {
111 b := make([]byte, 4)
112 native.PutUint32(b, *rule.Mask)
113 req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
114 }
115 if rule.Flow >= 0 {
116 b := make([]byte, 4)
117 native.PutUint32(b, uint32(rule.Flow))
118 req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
119 }
120 if rule.TunID > 0 {
121 b := make([]byte, 4)
122 native.PutUint32(b, uint32(rule.TunID))
123 req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
124 }
125 if rule.Table >= 256 {
126 b := make([]byte, 4)
127 native.PutUint32(b, uint32(rule.Table))
128 req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
129 }
130 if msg.Table > 0 {
131 if rule.SuppressPrefixlen >= 0 {
132 b := make([]byte, 4)
133 native.PutUint32(b, uint32(rule.SuppressPrefixlen))
134 req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
135 }
136 if rule.SuppressIfgroup >= 0 {
137 b := make([]byte, 4)
138 native.PutUint32(b, uint32(rule.SuppressIfgroup))
139 req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
140 }
141 }
142 if rule.IifName != "" {
143 req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
144 }
145 if rule.OifName != "" {
146 req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
147 }
148 if rule.Goto >= 0 {
149 msg.Type = nl.FR_ACT_GOTO
150 b := make([]byte, 4)
151 native.PutUint32(b, uint32(rule.Goto))
152 req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
153 }
154
155 if rule.IPProto > 0 {
156 b := make([]byte, 4)
157 native.PutUint32(b, uint32(rule.IPProto))
158 req.AddData(nl.NewRtAttr(nl.FRA_IP_PROTO, b))
159 }
160
161 if rule.Dport != nil {
162 b := rule.Dport.toRtAttrData()
163 req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
164 }
165
166 if rule.Sport != nil {
167 b := rule.Sport.toRtAttrData()
168 req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
169 }
170
171 if rule.UIDRange != nil {
172 b := rule.UIDRange.toRtAttrData()
173 req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b))
174 }
175
176 if rule.Protocol > 0 {
177 req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol)))
178 }
179
180 _, err := req.Execute(unix.NETLINK_ROUTE, 0)
181 return err
182 }
183
184
185
186 func RuleList(family int) ([]Rule, error) {
187 return pkgHandle.RuleList(family)
188 }
189
190
191
192 func (h *Handle) RuleList(family int) ([]Rule, error) {
193 return h.RuleListFiltered(family, nil, 0)
194 }
195
196
197
198
199 func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
200 return pkgHandle.RuleListFiltered(family, filter, filterMask)
201 }
202
203
204
205 func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
206 req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
207 msg := nl.NewIfInfomsg(family)
208 req.AddData(msg)
209
210 msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
211 if err != nil {
212 return nil, err
213 }
214
215 var res = make([]Rule, 0)
216 for i := range msgs {
217 msg := nl.DeserializeRtMsg(msgs[i])
218 attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
219 if err != nil {
220 return nil, err
221 }
222
223 rule := NewRule()
224 rule.Priority = 0
225
226 rule.Invert = msg.Flags&FibRuleInvert > 0
227 rule.Family = int(msg.Family)
228 rule.Tos = uint(msg.Tos)
229
230 for j := range attrs {
231 switch attrs[j].Attr.Type {
232 case unix.RTA_TABLE:
233 rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
234 case nl.FRA_SRC:
235 rule.Src = &net.IPNet{
236 IP: attrs[j].Value,
237 Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
238 }
239 case nl.FRA_DST:
240 rule.Dst = &net.IPNet{
241 IP: attrs[j].Value,
242 Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
243 }
244 case nl.FRA_FWMARK:
245 rule.Mark = native.Uint32(attrs[j].Value[0:4])
246 case nl.FRA_FWMASK:
247 mask := native.Uint32(attrs[j].Value[0:4])
248 rule.Mask = &mask
249 case nl.FRA_TUN_ID:
250 rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
251 case nl.FRA_IIFNAME:
252 rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
253 case nl.FRA_OIFNAME:
254 rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
255 case nl.FRA_SUPPRESS_PREFIXLEN:
256 i := native.Uint32(attrs[j].Value[0:4])
257 if i != 0xffffffff {
258 rule.SuppressPrefixlen = int(i)
259 }
260 case nl.FRA_SUPPRESS_IFGROUP:
261 i := native.Uint32(attrs[j].Value[0:4])
262 if i != 0xffffffff {
263 rule.SuppressIfgroup = int(i)
264 }
265 case nl.FRA_FLOW:
266 rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
267 case nl.FRA_GOTO:
268 rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
269 case nl.FRA_PRIORITY:
270 rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
271 case nl.FRA_IP_PROTO:
272 rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
273 case nl.FRA_DPORT_RANGE:
274 rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
275 case nl.FRA_SPORT_RANGE:
276 rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
277 case nl.FRA_UID_RANGE:
278 rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
279 case nl.FRA_PROTOCOL:
280 rule.Protocol = uint8(attrs[j].Value[0])
281 }
282 }
283
284 if filter != nil {
285 switch {
286 case filterMask&RT_FILTER_SRC != 0 &&
287 (rule.Src == nil || rule.Src.String() != filter.Src.String()):
288 continue
289 case filterMask&RT_FILTER_DST != 0 &&
290 (rule.Dst == nil || rule.Dst.String() != filter.Dst.String()):
291 continue
292 case filterMask&RT_FILTER_TABLE != 0 &&
293 filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
294 continue
295 case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
296 continue
297 case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
298 continue
299 case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
300 continue
301 case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
302 continue
303 }
304 }
305
306 res = append(res, *rule)
307 }
308
309 return res, nil
310 }
311
312 func (pr *RulePortRange) toRtAttrData() []byte {
313 b := [][]byte{make([]byte, 2), make([]byte, 2)}
314 native.PutUint16(b[0], pr.Start)
315 native.PutUint16(b[1], pr.End)
316 return bytes.Join(b, []byte{})
317 }
318
319 func (pr *RuleUIDRange) toRtAttrData() []byte {
320 b := [][]byte{make([]byte, 4), make([]byte, 4)}
321 native.PutUint32(b[0], pr.Start)
322 native.PutUint32(b[1], pr.End)
323 return bytes.Join(b, []byte{})
324 }
325
326 func ptrEqual(a, b *uint32) bool {
327 if a == b {
328 return true
329 }
330 if (a == nil) || (b == nil) {
331 return false
332 }
333 return *a == *b
334 }
335
336 func (r Rule) typeString() string {
337 switch r.Type {
338 case unix.RTN_UNSPEC:
339 return ""
340 case unix.RTN_UNICAST:
341 return ""
342 case unix.RTN_LOCAL:
343 return "local"
344 case unix.RTN_BROADCAST:
345 return "broadcast"
346 case unix.RTN_ANYCAST:
347 return "anycast"
348 case unix.RTN_MULTICAST:
349 return "multicast"
350 case unix.RTN_BLACKHOLE:
351 return "blackhole"
352 case unix.RTN_UNREACHABLE:
353 return "unreachable"
354 case unix.RTN_PROHIBIT:
355 return "prohibit"
356 case unix.RTN_THROW:
357 return "throw"
358 case unix.RTN_NAT:
359 return "nat"
360 case unix.RTN_XRESOLVE:
361 return "xresolve"
362 default:
363 return fmt.Sprintf("type(0x%x)", r.Type)
364 }
365 }
366
View as plain text