...

Source file src/edge-infra.dev/pkg/lib/kernel/netlink/socket/socket.go

Documentation: edge-infra.dev/pkg/lib/kernel/netlink/socket

     1  //go:build linux
     2  
     3  package socket
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"golang.org/x/sys/unix"
    12  
    13  	"edge-infra.dev/pkg/lib/fog"
    14  )
    15  
    16  const timeout = time.Second * 5
    17  
    18  // Socket server defines a bind address to listen/broadcast on
    19  // and the current data channel.
    20  type Server struct {
    21  	bindAddress string
    22  	data        chan []byte
    23  }
    24  
    25  // Returns the socket servers data channel
    26  func (s *Server) GetDataChan() *chan []byte {
    27  	return &s.data
    28  }
    29  
    30  func (s *Server) Setup(bindAddress string) {
    31  	s.bindAddress = bindAddress
    32  	s.data = make(chan []byte, 1)
    33  }
    34  
    35  // Starts the socket server, either replaying to the local netlink instance or
    36  // broadcast events locally to a socket.
    37  //
    38  // Stops whens the context is cancelled. Sends any errors over the error channel.
    39  func (s *Server) Serve(ctx context.Context, errs chan error) {
    40  	log := fog.FromContext(ctx)
    41  	if s.bindAddress == "netlink" {
    42  		log.Info("serving netlink")
    43  		s.serveNetlink(ctx, errs)
    44  	} else {
    45  		log.Info("serving remote destination")
    46  		s.serveTCP(ctx, errs)
    47  	}
    48  }
    49  
    50  // serveNetlink instantiates a new netlink socket and will replay
    51  // the data (uevents) from the data channel to netlink.
    52  //
    53  // Stops whens the context is cancelled. Sends any errors over the error channel.
    54  func (s *Server) serveNetlink(ctx context.Context, errs chan error) {
    55  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_KOBJECT_UEVENT)
    56  	if err != nil {
    57  		errs <- err
    58  		return
    59  	}
    60  
    61  	addr := unix.SockaddrNetlink{
    62  		Family: unix.AF_NETLINK,
    63  		Groups: uint32(2),
    64  	}
    65  
    66  	if err = unix.Bind(fd, &addr); err != nil {
    67  		unix.Close(fd)
    68  	}
    69  
    70  	for {
    71  		select {
    72  		case data := <-s.data:
    73  			if err = unix.Sendto(fd, data, 0, &addr); err != nil {
    74  				errs <- fmt.Errorf("failed to start netlink connection: %w", err)
    75  			}
    76  		case <-ctx.Done():
    77  			return
    78  		}
    79  	}
    80  }
    81  
    82  // serverTCP broadcasts a new socket and sends data (uevents) from the data
    83  // channel to any connected clients that are listening
    84  //
    85  // Stops whens the context is cancelled. Sends any errors over the error channel.
    86  func (s *Server) serveTCP(ctx context.Context, errs chan error) {
    87  	connChan := make(chan net.Conn, 1)
    88  
    89  	ln, err := net.Listen("tcp4", s.bindAddress)
    90  	if err != nil {
    91  		errs <- fmt.Errorf("failed to start remote connection: %w", err)
    92  		return
    93  	}
    94  
    95  	// writes the data stream to any connected clients
    96  	go writeToConnections(ctx, connChan, s.data)
    97  
    98  	for {
    99  		select {
   100  		case <-ctx.Done():
   101  			return
   102  		default:
   103  			conn, err := ln.Accept()
   104  			if err != nil {
   105  				errs <- err
   106  				return
   107  			}
   108  			connChan <- conn
   109  		}
   110  	}
   111  }
   112  
   113  func writeToConnections(ctx context.Context, connChan chan net.Conn, input chan []byte) {
   114  	connections := make(map[net.Conn]net.Conn, 1)
   115  
   116  	for {
   117  		select {
   118  		case conn := <-connChan:
   119  			connections[conn] = conn
   120  		case data := <-input:
   121  			for _, c := range connections {
   122  				if err := c.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
   123  					c.Close()
   124  					delete(connections, c)
   125  				}
   126  				_, err := c.Write(data)
   127  				if err != nil {
   128  					c.Close()
   129  					delete(connections, c)
   130  				}
   131  			}
   132  		case <-ctx.Done():
   133  			// gracefully close connections and return
   134  			for _, c := range connections {
   135  				c.Close()
   136  			}
   137  			return
   138  		}
   139  	}
   140  }
   141  

View as plain text