...
1
2
3 package socket
4
5 import (
6 "context"
7 "fmt"
8 "net"
9 "time"
10
11 "golang.org/x/sys/unix"
12
13 "edge-infra.dev/pkg/lib/fog"
14 )
15
16 const timeout = time.Second * 5
17
18
19
20 type Server struct {
21 bindAddress string
22 data chan []byte
23 }
24
25
26 func (s *Server) GetDataChan() *chan []byte {
27 return &s.data
28 }
29
30 func (s *Server) Setup(bindAddress string) {
31 s.bindAddress = bindAddress
32 s.data = make(chan []byte, 1)
33 }
34
35
36
37
38
39 func (s *Server) Serve(ctx context.Context, errs chan error) {
40 log := fog.FromContext(ctx)
41 if s.bindAddress == "netlink" {
42 log.Info("serving netlink")
43 s.serveNetlink(ctx, errs)
44 } else {
45 log.Info("serving remote destination")
46 s.serveTCP(ctx, errs)
47 }
48 }
49
50
51
52
53
54 func (s *Server) serveNetlink(ctx context.Context, errs chan error) {
55 fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_KOBJECT_UEVENT)
56 if err != nil {
57 errs <- err
58 return
59 }
60
61 addr := unix.SockaddrNetlink{
62 Family: unix.AF_NETLINK,
63 Groups: uint32(2),
64 }
65
66 if err = unix.Bind(fd, &addr); err != nil {
67 unix.Close(fd)
68 }
69
70 for {
71 select {
72 case data := <-s.data:
73 if err = unix.Sendto(fd, data, 0, &addr); err != nil {
74 errs <- fmt.Errorf("failed to start netlink connection: %w", err)
75 }
76 case <-ctx.Done():
77 return
78 }
79 }
80 }
81
82
83
84
85
86 func (s *Server) serveTCP(ctx context.Context, errs chan error) {
87 connChan := make(chan net.Conn, 1)
88
89 ln, err := net.Listen("tcp4", s.bindAddress)
90 if err != nil {
91 errs <- fmt.Errorf("failed to start remote connection: %w", err)
92 return
93 }
94
95
96 go writeToConnections(ctx, connChan, s.data)
97
98 for {
99 select {
100 case <-ctx.Done():
101 return
102 default:
103 conn, err := ln.Accept()
104 if err != nil {
105 errs <- err
106 return
107 }
108 connChan <- conn
109 }
110 }
111 }
112
113 func writeToConnections(ctx context.Context, connChan chan net.Conn, input chan []byte) {
114 connections := make(map[net.Conn]net.Conn, 1)
115
116 for {
117 select {
118 case conn := <-connChan:
119 connections[conn] = conn
120 case data := <-input:
121 for _, c := range connections {
122 if err := c.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
123 c.Close()
124 delete(connections, c)
125 }
126 _, err := c.Write(data)
127 if err != nil {
128 c.Close()
129 delete(connections, c)
130 }
131 }
132 case <-ctx.Done():
133
134 for _, c := range connections {
135 c.Close()
136 }
137 return
138 }
139 }
140 }
141
View as plain text