//go:build linux package socket import ( "context" "fmt" "net" "time" "golang.org/x/sys/unix" "edge-infra.dev/pkg/lib/fog" ) const timeout = time.Second * 5 // Socket server defines a bind address to listen/broadcast on // and the current data channel. type Server struct { bindAddress string data chan []byte } // Returns the socket servers data channel func (s *Server) GetDataChan() *chan []byte { return &s.data } func (s *Server) Setup(bindAddress string) { s.bindAddress = bindAddress s.data = make(chan []byte, 1) } // Starts the socket server, either replaying to the local netlink instance or // broadcast events locally to a socket. // // Stops whens the context is cancelled. Sends any errors over the error channel. func (s *Server) Serve(ctx context.Context, errs chan error) { log := fog.FromContext(ctx) if s.bindAddress == "netlink" { log.Info("serving netlink") s.serveNetlink(ctx, errs) } else { log.Info("serving remote destination") s.serveTCP(ctx, errs) } } // serveNetlink instantiates a new netlink socket and will replay // the data (uevents) from the data channel to netlink. // // Stops whens the context is cancelled. Sends any errors over the error channel. func (s *Server) serveNetlink(ctx context.Context, errs chan error) { fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_KOBJECT_UEVENT) if err != nil { errs <- err return } addr := unix.SockaddrNetlink{ Family: unix.AF_NETLINK, Groups: uint32(2), } if err = unix.Bind(fd, &addr); err != nil { unix.Close(fd) } for { select { case data := <-s.data: if err = unix.Sendto(fd, data, 0, &addr); err != nil { errs <- fmt.Errorf("failed to start netlink connection: %w", err) } case <-ctx.Done(): return } } } // serverTCP broadcasts a new socket and sends data (uevents) from the data // channel to any connected clients that are listening // // Stops whens the context is cancelled. Sends any errors over the error channel. func (s *Server) serveTCP(ctx context.Context, errs chan error) { connChan := make(chan net.Conn, 1) ln, err := net.Listen("tcp4", s.bindAddress) if err != nil { errs <- fmt.Errorf("failed to start remote connection: %w", err) return } // writes the data stream to any connected clients go writeToConnections(ctx, connChan, s.data) for { select { case <-ctx.Done(): return default: conn, err := ln.Accept() if err != nil { errs <- err return } connChan <- conn } } } func writeToConnections(ctx context.Context, connChan chan net.Conn, input chan []byte) { connections := make(map[net.Conn]net.Conn, 1) for { select { case conn := <-connChan: connections[conn] = conn case data := <-input: for _, c := range connections { if err := c.SetWriteDeadline(time.Now().Add(timeout)); err != nil { c.Close() delete(connections, c) } _, err := c.Write(data) if err != nil { c.Close() delete(connections, c) } } case <-ctx.Done(): // gracefully close connections and return for _, c := range connections { c.Close() } return } } }