1
16
17
18 package cdsbalancer
19
20 import (
21 "context"
22 "encoding/json"
23 "fmt"
24 "sync/atomic"
25 "unsafe"
26
27 "google.golang.org/grpc/balancer"
28 "google.golang.org/grpc/balancer/base"
29 "google.golang.org/grpc/connectivity"
30 "google.golang.org/grpc/credentials"
31 "google.golang.org/grpc/credentials/tls/certprovider"
32 "google.golang.org/grpc/internal/balancer/nop"
33 xdsinternal "google.golang.org/grpc/internal/credentials/xds"
34 "google.golang.org/grpc/internal/grpclog"
35 "google.golang.org/grpc/internal/grpcsync"
36 "google.golang.org/grpc/internal/pretty"
37 "google.golang.org/grpc/resolver"
38 "google.golang.org/grpc/serviceconfig"
39 "google.golang.org/grpc/xds/internal/balancer/clusterresolver"
40 "google.golang.org/grpc/xds/internal/xdsclient"
41 "google.golang.org/grpc/xds/internal/xdsclient/xdsresource"
42 )
43
44 const (
45 cdsName = "cds_experimental"
46 aggregateClusterMaxDepth = 16
47 )
48
49 var (
50 errBalancerClosed = fmt.Errorf("cds_experimental LB policy is closed")
51 errExceedsMaxDepth = fmt.Errorf("aggregate cluster graph exceeds max depth (%d)", aggregateClusterMaxDepth)
52
53
54
55 newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) {
56 builder := balancer.Get(clusterresolver.Name)
57 if builder == nil {
58 return nil, fmt.Errorf("xds: no balancer builder with name %v", clusterresolver.Name)
59 }
60
61
62
63 return builder.Build(cc, opts), nil
64 }
65 buildProvider = buildProviderFunc
66 )
67
68 func init() {
69 balancer.Register(bb{})
70 }
71
72
73
74
75 type bb struct{}
76
77
78 func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
79 builder := balancer.Get(clusterresolver.Name)
80 if builder == nil {
81
82
83 logger.Errorf("%q LB policy is needed but not registered", clusterresolver.Name)
84 return nop.NewBalancer(cc, fmt.Errorf("%q LB policy is needed but not registered", clusterresolver.Name))
85 }
86 parser, ok := builder.(balancer.ConfigParser)
87 if !ok {
88
89 logger.Errorf("%q LB policy does not implement a config parser", clusterresolver.Name)
90 return nop.NewBalancer(cc, fmt.Errorf("%q LB policy does not implement a config parser", clusterresolver.Name))
91 }
92
93 ctx, cancel := context.WithCancel(context.Background())
94 hi := xdsinternal.NewHandshakeInfo(nil, nil, nil, false)
95 xdsHIPtr := unsafe.Pointer(hi)
96 b := &cdsBalancer{
97 bOpts: opts,
98 childConfigParser: parser,
99 serializer: grpcsync.NewCallbackSerializer(ctx),
100 serializerCancel: cancel,
101 xdsHIPtr: &xdsHIPtr,
102 watchers: make(map[string]*watcherState),
103 }
104 b.ccw = &ccWrapper{
105 ClientConn: cc,
106 xdsHIPtr: b.xdsHIPtr,
107 }
108 b.logger = prefixLogger(b)
109 b.logger.Infof("Created")
110
111 var creds credentials.TransportCredentials
112 switch {
113 case opts.DialCreds != nil:
114 creds = opts.DialCreds
115 case opts.CredsBundle != nil:
116 creds = opts.CredsBundle.TransportCredentials()
117 }
118 if xc, ok := creds.(interface{ UsesXDS() bool }); ok && xc.UsesXDS() {
119 b.xdsCredsInUse = true
120 }
121 b.logger.Infof("xDS credentials in use: %v", b.xdsCredsInUse)
122 return b
123 }
124
125
126 func (bb) Name() string {
127 return cdsName
128 }
129
130
131
132 type lbConfig struct {
133 serviceconfig.LoadBalancingConfig
134 ClusterName string `json:"Cluster"`
135 }
136
137
138
139 func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
140 var cfg lbConfig
141 if err := json.Unmarshal(c, &cfg); err != nil {
142 return nil, fmt.Errorf("xds: unable to unmarshal lbconfig: %s, error: %v", string(c), err)
143 }
144 return &cfg, nil
145 }
146
147
148
149
150
151
152 type cdsBalancer struct {
153
154
155
156 ccw *ccWrapper
157 bOpts balancer.BuildOptions
158 childConfigParser balancer.ConfigParser
159 logger *grpclog.PrefixLogger
160 xdsCredsInUse bool
161
162 xdsHIPtr *unsafe.Pointer
163
164
165
166
167
168 serializer *grpcsync.CallbackSerializer
169 serializerCancel context.CancelFunc
170 childLB balancer.Balancer
171 xdsClient xdsclient.XDSClient
172 watchers map[string]*watcherState
173 lbCfg *lbConfig
174
175
176
177 cachedRoot certprovider.Provider
178 cachedIdentity certprovider.Provider
179 }
180
181
182
183
184
185
186
187 func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) error {
188
189
190
191 if !b.xdsCredsInUse {
192 return nil
193 }
194 var xdsHI *xdsinternal.HandshakeInfo
195
196
197
198
199 if config == nil {
200
201
202
203 xdsHI = xdsinternal.NewHandshakeInfo(nil, nil, nil, false)
204 atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI))
205 return nil
206
207 }
208
209
210 cpc := b.xdsClient.BootstrapConfig().CertProviderConfigs
211 rootProvider, err := buildProvider(cpc, config.RootInstanceName, config.RootCertName, false, true)
212 if err != nil {
213 return err
214 }
215
216
217 var identityProvider certprovider.Provider
218 if name, cert := config.IdentityInstanceName, config.IdentityCertName; name != "" {
219 var err error
220 identityProvider, err = buildProvider(cpc, name, cert, true, false)
221 if err != nil {
222 return err
223 }
224 }
225
226
227 if b.cachedRoot != nil {
228 b.cachedRoot.Close()
229 }
230 if b.cachedIdentity != nil {
231 b.cachedIdentity.Close()
232 }
233 b.cachedRoot = rootProvider
234 b.cachedIdentity = identityProvider
235 xdsHI = xdsinternal.NewHandshakeInfo(rootProvider, identityProvider, config.SubjectAltNameMatchers, false)
236 atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI))
237 return nil
238 }
239
240 func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) {
241 cfg, ok := configs[instanceName]
242 if !ok {
243
244
245
246 return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName)
247 }
248 provider, err := cfg.Build(certprovider.BuildOptions{
249 CertName: certName,
250 WantIdentity: wantIdentity,
251 WantRoot: wantRoot,
252 })
253 if err != nil {
254
255
256
257
258 return nil, fmt.Errorf("xds: failed to get security plugin instance (%+v): %v", cfg, err)
259 }
260 return provider, nil
261 }
262
263
264
265
266 func (b *cdsBalancer) createAndAddWatcherForCluster(name string) {
267 w := &clusterWatcher{
268 name: name,
269 parent: b,
270 }
271 ws := &watcherState{
272 watcher: w,
273 cancelWatch: xdsresource.WatchCluster(b.xdsClient, name, w),
274 }
275 b.watchers[name] = ws
276 }
277
278
279
280
281 func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
282 if b.xdsClient == nil {
283 c := xdsclient.FromResolverState(state.ResolverState)
284 if c == nil {
285 b.logger.Warningf("Received balancer config with no xDS client")
286 return balancer.ErrBadResolverState
287 }
288 b.xdsClient = c
289 }
290 b.logger.Infof("Received balancer config update: %s", pretty.ToJSON(state.BalancerConfig))
291
292
293
294
295 lbCfg, ok := state.BalancerConfig.(*lbConfig)
296 if !ok {
297 b.logger.Warningf("Received unexpected balancer config type: %T", state.BalancerConfig)
298 return balancer.ErrBadResolverState
299 }
300 if lbCfg.ClusterName == "" {
301 b.logger.Warningf("Received balancer config with no cluster name")
302 return balancer.ErrBadResolverState
303 }
304
305
306 if b.lbCfg != nil && b.lbCfg.ClusterName == lbCfg.ClusterName {
307 return nil
308 }
309 b.lbCfg = lbCfg
310
311
312 done := make(chan struct{})
313 ok = b.serializer.Schedule(func(context.Context) {
314
315
316 b.closeAllWatchers()
317
318
319
320
321 b.createAndAddWatcherForCluster(lbCfg.ClusterName)
322 close(done)
323 })
324 if !ok {
325
326
327 return errBalancerClosed
328 }
329 <-done
330 return nil
331 }
332
333
334 func (b *cdsBalancer) ResolverError(err error) {
335 b.serializer.Schedule(func(context.Context) {
336
337
338 if xdsresource.ErrType(err) == xdsresource.ErrorTypeResourceNotFound {
339 b.closeAllWatchers()
340 }
341 var root string
342 if b.lbCfg != nil {
343 root = b.lbCfg.ClusterName
344 }
345 b.onClusterError(root, err)
346 })
347 }
348
349
350 func (b *cdsBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
351 b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
352 }
353
354
355
356
357 func (b *cdsBalancer) closeAllWatchers() {
358 for name, state := range b.watchers {
359 state.cancelWatch()
360 delete(b.watchers, name)
361 }
362 }
363
364
365
366 func (b *cdsBalancer) Close() {
367 b.serializer.Schedule(func(ctx context.Context) {
368 b.closeAllWatchers()
369
370 if b.childLB != nil {
371 b.childLB.Close()
372 b.childLB = nil
373 }
374 if b.cachedRoot != nil {
375 b.cachedRoot.Close()
376 }
377 if b.cachedIdentity != nil {
378 b.cachedIdentity.Close()
379 }
380 b.logger.Infof("Shutdown")
381 })
382 b.serializerCancel()
383 <-b.serializer.Done()
384 }
385
386 func (b *cdsBalancer) ExitIdle() {
387 b.serializer.Schedule(func(context.Context) {
388 if b.childLB == nil {
389 b.logger.Warningf("Received ExitIdle with no child policy")
390 return
391 }
392
393
394
395
396 if ei, ok := b.childLB.(balancer.ExitIdler); ok {
397 ei.ExitIdle()
398 }
399 })
400 }
401
402
403
404
405
406
407 func (b *cdsBalancer) onClusterUpdate(name string, update xdsresource.ClusterUpdate) {
408 state := b.watchers[name]
409 if state == nil {
410
411 return
412 }
413
414 b.logger.Infof("Received Cluster resource: %s", pretty.ToJSON(update))
415
416
417 state.lastUpdate = &update
418
419
420
421 if name == b.lbCfg.ClusterName {
422
423
424
425
426
427 if err := b.handleSecurityConfig(update.SecurityCfg); err != nil {
428
429
430
431 b.onClusterError(name, fmt.Errorf("received Cluster resource contains invalid security config: %v", err))
432 return
433 }
434 }
435
436 clustersSeen := make(map[string]bool)
437 dms, ok, err := b.generateDMsForCluster(b.lbCfg.ClusterName, 0, nil, clustersSeen)
438 if err != nil {
439 b.onClusterError(b.lbCfg.ClusterName, fmt.Errorf("failed to generate discovery mechanisms: %v", err))
440 return
441 }
442 if ok {
443 if len(dms) == 0 {
444 b.onClusterError(b.lbCfg.ClusterName, fmt.Errorf("aggregate cluster graph has no leaf clusters"))
445 return
446 }
447
448 if b.childLB == nil {
449 childLB, err := newChildBalancer(b.ccw, b.bOpts)
450 if err != nil {
451 b.logger.Errorf("Failed to create child policy of type %s: %v", clusterresolver.Name, err)
452 return
453 }
454 b.childLB = childLB
455 b.logger.Infof("Created child policy %p of type %s", b.childLB, clusterresolver.Name)
456 }
457
458
459
460
461 childCfg := &clusterresolver.LBConfig{
462 DiscoveryMechanisms: dms,
463
464 XDSLBPolicy: b.watchers[b.lbCfg.ClusterName].lastUpdate.LBPolicy,
465 }
466 cfgJSON, err := json.Marshal(childCfg)
467 if err != nil {
468
469 b.logger.Errorf("cds_balancer: error marshalling prepared config: %v", childCfg)
470 return
471 }
472
473 var sc serviceconfig.LoadBalancingConfig
474 if sc, err = b.childConfigParser.ParseConfig(cfgJSON); err != nil {
475 b.logger.Errorf("cds_balancer: cluster_resolver config generated %v is invalid: %v", string(cfgJSON), err)
476 return
477 }
478
479 ccState := balancer.ClientConnState{
480 ResolverState: xdsclient.SetClient(resolver.State{}, b.xdsClient),
481 BalancerConfig: sc,
482 }
483 if err := b.childLB.UpdateClientConnState(ccState); err != nil {
484 b.logger.Errorf("Encountered error when sending config {%+v} to child policy: %v", ccState, err)
485 }
486 }
487
488
489 for cluster := range clustersSeen {
490 state, ok := b.watchers[cluster]
491 if ok {
492 continue
493 }
494 state.cancelWatch()
495 delete(b.watchers, cluster)
496 }
497 }
498
499
500
501
502
503
504 func (b *cdsBalancer) onClusterError(name string, err error) {
505 b.logger.Warningf("Cluster resource %q received error update: %v", name, err)
506
507 if b.childLB != nil {
508 if xdsresource.ErrType(err) != xdsresource.ErrorTypeConnection {
509
510
511 b.childLB.ResolverError(err)
512 }
513 } else {
514
515
516 b.ccw.UpdateState(balancer.State{
517 ConnectivityState: connectivity.TransientFailure,
518 Picker: base.NewErrPicker(fmt.Errorf("%q: %v", name, err)),
519 })
520 }
521 }
522
523
524
525
526
527
528 func (b *cdsBalancer) onClusterResourceNotFound(name string) {
529 err := xdsresource.NewErrorf(xdsresource.ErrorTypeResourceNotFound, "resource name %q of type Cluster not found in received response", name)
530 if b.childLB != nil {
531 b.childLB.ResolverError(err)
532 } else {
533
534 b.ccw.UpdateState(balancer.State{
535 ConnectivityState: connectivity.TransientFailure,
536 Picker: base.NewErrPicker(err),
537 })
538 }
539 }
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560 func (b *cdsBalancer) generateDMsForCluster(name string, depth int, dms []clusterresolver.DiscoveryMechanism, clustersSeen map[string]bool) ([]clusterresolver.DiscoveryMechanism, bool, error) {
561 if depth >= aggregateClusterMaxDepth {
562 return dms, false, errExceedsMaxDepth
563 }
564
565 if clustersSeen[name] {
566
567 return dms, true, nil
568 }
569 clustersSeen[name] = true
570
571 state, ok := b.watchers[name]
572 if !ok {
573
574
575 b.createAndAddWatcherForCluster(name)
576
577
578
579 return dms, false, nil
580 }
581
582
583 if state.lastUpdate == nil {
584 return dms, false, nil
585 }
586
587 var dm clusterresolver.DiscoveryMechanism
588 cluster := state.lastUpdate
589 switch cluster.ClusterType {
590 case xdsresource.ClusterTypeAggregate:
591
592
593
594
595
596 missingCluster := false
597 var err error
598 for _, child := range cluster.PrioritizedClusterNames {
599 var ok bool
600 dms, ok, err = b.generateDMsForCluster(child, depth+1, dms, clustersSeen)
601 if err != nil || !ok {
602 missingCluster = true
603 }
604 }
605 return dms, !missingCluster, err
606 case xdsresource.ClusterTypeEDS:
607 dm = clusterresolver.DiscoveryMechanism{
608 Type: clusterresolver.DiscoveryMechanismTypeEDS,
609 Cluster: cluster.ClusterName,
610 EDSServiceName: cluster.EDSServiceName,
611 MaxConcurrentRequests: cluster.MaxRequests,
612 LoadReportingServer: cluster.LRSServerConfig,
613 }
614 case xdsresource.ClusterTypeLogicalDNS:
615 dm = clusterresolver.DiscoveryMechanism{
616 Type: clusterresolver.DiscoveryMechanismTypeLogicalDNS,
617 Cluster: cluster.ClusterName,
618 DNSHostname: cluster.DNSHostName,
619 }
620 }
621 odJSON := cluster.OutlierDetection
622
623
624
625
626 if odJSON == nil {
627
628
629
630 odJSON = json.RawMessage(`{}`)
631 }
632 dm.OutlierDetection = odJSON
633
634 dm.TelemetryLabels = cluster.TelemetryLabels
635
636 return append(dms, dm), true, nil
637 }
638
639
640
641
642
643
644
645 type ccWrapper struct {
646 balancer.ClientConn
647
648 xdsHIPtr *unsafe.Pointer
649 }
650
651
652
653
654 func (ccw *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
655 newAddrs := make([]resolver.Address, len(addrs))
656 for i, addr := range addrs {
657 newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr)
658 }
659
660
661
662 return ccw.ClientConn.NewSubConn(newAddrs, opts)
663 }
664
665 func (ccw *ccWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
666 newAddrs := make([]resolver.Address, len(addrs))
667 for i, addr := range addrs {
668 newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr)
669 }
670 ccw.ClientConn.UpdateAddresses(sc, newAddrs)
671 }
672
View as plain text