1 package socket
2
3 import (
4 "context"
5 "fmt"
6 "net"
7 "time"
8
9 "github.com/google/uuid"
10 "github.com/spf13/afero"
11
12 "edge-infra.dev/pkg/lib/fog"
13 )
14
15 const (
16 ResetMessage = "reset"
17 SuccessResponse = "success"
18 LockedResponse = "locked"
19 InvalidResponse = "invalid"
20 )
21
22 type Reseter interface {
23 WithTryLock(ctx context.Context, busy chan<- bool, fn func(context.Context) error) (bool, error)
24 ResetCluster(ctx context.Context) error
25 }
26
27 type Socket struct {
28 Fs afero.Fs
29 Path string
30 Listener net.Listener
31 Connections chan *Connection
32 }
33
34 type Connection struct {
35 net.Conn
36 uuid.UUID
37 }
38
39
40 func NewSocket(fs afero.Fs, path string) *Socket {
41 connections := make(chan *Connection, 5)
42
43 return &Socket{
44 Fs: fs,
45 Path: path,
46 Connections: connections,
47 }
48 }
49
50
51 func NewConnection(conn net.Conn) *Connection {
52 return &Connection{
53 conn,
54 uuid.New(),
55 }
56 }
57
58
59
60 func (s *Socket) Listen() error {
61 if err := s.Clear(); err != nil {
62 return err
63 }
64
65 listener, err := net.Listen("unix", s.Path)
66 if err != nil {
67 return fmt.Errorf("failed to create Unix socket listener: %w", err)
68 }
69 s.Listener = listener
70 return nil
71 }
72
73
74
75
76
77
78
79 func (s *Socket) Accept(ctx context.Context) {
80 var lastLoggedErrorAt *time.Time
81 log := fog.FromContext(ctx).WithValues("routine", "socket")
82 ctx = fog.IntoContext(ctx, log)
83 for {
84 select {
85 case <-ctx.Done():
86
87
88 close(s.Connections)
89 return
90 default:
91 s.acceptConnection(ctx, lastLoggedErrorAt)
92 }
93 }
94 }
95
96 func (s *Socket) acceptConnection(ctx context.Context, lastLoggedErrorAt *time.Time) {
97 log := fog.FromContext(ctx)
98 conn, err := s.Listener.Accept()
99 if err != nil {
100 if time.Since(*lastLoggedErrorAt) > 5*time.Minute {
101 log.Error(err, "failed to accept socket connections", "socket", s.Path)
102 *lastLoggedErrorAt = time.Now()
103 }
104
105 time.Sleep(20 * time.Second)
106 return
107 }
108 log.V(0).Info("socket connection accepted", "emaudit", "")
109
110 s.Connections <- NewConnection(conn)
111 }
112
113
114
115 func (s *Socket) Handle(ctx context.Context, reseter Reseter) {
116 log := fog.FromContext(ctx).WithValues("routine", "socket")
117 ctx = fog.IntoContext(ctx, log)
118 for c := range s.Connections {
119 if err := c.handle(ctx, reseter); err != nil {
120 log.Error(err, "failed to handle socket connection", "socket", s.Path)
121 }
122 }
123 }
124
125
126 func (s *Socket) Clear() error {
127 return s.Fs.RemoveAll(s.Path)
128 }
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145 func (c *Connection) handle(ctx context.Context, reseter Reseter) error {
146 log := fog.FromContext(ctx).WithValues("connection", c.UUID)
147 ctx = fog.IntoContext(ctx, log)
148 defer c.Close()
149
150 if err := c.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
151 return err
152 }
153
154 msg, err := c.read()
155 if err != nil {
156 return err
157 }
158
159 if msg != ResetMessage {
160 log.V(0).Info("received invalid reset request")
161 return c.write(InvalidResponse)
162 }
163 return handleReset(ctx, reseter, c)
164 }
165
166
167 func handleReset(ctx context.Context, reseter Reseter, conn *Connection) error {
168 log := fog.FromContext(ctx)
169 log.V(0).Info("received instant reset request, resetting cluster...")
170 locked := make(chan bool)
171 go func() {
172 if _, err := reseter.WithTryLock(ctx, locked, reseter.ResetCluster); err != nil {
173 log.Error(err, "failed to reset the cluster")
174 }
175 }()
176
177
178 if isLocked := <-locked; isLocked {
179 log.V(0).Info("reset already in progress")
180 return conn.write(LockedResponse)
181 }
182 return conn.write(SuccessResponse)
183 }
184
185
186 func (c *Connection) read() (string, error) {
187 buffer := make([]byte, 512)
188
189 n, err := c.Read(buffer)
190 if err != nil {
191 return "", fmt.Errorf("failed to read data from socket connection: %w", err)
192 }
193 return string(buffer[:n]), nil
194 }
195
196
197 func (c *Connection) write(response string) error {
198 _, err := c.Write([]byte(response))
199 return err
200 }
201
View as plain text