1 package netlink
2
3 import (
4 "fmt"
5 "time"
6
7 "github.com/vishvananda/netlink/nl"
8 "github.com/vishvananda/netns"
9 "golang.org/x/sys/unix"
10 )
11
12
13 var pkgHandle = &Handle{}
14
15
16
17
18
19 type Handle struct {
20 sockets map[int]*nl.SocketHandle
21 lookupByDump bool
22 }
23
24
25 func SetSocketTimeout(to time.Duration) error {
26 if to < time.Microsecond {
27 return fmt.Errorf("invalid timeout, minimul value is %s", time.Microsecond)
28 }
29
30 nl.SocketTimeoutTv = unix.NsecToTimeval(to.Nanoseconds())
31 return nil
32 }
33
34
35 func GetSocketTimeout() time.Duration {
36 nsec := unix.TimevalToNsec(nl.SocketTimeoutTv)
37 return time.Duration(nsec) * time.Nanosecond
38 }
39
40
41 func (h *Handle) SupportsNetlinkFamily(nlFamily int) bool {
42 _, ok := h.sockets[nlFamily]
43 return ok
44 }
45
46
47
48
49
50 func NewHandle(nlFamilies ...int) (*Handle, error) {
51 return newHandle(netns.None(), netns.None(), nlFamilies...)
52 }
53
54
55
56
57
58 func (h *Handle) SetSocketTimeout(to time.Duration) error {
59 if to < time.Microsecond {
60 return fmt.Errorf("invalid timeout, minimul value is %s", time.Microsecond)
61 }
62 tv := unix.NsecToTimeval(to.Nanoseconds())
63 for _, sh := range h.sockets {
64 if err := sh.Socket.SetSendTimeout(&tv); err != nil {
65 return err
66 }
67 if err := sh.Socket.SetReceiveTimeout(&tv); err != nil {
68 return err
69 }
70 }
71 return nil
72 }
73
74
75
76
77 func (h *Handle) SetSocketReceiveBufferSize(size int, force bool) error {
78 opt := unix.SO_RCVBUF
79 if force {
80 opt = unix.SO_RCVBUFFORCE
81 }
82 for _, sh := range h.sockets {
83 fd := sh.Socket.GetFd()
84 err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, opt, size)
85 if err != nil {
86 return err
87 }
88 }
89 return nil
90 }
91
92
93
94
95 func (h *Handle) GetSocketReceiveBufferSize() ([]int, error) {
96 results := make([]int, len(h.sockets))
97 i := 0
98 for _, sh := range h.sockets {
99 fd := sh.Socket.GetFd()
100 size, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF)
101 if err != nil {
102 return nil, err
103 }
104 results[i] = size
105 i++
106 }
107 return results, nil
108 }
109
110
111 func (h *Handle) SetStrictCheck(state bool) error {
112 for _, sh := range h.sockets {
113 var stateInt int = 0
114 if state {
115 stateInt = 1
116 }
117 err := unix.SetsockoptInt(sh.Socket.GetFd(), unix.SOL_NETLINK, unix.NETLINK_GET_STRICT_CHK, stateInt)
118 if err != nil {
119 return err
120 }
121 }
122 return nil
123 }
124
125
126
127
128 func NewHandleAt(ns netns.NsHandle, nlFamilies ...int) (*Handle, error) {
129 return newHandle(ns, netns.None(), nlFamilies...)
130 }
131
132
133
134 func NewHandleAtFrom(newNs, curNs netns.NsHandle) (*Handle, error) {
135 return newHandle(newNs, curNs)
136 }
137
138 func newHandle(newNs, curNs netns.NsHandle, nlFamilies ...int) (*Handle, error) {
139 h := &Handle{sockets: map[int]*nl.SocketHandle{}}
140 fams := nl.SupportedNlFamilies
141 if len(nlFamilies) != 0 {
142 fams = nlFamilies
143 }
144 for _, f := range fams {
145 s, err := nl.GetNetlinkSocketAt(newNs, curNs, f)
146 if err != nil {
147 return nil, err
148 }
149 h.sockets[f] = &nl.SocketHandle{Socket: s}
150 }
151 return h, nil
152 }
153
154
155 func (h *Handle) Close() {
156 for _, sh := range h.sockets {
157 sh.Close()
158 }
159 h.sockets = nil
160 }
161
162
163
164
165
166 func (h *Handle) Delete() {
167 h.Close()
168 }
169
170 func (h *Handle) newNetlinkRequest(proto, flags int) *nl.NetlinkRequest {
171
172 if h.sockets == nil {
173 return nl.NewNetlinkRequest(proto, flags)
174 }
175 return &nl.NetlinkRequest{
176 NlMsghdr: unix.NlMsghdr{
177 Len: uint32(unix.SizeofNlMsghdr),
178 Type: uint16(proto),
179 Flags: unix.NLM_F_REQUEST | uint16(flags),
180 },
181 Sockets: h.sockets,
182 }
183 }
184
View as plain text