1 package server
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "sync"
8 "time"
9
10 "edge-infra.dev/pkg/lib/fog"
11
12 "cloud.google.com/go/pubsub"
13 "google.golang.org/api/option"
14 "google.golang.org/grpc"
15 )
16
17 var ErrPollMaxRetries = fmt.Errorf("maximum poll retries reached")
18
19 func (p *PSQLInjector) receiverMux() (*ReceiverMux, error) {
20 return NewReceiverMux(&ReceiverMuxConfig{
21 PollFunc: p.sql.GetBannerProjectIDs,
22 Handler: p.HandleMsg,
23
24 ForemanProjectID: p.cfg.ForemanProjectID,
25 SubscriptionID: p.cfg.SubscriptionID,
26 TopicID: p.cfg.TopicID,
27 Conn: p.cfg.TestPubSubConn,
28
29 PollPeriod: p.cfg.PollBannersPeriod,
30 PollMaxRetries: p.cfg.PollBannersMaxRetries,
31
32 PollSubscriptionExistsPeriod: p.cfg.PollSubscriptionExistsPeriod,
33 })
34 }
35
36 type ReceiverMuxConfig struct {
37
38 Handler ReceiverHandler
39
40
41 PollFunc func(context.Context) ([]string, error)
42 PollPeriod time.Duration
43 PollMaxRetries int
44
45 SubscriptionID string
46 TopicID string
47
48 Conn *grpc.ClientConn
49 ForemanProjectID string
50
51
52 PollSubscriptionExistsPeriod time.Duration
53 }
54
55
56
57
58 type ReceiverMux struct {
59 cfg *ReceiverMuxConfig
60 client *pubsub.Client
61
62
63 pollErrs []error
64
65 sync.Mutex
66 receivers map[string]*Receiver
67 }
68
69 func NewReceiverMux(cfg *ReceiverMuxConfig) (*ReceiverMux, error) {
70 var googleOptions []option.ClientOption
71 if cfg.Conn != nil {
72 googleOptions = append(googleOptions, option.WithGRPCConn(cfg.Conn))
73 }
74
75 client, err := pubsub.NewClient(context.Background(), cfg.ForemanProjectID, googleOptions...)
76 if err != nil {
77 return nil, fmt.Errorf("error creating pubsub client: %w", err)
78 }
79
80
81 var timeoutCtx, timeoutCtxCancel = context.WithTimeout(context.Background(), time.Minute)
82 defer timeoutCtxCancel()
83 if banners, err := cfg.PollFunc(timeoutCtx); err != nil {
84 return nil, fmt.Errorf("failed to retrieve banners at startup: %w", err)
85 } else if len(banners) == 0 {
86 return nil, fmt.Errorf("found zero banners at startup")
87 }
88
89 return &ReceiverMux{
90 cfg: cfg,
91 client: client,
92 receivers: make(map[string]*Receiver),
93 }, nil
94 }
95
96 func (rm *ReceiverMux) Run(ctx context.Context) error {
97 log := fog.FromContext(ctx).WithName("banner-mux-run")
98 ctx, stopReceivers := context.WithCancel(ctx)
99
100 var wg sync.WaitGroup
101 defer func() {
102 log.Info("stopping all receivers")
103 stopReceivers()
104
105 wg.Wait()
106 log.Info("all receivers have exited")
107 }()
108
109 var pollTicker = time.NewTicker(rm.cfg.PollPeriod)
110 defer pollTicker.Stop()
111
112 for {
113
114 added, dropped, err := rm.Poll(ctx)
115 if err != nil {
116
117
118 return err
119 }
120
121 for _, r := range added {
122 wg.Add(1)
123 go func() {
124 defer wg.Done()
125 r.Start(ctx)
126
127
128
129 rm.Lock()
130 defer rm.Unlock()
131 delete(rm.receivers, r.projectID)
132 }()
133 }
134
135 for _, r := range dropped {
136
137
138 _ = r.Close()
139 }
140
141 select {
142 case <-pollTicker.C:
143 case <-ctx.Done():
144 return nil
145 }
146 }
147 }
148
149 func (rm *ReceiverMux) Poll(ctx context.Context) (added, dropped []*Receiver, maxRetryError error) {
150 projectIDs, err := rm.cfg.PollFunc(ctx)
151
152 if err != nil {
153 log := fog.FromContext(ctx).WithName("banner-poll")
154
155 log.Error(err, "failed to poll banners",
156 "retry_count", len(rm.pollErrs),
157 "max_retries", rm.cfg.PollMaxRetries,
158 )
159
160 rm.pollErrs = append(rm.pollErrs, err)
161 if len(rm.pollErrs) > rm.cfg.PollMaxRetries {
162 log.Error(ErrPollMaxRetries, "max retries exhausted",
163 "max_retries", rm.cfg.PollMaxRetries,
164 "errs", rm.pollErrs,
165 )
166 return nil, nil, errors.Join(ErrPollMaxRetries, errors.Join(rm.pollErrs...))
167 }
168
169 return nil, nil, nil
170 }
171 rm.pollErrs = nil
172
173
174
175 added, dropped = rm.diffReceivers(projectIDs...)
176 return added, dropped, nil
177 }
178
179
180
181
182
183 func (rm *ReceiverMux) diffReceivers(projectIDs ...string) (added, dropped []*Receiver) {
184 var keep = make(map[string]bool)
185 for _, projectID := range projectIDs {
186 keep[projectID] = true
187 }
188
189 rm.Lock()
190 defer rm.Unlock()
191
192
193 for _, projectID := range projectIDs {
194 if _, found := rm.receivers[projectID]; found {
195 continue
196 }
197
198 var pollPeriod = rm.cfg.PollSubscriptionExistsPeriod
199 var sub = rm.client.SubscriptionInProject(rm.cfg.SubscriptionID, projectID)
200 var r = NewReceiver(projectID, sub, rm.cfg.Handler, pollPeriod)
201 added = append(added, r)
202
203
204 rm.receivers[projectID] = r
205 }
206
207
208 for projectID, r := range rm.receivers {
209 if !keep[projectID] {
210
211
212 dropped = append(dropped, r)
213 }
214 }
215
216 return added, dropped
217 }
218
219
220 type ReceiverMuxHealthCheck struct {
221 Count int
222 HealthyCount int
223 UnhealthyCount int
224 Receivers map[string]ReceiverHealthCheck
225 }
226
227 func (rm *ReceiverMux) HealthCheck() ReceiverMuxHealthCheck {
228 rm.Lock()
229 defer rm.Unlock()
230
231 var bmhc = ReceiverMuxHealthCheck{
232 Count: len(rm.receivers),
233 Receivers: make(map[string]ReceiverHealthCheck),
234 }
235
236 for projectID, r := range rm.receivers {
237 var rhc = r.HealthCheck()
238
239 if rhc.Healthy {
240 bmhc.HealthyCount++
241 } else {
242 bmhc.UnhealthyCount++
243 }
244
245 bmhc.Receivers[projectID] = rhc
246 }
247 return bmhc
248 }
249
View as plain text