1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package sotw
18
19 import (
20 "context"
21 "errors"
22 "strconv"
23 "sync/atomic"
24
25 "google.golang.org/grpc"
26 "google.golang.org/grpc/codes"
27 "google.golang.org/grpc/status"
28
29 core "github.com/datawire/ambassador/v2/pkg/api/envoy/config/core/v3"
30 discovery "github.com/datawire/ambassador/v2/pkg/api/envoy/service/discovery/v3"
31 "github.com/datawire/ambassador/v2/pkg/envoy-control-plane/cache/v3"
32 "github.com/datawire/ambassador/v2/pkg/envoy-control-plane/resource/v3"
33 )
34
35 type Server interface {
36 StreamHandler(stream Stream, typeURL string) error
37 }
38
39 type Callbacks interface {
40
41
42 OnStreamOpen(context.Context, int64, string) error
43
44 OnStreamClosed(int64)
45
46
47 OnStreamRequest(int64, *discovery.DiscoveryRequest) error
48
49 OnStreamResponse(int64, *discovery.DiscoveryRequest, *discovery.DiscoveryResponse)
50 }
51
52
53 func NewServer(ctx context.Context, config cache.ConfigWatcher, callbacks Callbacks) Server {
54 return &server{cache: config, callbacks: callbacks, ctx: ctx}
55 }
56
57 type server struct {
58 cache cache.ConfigWatcher
59 callbacks Callbacks
60 ctx context.Context
61
62
63 streamCount int64
64 }
65
66
67 type Stream interface {
68 grpc.ServerStream
69
70 Send(*discovery.DiscoveryResponse) error
71 Recv() (*discovery.DiscoveryRequest, error)
72 }
73
74
75 type watches struct {
76 endpoints chan cache.Response
77 clusters chan cache.Response
78 routes chan cache.Response
79 listeners chan cache.Response
80 secrets chan cache.Response
81 runtimes chan cache.Response
82
83 endpointCancel func()
84 clusterCancel func()
85 routeCancel func()
86 listenerCancel func()
87 secretCancel func()
88 runtimeCancel func()
89
90 endpointNonce string
91 clusterNonce string
92 routeNonce string
93 listenerNonce string
94 secretNonce string
95 runtimeNonce string
96
97
98 responses chan cache.Response
99 cancellations map[string]func()
100 nonces map[string]string
101 terminations map[string]chan struct{}
102 }
103
104
105 func (values *watches) Init() {
106
107 values.responses = make(chan cache.Response, 5)
108 values.cancellations = make(map[string]func())
109 values.nonces = make(map[string]string)
110 values.terminations = make(map[string]chan struct{})
111 }
112
113
114 var errorResponse = &cache.RawResponse{}
115
116
117 func (values *watches) Cancel() {
118 if values.endpointCancel != nil {
119 values.endpointCancel()
120 }
121 if values.clusterCancel != nil {
122 values.clusterCancel()
123 }
124 if values.routeCancel != nil {
125 values.routeCancel()
126 }
127 if values.listenerCancel != nil {
128 values.listenerCancel()
129 }
130 if values.secretCancel != nil {
131 values.secretCancel()
132 }
133 if values.runtimeCancel != nil {
134 values.runtimeCancel()
135 }
136 for _, cancel := range values.cancellations {
137 if cancel != nil {
138 cancel()
139 }
140 }
141 for _, terminate := range values.terminations {
142 close(terminate)
143 }
144 }
145
146
147 func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest, defaultTypeURL string) error {
148
149 streamID := atomic.AddInt64(&s.streamCount, 1)
150
151
152
153 var streamNonce int64
154
155
156 var values watches
157 values.Init()
158 defer func() {
159 values.Cancel()
160 if s.callbacks != nil {
161 s.callbacks.OnStreamClosed(streamID)
162 }
163 }()
164
165
166 send := func(resp cache.Response, typeURL string) (string, error) {
167 if resp == nil {
168 return "", errors.New("missing response")
169 }
170
171 out, err := resp.GetDiscoveryResponse()
172 if err != nil {
173 return "", err
174 }
175
176
177 streamNonce = streamNonce + 1
178 out.Nonce = strconv.FormatInt(streamNonce, 10)
179 if s.callbacks != nil {
180 s.callbacks.OnStreamResponse(streamID, resp.GetRequest(), out)
181 }
182 return out.Nonce, stream.Send(out)
183 }
184
185 if s.callbacks != nil {
186 if err := s.callbacks.OnStreamOpen(stream.Context(), streamID, defaultTypeURL); err != nil {
187 return err
188 }
189 }
190
191
192 var node = &core.Node{}
193
194 for {
195 select {
196 case <-s.ctx.Done():
197 return nil
198
199 case resp, more := <-values.endpoints:
200 if !more {
201 return status.Errorf(codes.Unavailable, "endpoints watch failed")
202 }
203 nonce, err := send(resp, resource.EndpointType)
204 if err != nil {
205 return err
206 }
207 values.endpointNonce = nonce
208
209 case resp, more := <-values.clusters:
210 if !more {
211 return status.Errorf(codes.Unavailable, "clusters watch failed")
212 }
213 nonce, err := send(resp, resource.ClusterType)
214 if err != nil {
215 return err
216 }
217 values.clusterNonce = nonce
218
219 case resp, more := <-values.routes:
220 if !more {
221 return status.Errorf(codes.Unavailable, "routes watch failed")
222 }
223 nonce, err := send(resp, resource.RouteType)
224 if err != nil {
225 return err
226 }
227 values.routeNonce = nonce
228
229 case resp, more := <-values.listeners:
230 if !more {
231 return status.Errorf(codes.Unavailable, "listeners watch failed")
232 }
233 nonce, err := send(resp, resource.ListenerType)
234 if err != nil {
235 return err
236 }
237 values.listenerNonce = nonce
238
239 case resp, more := <-values.secrets:
240 if !more {
241 return status.Errorf(codes.Unavailable, "secrets watch failed")
242 }
243 nonce, err := send(resp, resource.SecretType)
244 if err != nil {
245 return err
246 }
247 values.secretNonce = nonce
248
249 case resp, more := <-values.runtimes:
250 if !more {
251 return status.Errorf(codes.Unavailable, "runtimes watch failed")
252 }
253 nonce, err := send(resp, resource.RuntimeType)
254 if err != nil {
255 return err
256 }
257 values.runtimeNonce = nonce
258
259 case resp, more := <-values.responses:
260 if more {
261 if resp == errorResponse {
262 return status.Errorf(codes.Unavailable, "resource watch failed")
263 }
264 typeUrl := resp.GetRequest().TypeUrl
265 nonce, err := send(resp, typeUrl)
266 if err != nil {
267 return err
268 }
269 values.nonces[typeUrl] = nonce
270 }
271
272 case req, more := <-reqCh:
273
274 if !more {
275 return nil
276 }
277 if req == nil {
278 return status.Errorf(codes.Unavailable, "empty request")
279 }
280
281
282 if req.Node != nil {
283 node = req.Node
284 } else {
285 req.Node = node
286 }
287
288
289 nonce := req.GetResponseNonce()
290
291
292 if defaultTypeURL == resource.AnyType {
293 if req.TypeUrl == "" {
294 return status.Errorf(codes.InvalidArgument, "type URL is required for ADS")
295 }
296 } else if req.TypeUrl == "" {
297 req.TypeUrl = defaultTypeURL
298 }
299
300 if s.callbacks != nil {
301 if err := s.callbacks.OnStreamRequest(streamID, req); err != nil {
302 return err
303 }
304 }
305
306
307 switch {
308 case req.TypeUrl == resource.EndpointType:
309 if values.endpointNonce == "" || values.endpointNonce == nonce {
310 if values.endpointCancel != nil {
311 values.endpointCancel()
312 }
313 values.endpoints, values.endpointCancel = s.cache.CreateWatch(req)
314 }
315 case req.TypeUrl == resource.ClusterType:
316 if values.clusterNonce == "" || values.clusterNonce == nonce {
317 if values.clusterCancel != nil {
318 values.clusterCancel()
319 }
320 values.clusters, values.clusterCancel = s.cache.CreateWatch(req)
321 }
322 case req.TypeUrl == resource.RouteType:
323 if values.routeNonce == "" || values.routeNonce == nonce {
324 if values.routeCancel != nil {
325 values.routeCancel()
326 }
327 values.routes, values.routeCancel = s.cache.CreateWatch(req)
328 }
329 case req.TypeUrl == resource.ListenerType:
330 if values.listenerNonce == "" || values.listenerNonce == nonce {
331 if values.listenerCancel != nil {
332 values.listenerCancel()
333 }
334 values.listeners, values.listenerCancel = s.cache.CreateWatch(req)
335 }
336 case req.TypeUrl == resource.SecretType:
337 if values.secretNonce == "" || values.secretNonce == nonce {
338 if values.secretCancel != nil {
339 values.secretCancel()
340 }
341 values.secrets, values.secretCancel = s.cache.CreateWatch(req)
342 }
343 case req.TypeUrl == resource.RuntimeType:
344 if values.runtimeNonce == "" || values.runtimeNonce == nonce {
345 if values.runtimeCancel != nil {
346 values.runtimeCancel()
347 }
348 values.runtimes, values.runtimeCancel = s.cache.CreateWatch(req)
349 }
350 default:
351 typeUrl := req.TypeUrl
352 responseNonce, seen := values.nonces[typeUrl]
353 if !seen || responseNonce == nonce {
354
355
356 if terminate, exists := values.terminations[typeUrl]; exists {
357 close(terminate)
358 }
359 if cancel, seen := values.cancellations[typeUrl]; seen && cancel != nil {
360 cancel()
361 }
362 var watch chan cache.Response
363 watch, values.cancellations[typeUrl] = s.cache.CreateWatch(req)
364
365
366 terminate := make(chan struct{})
367 values.terminations[typeUrl] = terminate
368 go func() {
369 select {
370 case resp, more := <-watch:
371 if more {
372 values.responses <- resp
373 } else {
374
375 select {
376 case <-terminate:
377 default:
378
379
380 values.responses <- errorResponse
381 }
382 }
383 break
384 case <-terminate:
385 break
386 }
387 }()
388 }
389 }
390 }
391 }
392 }
393
394
395 func (s *server) StreamHandler(stream Stream, typeURL string) error {
396
397 reqCh := make(chan *discovery.DiscoveryRequest)
398 reqStop := int32(0)
399 go func() {
400 for {
401 req, err := stream.Recv()
402 if atomic.LoadInt32(&reqStop) != 0 {
403 return
404 }
405 if err != nil {
406 close(reqCh)
407 return
408 }
409 select {
410 case reqCh <- req:
411 case <-s.ctx.Done():
412 return
413 }
414 }
415 }()
416
417 err := s.process(stream, reqCh, typeURL)
418
419
420
421 atomic.StoreInt32(&reqStop, 1)
422
423 return err
424 }
425
View as plain text