1
18
19 package weightedroundrobin
20
21 import (
22 "context"
23 "encoding/json"
24 "errors"
25 "fmt"
26 "sync"
27 "sync/atomic"
28 "time"
29 "unsafe"
30
31 "google.golang.org/grpc/balancer"
32 "google.golang.org/grpc/balancer/base"
33 "google.golang.org/grpc/balancer/weightedroundrobin/internal"
34 "google.golang.org/grpc/connectivity"
35 "google.golang.org/grpc/internal/grpclog"
36 "google.golang.org/grpc/internal/grpcrand"
37 iserviceconfig "google.golang.org/grpc/internal/serviceconfig"
38 "google.golang.org/grpc/orca"
39 "google.golang.org/grpc/resolver"
40 "google.golang.org/grpc/serviceconfig"
41
42 v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
43 )
44
45
46 const Name = "weighted_round_robin"
47
48 func init() {
49 balancer.Register(bb{})
50 }
51
52 type bb struct{}
53
54 func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
55 b := &wrrBalancer{
56 cc: cc,
57 subConns: resolver.NewAddressMap(),
58 csEvltr: &balancer.ConnectivityStateEvaluator{},
59 scMap: make(map[balancer.SubConn]*weightedSubConn),
60 connectivityState: connectivity.Connecting,
61 }
62 b.logger = prefixLogger(b)
63 b.logger.Infof("Created")
64 return b
65 }
66
67 func (bb) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
68 lbCfg := &lbConfig{
69
70 OOBReportingPeriod: iserviceconfig.Duration(10 * time.Second),
71 BlackoutPeriod: iserviceconfig.Duration(10 * time.Second),
72 WeightExpirationPeriod: iserviceconfig.Duration(3 * time.Minute),
73 WeightUpdatePeriod: iserviceconfig.Duration(time.Second),
74 ErrorUtilizationPenalty: 1,
75 }
76 if err := json.Unmarshal(js, lbCfg); err != nil {
77 return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
78 }
79
80 if lbCfg.ErrorUtilizationPenalty < 0 {
81 return nil, fmt.Errorf("wrr: errorUtilizationPenalty must be non-negative")
82 }
83
84
85
86 if !lbCfg.EnableOOBLoadReport {
87 lbCfg.OOBReportingPeriod = 0
88 }
89
90
91 if !internal.AllowAnyWeightUpdatePeriod && lbCfg.WeightUpdatePeriod < iserviceconfig.Duration(100*time.Millisecond) {
92 lbCfg.WeightUpdatePeriod = iserviceconfig.Duration(100 * time.Millisecond)
93 }
94
95 return lbCfg, nil
96 }
97
98 func (bb) Name() string {
99 return Name
100 }
101
102
103 type wrrBalancer struct {
104 cc balancer.ClientConn
105 logger *grpclog.PrefixLogger
106
107
108
109 cfg *lbConfig
110 subConns *resolver.AddressMap
111 scMap map[balancer.SubConn]*weightedSubConn
112 connectivityState connectivity.State
113 csEvltr *balancer.ConnectivityStateEvaluator
114 resolverErr error
115 connErr error
116 stopPicker func()
117 }
118
119 func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
120 b.logger.Infof("UpdateCCS: %v", ccs)
121 b.resolverErr = nil
122 cfg, ok := ccs.BalancerConfig.(*lbConfig)
123 if !ok {
124 return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig)
125 }
126
127 b.cfg = cfg
128 b.updateAddresses(ccs.ResolverState.Addresses)
129
130 if len(ccs.ResolverState.Addresses) == 0 {
131 b.ResolverError(errors.New("resolver produced zero addresses"))
132 return balancer.ErrBadResolverState
133 }
134
135 b.regeneratePicker()
136
137 return nil
138 }
139
140 func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) {
141 addrsSet := resolver.NewAddressMap()
142
143
144 for _, addr := range addrs {
145 if _, ok := addrsSet.Get(addr); ok {
146
147 continue
148 }
149 addrsSet.Set(addr, nil)
150
151 var wsc *weightedSubConn
152 wsci, ok := b.subConns.Get(addr)
153 if ok {
154 wsc = wsci.(*weightedSubConn)
155 } else {
156
157 var sc balancer.SubConn
158 sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
159 StateListener: func(state balancer.SubConnState) {
160 b.updateSubConnState(sc, state)
161 },
162 })
163 if err != nil {
164 b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err)
165 continue
166 }
167 wsc = &weightedSubConn{
168 SubConn: sc,
169 logger: b.logger,
170 connectivityState: connectivity.Idle,
171
172
173 cfg: &lbConfig{EnableOOBLoadReport: false},
174 }
175 b.subConns.Set(addr, wsc)
176 b.scMap[sc] = wsc
177 b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle)
178 sc.Connect()
179 }
180
181
182
183 wsc.updateConfig(b.cfg)
184 }
185
186
187 for _, addr := range b.subConns.Keys() {
188 if _, ok := addrsSet.Get(addr); ok {
189
190 continue
191 }
192
193 wsci, _ := b.subConns.Get(addr)
194 wsc := wsci.(*weightedSubConn)
195 wsc.SubConn.Shutdown()
196 b.subConns.Delete(addr)
197 }
198 }
199
200 func (b *wrrBalancer) ResolverError(err error) {
201 b.resolverErr = err
202 if b.subConns.Len() == 0 {
203 b.connectivityState = connectivity.TransientFailure
204 }
205 if b.connectivityState != connectivity.TransientFailure {
206
207 return
208 }
209 b.regeneratePicker()
210 }
211
212 func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
213 b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
214 }
215
216 func (b *wrrBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
217 wsc := b.scMap[sc]
218 if wsc == nil {
219 b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state)
220 return
221 }
222 if b.logger.V(2) {
223 logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state)
224 }
225
226 cs := state.ConnectivityState
227
228 if cs == connectivity.TransientFailure {
229
230 b.connErr = state.ConnectionError
231 }
232
233 if cs == connectivity.Shutdown {
234 delete(b.scMap, sc)
235
236
237 }
238
239 oldCS := wsc.updateConnectivityState(cs)
240 b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs)
241
242
243
244
245
246 if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) ||
247 b.connectivityState == connectivity.TransientFailure {
248 b.regeneratePicker()
249 }
250 }
251
252
253
254 func (b *wrrBalancer) Close() {
255 if b.stopPicker != nil {
256 b.stopPicker()
257 b.stopPicker = nil
258 }
259 for _, wsc := range b.scMap {
260
261 wsc.updateConnectivityState(connectivity.Shutdown)
262 }
263 }
264
265
266 func (b *wrrBalancer) ExitIdle() {}
267
268 func (b *wrrBalancer) readySubConns() []*weightedSubConn {
269 var ret []*weightedSubConn
270 for _, v := range b.subConns.Values() {
271 wsc := v.(*weightedSubConn)
272 if wsc.connectivityState == connectivity.Ready {
273 ret = append(ret, wsc)
274 }
275 }
276 return ret
277 }
278
279
280
281
282 func (b *wrrBalancer) mergeErrors() error {
283
284
285 if b.connErr == nil {
286 return fmt.Errorf("last resolver error: %v", b.resolverErr)
287 }
288 if b.resolverErr == nil {
289 return fmt.Errorf("last connection error: %v", b.connErr)
290 }
291 return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr)
292 }
293
294 func (b *wrrBalancer) regeneratePicker() {
295 if b.stopPicker != nil {
296 b.stopPicker()
297 b.stopPicker = nil
298 }
299
300 switch b.connectivityState {
301 case connectivity.TransientFailure:
302 b.cc.UpdateState(balancer.State{
303 ConnectivityState: connectivity.TransientFailure,
304 Picker: base.NewErrPicker(b.mergeErrors()),
305 })
306 return
307 case connectivity.Connecting, connectivity.Idle:
308
309
310
311 b.cc.UpdateState(balancer.State{
312 ConnectivityState: connectivity.Connecting,
313 Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable),
314 })
315 return
316 case connectivity.Ready:
317 b.connErr = nil
318 }
319
320 p := &picker{
321 v: grpcrand.Uint32(),
322 cfg: b.cfg,
323 subConns: b.readySubConns(),
324 }
325 var ctx context.Context
326 ctx, b.stopPicker = context.WithCancel(context.Background())
327 p.start(ctx)
328 b.cc.UpdateState(balancer.State{
329 ConnectivityState: b.connectivityState,
330 Picker: p,
331 })
332 }
333
334
335
336
337 type picker struct {
338 scheduler unsafe.Pointer
339 v uint32
340 cfg *lbConfig
341 subConns []*weightedSubConn
342 }
343
344
345
346 func (p *picker) scWeights() []float64 {
347 ws := make([]float64, len(p.subConns))
348 now := internal.TimeNow()
349 for i, wsc := range p.subConns {
350 ws[i] = wsc.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod))
351 }
352 return ws
353 }
354
355 func (p *picker) inc() uint32 {
356 return atomic.AddUint32(&p.v, 1)
357 }
358
359 func (p *picker) regenerateScheduler() {
360 s := newScheduler(p.scWeights(), p.inc)
361 atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s))
362 }
363
364 func (p *picker) start(ctx context.Context) {
365 p.regenerateScheduler()
366 if len(p.subConns) == 1 {
367
368 return
369 }
370 go func() {
371 ticker := time.NewTicker(time.Duration(p.cfg.WeightUpdatePeriod))
372 defer ticker.Stop()
373 for {
374 select {
375 case <-ctx.Done():
376 return
377 case <-ticker.C:
378 p.regenerateScheduler()
379 }
380 }
381 }()
382 }
383
384 func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
385
386
387
388 sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler))
389
390 pickedSC := p.subConns[sched.nextIndex()]
391 pr := balancer.PickResult{SubConn: pickedSC.SubConn}
392 if !p.cfg.EnableOOBLoadReport {
393 pr.Done = func(info balancer.DoneInfo) {
394 if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil {
395 pickedSC.OnLoadReport(load)
396 }
397 }
398 }
399 return pr, nil
400 }
401
402
403
404
405
406 type weightedSubConn struct {
407 balancer.SubConn
408 logger *grpclog.PrefixLogger
409
410
411
412 connectivityState connectivity.State
413 stopORCAListener func()
414
415
416
417
418
419
420 mu sync.Mutex
421 weightVal float64
422 nonEmptySince time.Time
423 lastUpdated time.Time
424 cfg *lbConfig
425 }
426
427 func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
428 if w.logger.V(2) {
429 w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load)
430 }
431
432 utilization := load.ApplicationUtilization
433 if utilization == 0 {
434 utilization = load.CpuUtilization
435 }
436 if utilization == 0 || load.RpsFractional == 0 {
437 if w.logger.V(2) {
438 w.logger.Infof("Ignoring empty load report for subchannel %v", w.SubConn)
439 }
440 return
441 }
442
443 w.mu.Lock()
444 defer w.mu.Unlock()
445
446 errorRate := load.Eps / load.RpsFractional
447 w.weightVal = load.RpsFractional / (utilization + errorRate*w.cfg.ErrorUtilizationPenalty)
448 if w.logger.V(2) {
449 w.logger.Infof("New weight for subchannel %v: %v", w.SubConn, w.weightVal)
450 }
451
452 w.lastUpdated = internal.TimeNow()
453 if w.nonEmptySince == (time.Time{}) {
454 w.nonEmptySince = w.lastUpdated
455 }
456 }
457
458
459
460 func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
461 w.mu.Lock()
462 oldCfg := w.cfg
463 w.cfg = cfg
464 w.mu.Unlock()
465
466 newPeriod := cfg.OOBReportingPeriod
467 if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
468 newPeriod == oldCfg.OOBReportingPeriod {
469
470
471
472 return
473 }
474
475
476
477 if w.stopORCAListener != nil {
478 w.stopORCAListener()
479 }
480 if !cfg.EnableOOBLoadReport {
481 w.stopORCAListener = nil
482 return
483 }
484 if w.logger.V(2) {
485 w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
486 }
487 opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)}
488 w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
489 }
490
491 func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State {
492 switch cs {
493 case connectivity.Idle:
494
495 w.SubConn.Connect()
496 case connectivity.Ready:
497
498
499
500
501
502
503
504 w.mu.Lock()
505 w.nonEmptySince = time.Time{}
506 w.mu.Unlock()
507 case connectivity.Shutdown:
508 if w.stopORCAListener != nil {
509 w.stopORCAListener()
510 }
511 }
512
513 oldCS := w.connectivityState
514
515 if oldCS == connectivity.TransientFailure &&
516 (cs == connectivity.Connecting || cs == connectivity.Idle) {
517
518
519
520 return oldCS
521 }
522
523 w.connectivityState = cs
524
525 return oldCS
526 }
527
528
529
530
531
532 func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 {
533 w.mu.Lock()
534 defer w.mu.Unlock()
535
536
537
538 if now.Sub(w.lastUpdated) >= weightExpirationPeriod {
539 w.nonEmptySince = time.Time{}
540 return 0
541 }
542
543 if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) {
544 return 0
545 }
546 return w.weightVal
547 }
548
View as plain text