1
2
3
4 package network
5
6 import (
7 "context"
8 "fmt"
9 "net"
10 "os/exec"
11 "runtime"
12 "strconv"
13 "time"
14
15 "github.com/Microsoft/hcsshim/internal/guest/prot"
16 "github.com/Microsoft/hcsshim/internal/log"
17 "github.com/pkg/errors"
18 "github.com/sirupsen/logrus"
19 "github.com/vishvananda/netlink"
20 "github.com/vishvananda/netns"
21 )
22
23
24
25 func MoveInterfaceToNS(ifStr string, pid int) error {
26
27 link, err := netlink.LinkByName(ifStr)
28 if err != nil {
29 return errors.Wrapf(err, "netlink.LinkByName(%s) failed", ifStr)
30 }
31 if err := netlink.LinkSetDown(link); err != nil {
32 return errors.Wrapf(err, "netlink.LinkSetDown(%#v) failed", link)
33 }
34
35
36 if err := netlink.LinkSetNsPid(link, pid); err != nil {
37 return errors.Wrapf(err, "netlink.SetNsPid(%#v, %d) failed", link, pid)
38 }
39 return nil
40 }
41
42
43
44
45
46 func DoInNetNS(ns netns.NsHandle, run func() error) error {
47 runtime.LockOSThread()
48 defer runtime.UnlockOSThread()
49
50 origNs, err := netns.Get()
51 if err != nil {
52 return errors.Wrap(err, "failed to get current network namespace")
53 }
54 defer origNs.Close()
55
56 if err := netns.Set(ns); err != nil {
57 return errors.Wrapf(err, "failed to set network namespace to %v", ns)
58 }
59
60 defer netns.Set(origNs)
61
62 return run()
63 }
64
65
66
67
68
69
70 func NetNSConfig(ctx context.Context, ifStr string, nsPid int, adapter *prot.NetworkAdapter) error {
71 ctx, entry := log.S(ctx, logrus.Fields{
72 "ifname": ifStr,
73 "pid": nsPid,
74 })
75 if ifStr == "" || nsPid == -1 || adapter == nil {
76 return errors.New("All three arguments must be specified")
77 }
78
79 entry.Trace("Obtaining current namespace")
80 ns, err := netns.Get()
81 if err != nil {
82 return errors.Wrap(err, "netns.Get() failed")
83 }
84 defer ns.Close()
85 entry.WithField("namespace", ns).Debug("New network namespace from PID")
86
87
88 entry.Trace("Getting reference to interface")
89 link, err := netlink.LinkByName(ifStr)
90 if err != nil {
91 return errors.Wrapf(err, "netlink.LinkByName(%s) failed", ifStr)
92 }
93
94
95 if adapter.EncapOverhead != 0 {
96 mtu := link.Attrs().MTU - int(adapter.EncapOverhead)
97 entry.WithField("mtu", mtu).Debug("EncapOverhead non-zero, will set MTU")
98 if err = netlink.LinkSetMTU(link, mtu); err != nil {
99 return errors.Wrapf(err, "netlink.LinkSetMTU(%#v, %d) failed", link, mtu)
100 }
101 }
102
103
104 if adapter.NatEnabled {
105 entry.Tracef("Configuring interface with NAT: %s/%d gw=%s",
106 adapter.AllocatedIPAddress,
107 adapter.HostIPPrefixLength, adapter.HostIPAddress)
108 metric := 1
109 if adapter.EnableLowMetric {
110 metric = 500
111 }
112
113
114 if err := netlink.LinkSetUp(link); err != nil {
115 return errors.Wrapf(err, "netlink.LinkSetUp(%#v) failed", link)
116 }
117 if err := assignIPToLink(ctx, ifStr, nsPid, link,
118 adapter.AllocatedIPAddress, adapter.HostIPAddress, adapter.HostIPPrefixLength,
119 adapter.EnableLowMetric, metric,
120 ); err != nil {
121 return err
122 }
123 if err := assignIPToLink(ctx, ifStr, nsPid, link,
124 adapter.AllocatedIPv6Address, adapter.HostIPv6Address, adapter.HostIPv6PrefixLength,
125 adapter.EnableLowMetric, metric,
126 ); err != nil {
127 return err
128 }
129 } else {
130 timeout := 30 * time.Second
131 entry.Trace("Configure with DHCP")
132 entry.WithField("timeout", timeout.String()).Debug("Execing udhcpc with timeout...")
133 cmd := exec.Command("udhcpc", "-q", "-i", ifStr, "-s", "/sbin/udhcpc_config.script")
134
135 done := make(chan error)
136 go func() {
137 done <- cmd.Wait()
138 }()
139 defer close(done)
140
141 select {
142 case <-time.After(timeout):
143 var cos string
144 co, err := cmd.CombinedOutput()
145 if err != nil {
146 cos = string(co)
147 }
148 _ = cmd.Process.Kill()
149 entry.WithField("timeout", timeout.String()).Warningf("udhcpc timed out [%s]", cos)
150 return fmt.Errorf("udhcpc timed out. Failed to get DHCP address: %s", cos)
151 case <-done:
152 var cos string
153 co, err := cmd.CombinedOutput()
154 if err != nil {
155 cos = string(co)
156 }
157 if err != nil {
158 entry.WithError(err).Debugf("udhcpc failed [%s]", cos)
159 return errors.Wrapf(err, "process failed (%s)", cos)
160 }
161 }
162 var cos string
163 co, err := cmd.CombinedOutput()
164 if err != nil {
165 cos = string(co)
166 }
167 entry.Debugf("udhcpc succeeded: %s", cos)
168 }
169
170
171 if entry.Logger.GetLevel() >= logrus.DebugLevel {
172 curNS, _ := netns.Get()
173
174 link, _ = netlink.LinkByIndex(link.Attrs().Index)
175 attr := link.Attrs()
176 addrs, _ := netlink.AddrList(link, 0)
177 addrsStr := make([]string, 0, len(addrs))
178 for _, addr := range addrs {
179 addrsStr = append(addrsStr, fmt.Sprintf("%v", addr))
180 }
181
182 entry.WithField("addresses", addrsStr).Debugf("%v: %s[idx=%d,type=%s] is %v",
183 curNS, attr.Name, attr.Index, link.Type(), attr.OperState)
184 }
185
186 return nil
187 }
188
189 func assignIPToLink(ctx context.Context,
190 ifStr string,
191 nsPid int,
192 link netlink.Link,
193 allocatedIP string,
194 gatewayIP string,
195 prefixLen uint8,
196 enableLowMetric bool,
197 metric int,
198 ) error {
199 entry := log.G(ctx)
200 entry.WithFields(logrus.Fields{
201 "link": link.Attrs().Name,
202 "IP": allocatedIP,
203 "prefixLen": prefixLen,
204 "gateway": gatewayIP,
205 "metric": metric,
206 }).Trace("assigning IP address")
207 if allocatedIP == "" {
208 return nil
209 }
210
211 ip, addr, err := net.ParseCIDR(allocatedIP + "/" + strconv.FormatUint(uint64(prefixLen), 10))
212 if err != nil {
213 return errors.Wrapf(err, "parsing address %s/%d failed", allocatedIP, prefixLen)
214 }
215
216 addr.IP = ip
217 entry.WithFields(logrus.Fields{
218 "allocatedIP": ip,
219 "IP": addr,
220 }).Debugf("parsed ip address %s/%d", allocatedIP, prefixLen)
221 ipAddr := &netlink.Addr{IPNet: addr, Label: ""}
222 if err := netlink.AddrAdd(link, ipAddr); err != nil {
223 return errors.Wrapf(err, "netlink.AddrAdd(%#v, %#v) failed", link, ipAddr)
224 }
225 if gatewayIP == "" {
226 return nil
227 }
228
229 gw := net.ParseIP(gatewayIP)
230 if gw == nil {
231 return errors.Wrapf(err, "parsing gateway address %s failed", gatewayIP)
232 }
233
234 if !addr.Contains(gw) {
235
236
237
238 entry.Debugf("gw is outside of the subnet: Configure %s in %d with: %s/%d gw=%s\n",
239 ifStr, nsPid, allocatedIP, prefixLen, gatewayIP)
240 ml := len(gw) * 8
241 addr2 := &net.IPNet{
242 IP: gw,
243 Mask: net.CIDRMask(ml, ml)}
244 ipAddr2 := &netlink.Addr{IPNet: addr2, Label: ""}
245 if err := netlink.AddrAdd(link, ipAddr2); err != nil {
246 return errors.Wrapf(err, "netlink.AddrAdd(%#v, %#v) failed", link, ipAddr2)
247 }
248 }
249
250 var table int
251 if enableLowMetric {
252
253
254 _, ml := addr.Mask.Size()
255 srcNet := &net.IPNet{
256 IP: net.ParseIP(allocatedIP),
257 Mask: net.CIDRMask(ml, ml),
258 }
259 rule := netlink.NewRule()
260 rule.Table = 101
261 rule.Src = srcNet
262 rule.Priority = 5
263
264 if err := netlink.RuleAdd(rule); err != nil {
265 return errors.Wrapf(err, "netlink.RuleAdd(%#v) failed", rule)
266 }
267 table = rule.Table
268 }
269
270 route := netlink.Route{
271 Scope: netlink.SCOPE_UNIVERSE,
272 LinkIndex: link.Attrs().Index,
273 Gw: gw,
274 Table: table,
275 Priority: metric,
276 }
277 if err := netlink.RouteAdd(&route); err != nil {
278 return errors.Wrapf(err, "netlink.RouteAdd(%#v) failed", route)
279 }
280 return nil
281 }
282
View as plain text