...

Source file src/github.com/mdlayher/socket/netns_linux.go

Documentation: github.com/mdlayher/socket

     1  //go:build linux
     2  // +build linux
     3  
     4  package socket
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"runtime"
    11  
    12  	"golang.org/x/sync/errgroup"
    13  	"golang.org/x/sys/unix"
    14  )
    15  
    16  // errNetNSDisabled is returned when network namespaces are unavailable on
    17  // a given system.
    18  var errNetNSDisabled = errors.New("socket: Linux network namespaces are not enabled on this system")
    19  
    20  // withNetNS invokes fn within the context of the network namespace specified by
    21  // fd, while also managing the logic required to safely do so by manipulating
    22  // thread-local state.
    23  func withNetNS(fd int, fn func() (*Conn, error)) (*Conn, error) {
    24  	var (
    25  		eg   errgroup.Group
    26  		conn *Conn
    27  	)
    28  
    29  	eg.Go(func() error {
    30  		// Retrieve and store the calling OS thread's network namespace so the
    31  		// thread can be reassigned to it after creating a socket in another network
    32  		// namespace.
    33  		runtime.LockOSThread()
    34  
    35  		ns, err := threadNetNS()
    36  		if err != nil {
    37  			// No thread-local manipulation, unlock.
    38  			runtime.UnlockOSThread()
    39  			return err
    40  		}
    41  		defer ns.Close()
    42  
    43  		// Beyond this point, the thread's network namespace is poisoned. Do not
    44  		// unlock the OS thread until all network namespace manipulation completes
    45  		// to avoid returning to the caller with altered thread-local state.
    46  
    47  		// Assign the current OS thread the goroutine is locked to to the given
    48  		// network namespace.
    49  		if err := ns.Set(fd); err != nil {
    50  			return err
    51  		}
    52  
    53  		// Attempt Conn creation and unconditionally restore the original namespace.
    54  		c, err := fn()
    55  		if nerr := ns.Restore(); nerr != nil {
    56  			// Failed to restore original namespace. Return an error and allow the
    57  			// runtime to terminate the thread.
    58  			if err == nil {
    59  				_ = c.Close()
    60  			}
    61  
    62  			return nerr
    63  		}
    64  
    65  		// No more thread-local state manipulation; return the new Conn.
    66  		runtime.UnlockOSThread()
    67  		conn = c
    68  		return nil
    69  	})
    70  
    71  	if err := eg.Wait(); err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	return conn, nil
    76  }
    77  
    78  // A netNS is a handle that can manipulate network namespaces.
    79  //
    80  // Operations performed on a netNS must use runtime.LockOSThread before
    81  // manipulating any network namespaces.
    82  type netNS struct {
    83  	// The handle to a network namespace.
    84  	f *os.File
    85  
    86  	// Indicates if network namespaces are disabled on this system, and thus
    87  	// operations should become a no-op or return errors.
    88  	disabled bool
    89  }
    90  
    91  // threadNetNS constructs a netNS using the network namespace of the calling
    92  // thread. If the namespace is not the default namespace, runtime.LockOSThread
    93  // should be invoked first.
    94  func threadNetNS() (*netNS, error) {
    95  	return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid()))
    96  }
    97  
    98  // fileNetNS opens file and creates a netNS. fileNetNS should only be called
    99  // directly in tests.
   100  func fileNetNS(file string) (*netNS, error) {
   101  	f, err := os.Open(file)
   102  	switch {
   103  	case err == nil:
   104  		return &netNS{f: f}, nil
   105  	case os.IsNotExist(err):
   106  		// Network namespaces are not enabled on this system. Use this signal
   107  		// to return errors elsewhere if the caller explicitly asks for a
   108  		// network namespace to be set.
   109  		return &netNS{disabled: true}, nil
   110  	default:
   111  		return nil, err
   112  	}
   113  }
   114  
   115  // Close releases the handle to a network namespace.
   116  func (n *netNS) Close() error {
   117  	return n.do(func() error { return n.f.Close() })
   118  }
   119  
   120  // FD returns a file descriptor which represents the network namespace.
   121  func (n *netNS) FD() int {
   122  	if n.disabled {
   123  		// No reasonable file descriptor value in this case, so specify a
   124  		// non-existent one.
   125  		return -1
   126  	}
   127  
   128  	return int(n.f.Fd())
   129  }
   130  
   131  // Restore restores the original network namespace for the calling thread.
   132  func (n *netNS) Restore() error {
   133  	return n.do(func() error { return n.Set(n.FD()) })
   134  }
   135  
   136  // Set sets a new network namespace for the current thread using fd.
   137  func (n *netNS) Set(fd int) error {
   138  	return n.do(func() error {
   139  		return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET))
   140  	})
   141  }
   142  
   143  // do runs fn if network namespaces are enabled on this system.
   144  func (n *netNS) do(fn func() error) error {
   145  	if n.disabled {
   146  		return errNetNSDisabled
   147  	}
   148  
   149  	return fn()
   150  }
   151  

View as plain text