1 //go:build linux 2 // +build linux 3 4 package socket 5 6 import ( 7 "errors" 8 "fmt" 9 "os" 10 "runtime" 11 12 "golang.org/x/sync/errgroup" 13 "golang.org/x/sys/unix" 14 ) 15 16 // errNetNSDisabled is returned when network namespaces are unavailable on 17 // a given system. 18 var errNetNSDisabled = errors.New("socket: Linux network namespaces are not enabled on this system") 19 20 // withNetNS invokes fn within the context of the network namespace specified by 21 // fd, while also managing the logic required to safely do so by manipulating 22 // thread-local state. 23 func withNetNS(fd int, fn func() (*Conn, error)) (*Conn, error) { 24 var ( 25 eg errgroup.Group 26 conn *Conn 27 ) 28 29 eg.Go(func() error { 30 // Retrieve and store the calling OS thread's network namespace so the 31 // thread can be reassigned to it after creating a socket in another network 32 // namespace. 33 runtime.LockOSThread() 34 35 ns, err := threadNetNS() 36 if err != nil { 37 // No thread-local manipulation, unlock. 38 runtime.UnlockOSThread() 39 return err 40 } 41 defer ns.Close() 42 43 // Beyond this point, the thread's network namespace is poisoned. Do not 44 // unlock the OS thread until all network namespace manipulation completes 45 // to avoid returning to the caller with altered thread-local state. 46 47 // Assign the current OS thread the goroutine is locked to to the given 48 // network namespace. 49 if err := ns.Set(fd); err != nil { 50 return err 51 } 52 53 // Attempt Conn creation and unconditionally restore the original namespace. 54 c, err := fn() 55 if nerr := ns.Restore(); nerr != nil { 56 // Failed to restore original namespace. Return an error and allow the 57 // runtime to terminate the thread. 58 if err == nil { 59 _ = c.Close() 60 } 61 62 return nerr 63 } 64 65 // No more thread-local state manipulation; return the new Conn. 66 runtime.UnlockOSThread() 67 conn = c 68 return nil 69 }) 70 71 if err := eg.Wait(); err != nil { 72 return nil, err 73 } 74 75 return conn, nil 76 } 77 78 // A netNS is a handle that can manipulate network namespaces. 79 // 80 // Operations performed on a netNS must use runtime.LockOSThread before 81 // manipulating any network namespaces. 82 type netNS struct { 83 // The handle to a network namespace. 84 f *os.File 85 86 // Indicates if network namespaces are disabled on this system, and thus 87 // operations should become a no-op or return errors. 88 disabled bool 89 } 90 91 // threadNetNS constructs a netNS using the network namespace of the calling 92 // thread. If the namespace is not the default namespace, runtime.LockOSThread 93 // should be invoked first. 94 func threadNetNS() (*netNS, error) { 95 return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid())) 96 } 97 98 // fileNetNS opens file and creates a netNS. fileNetNS should only be called 99 // directly in tests. 100 func fileNetNS(file string) (*netNS, error) { 101 f, err := os.Open(file) 102 switch { 103 case err == nil: 104 return &netNS{f: f}, nil 105 case os.IsNotExist(err): 106 // Network namespaces are not enabled on this system. Use this signal 107 // to return errors elsewhere if the caller explicitly asks for a 108 // network namespace to be set. 109 return &netNS{disabled: true}, nil 110 default: 111 return nil, err 112 } 113 } 114 115 // Close releases the handle to a network namespace. 116 func (n *netNS) Close() error { 117 return n.do(func() error { return n.f.Close() }) 118 } 119 120 // FD returns a file descriptor which represents the network namespace. 121 func (n *netNS) FD() int { 122 if n.disabled { 123 // No reasonable file descriptor value in this case, so specify a 124 // non-existent one. 125 return -1 126 } 127 128 return int(n.f.Fd()) 129 } 130 131 // Restore restores the original network namespace for the calling thread. 132 func (n *netNS) Restore() error { 133 return n.do(func() error { return n.Set(n.FD()) }) 134 } 135 136 // Set sets a new network namespace for the current thread using fd. 137 func (n *netNS) Set(fd int) error { 138 return n.do(func() error { 139 return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET)) 140 }) 141 } 142 143 // do runs fn if network namespaces are enabled on this system. 144 func (n *netNS) do(fn func() error) error { 145 if n.disabled { 146 return errNetNSDisabled 147 } 148 149 return fn() 150 } 151