...

Source file src/google.golang.org/grpc/balancer/weightedroundrobin/balancer.go

Documentation: google.golang.org/grpc/balancer/weightedroundrobin

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package weightedroundrobin
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"errors"
    25  	"fmt"
    26  	"sync"
    27  	"sync/atomic"
    28  	"time"
    29  	"unsafe"
    30  
    31  	"google.golang.org/grpc/balancer"
    32  	"google.golang.org/grpc/balancer/base"
    33  	"google.golang.org/grpc/balancer/weightedroundrobin/internal"
    34  	"google.golang.org/grpc/connectivity"
    35  	"google.golang.org/grpc/internal/grpclog"
    36  	"google.golang.org/grpc/internal/grpcrand"
    37  	iserviceconfig "google.golang.org/grpc/internal/serviceconfig"
    38  	"google.golang.org/grpc/orca"
    39  	"google.golang.org/grpc/resolver"
    40  	"google.golang.org/grpc/serviceconfig"
    41  
    42  	v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
    43  )
    44  
    45  // Name is the name of the weighted round robin balancer.
    46  const Name = "weighted_round_robin"
    47  
    48  func init() {
    49  	balancer.Register(bb{})
    50  }
    51  
    52  type bb struct{}
    53  
    54  func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
    55  	b := &wrrBalancer{
    56  		cc:                cc,
    57  		subConns:          resolver.NewAddressMap(),
    58  		csEvltr:           &balancer.ConnectivityStateEvaluator{},
    59  		scMap:             make(map[balancer.SubConn]*weightedSubConn),
    60  		connectivityState: connectivity.Connecting,
    61  	}
    62  	b.logger = prefixLogger(b)
    63  	b.logger.Infof("Created")
    64  	return b
    65  }
    66  
    67  func (bb) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
    68  	lbCfg := &lbConfig{
    69  		// Default values as documented in A58.
    70  		OOBReportingPeriod:      iserviceconfig.Duration(10 * time.Second),
    71  		BlackoutPeriod:          iserviceconfig.Duration(10 * time.Second),
    72  		WeightExpirationPeriod:  iserviceconfig.Duration(3 * time.Minute),
    73  		WeightUpdatePeriod:      iserviceconfig.Duration(time.Second),
    74  		ErrorUtilizationPenalty: 1,
    75  	}
    76  	if err := json.Unmarshal(js, lbCfg); err != nil {
    77  		return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
    78  	}
    79  
    80  	if lbCfg.ErrorUtilizationPenalty < 0 {
    81  		return nil, fmt.Errorf("wrr: errorUtilizationPenalty must be non-negative")
    82  	}
    83  
    84  	// For easier comparisons later, ensure the OOB reporting period is unset
    85  	// (0s) when OOB reports are disabled.
    86  	if !lbCfg.EnableOOBLoadReport {
    87  		lbCfg.OOBReportingPeriod = 0
    88  	}
    89  
    90  	// Impose lower bound of 100ms on weightUpdatePeriod.
    91  	if !internal.AllowAnyWeightUpdatePeriod && lbCfg.WeightUpdatePeriod < iserviceconfig.Duration(100*time.Millisecond) {
    92  		lbCfg.WeightUpdatePeriod = iserviceconfig.Duration(100 * time.Millisecond)
    93  	}
    94  
    95  	return lbCfg, nil
    96  }
    97  
    98  func (bb) Name() string {
    99  	return Name
   100  }
   101  
   102  // wrrBalancer implements the weighted round robin LB policy.
   103  type wrrBalancer struct {
   104  	cc     balancer.ClientConn
   105  	logger *grpclog.PrefixLogger
   106  
   107  	// The following fields are only accessed on calls into the LB policy, and
   108  	// do not need a mutex.
   109  	cfg               *lbConfig            // active config
   110  	subConns          *resolver.AddressMap // active weightedSubConns mapped by address
   111  	scMap             map[balancer.SubConn]*weightedSubConn
   112  	connectivityState connectivity.State // aggregate state
   113  	csEvltr           *balancer.ConnectivityStateEvaluator
   114  	resolverErr       error // the last error reported by the resolver; cleared on successful resolution
   115  	connErr           error // the last connection error; cleared upon leaving TransientFailure
   116  	stopPicker        func()
   117  }
   118  
   119  func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
   120  	b.logger.Infof("UpdateCCS: %v", ccs)
   121  	b.resolverErr = nil
   122  	cfg, ok := ccs.BalancerConfig.(*lbConfig)
   123  	if !ok {
   124  		return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig)
   125  	}
   126  
   127  	b.cfg = cfg
   128  	b.updateAddresses(ccs.ResolverState.Addresses)
   129  
   130  	if len(ccs.ResolverState.Addresses) == 0 {
   131  		b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker
   132  		return balancer.ErrBadResolverState
   133  	}
   134  
   135  	b.regeneratePicker()
   136  
   137  	return nil
   138  }
   139  
   140  func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) {
   141  	addrsSet := resolver.NewAddressMap()
   142  
   143  	// Loop through new address list and create subconns for any new addresses.
   144  	for _, addr := range addrs {
   145  		if _, ok := addrsSet.Get(addr); ok {
   146  			// Redundant address; skip.
   147  			continue
   148  		}
   149  		addrsSet.Set(addr, nil)
   150  
   151  		var wsc *weightedSubConn
   152  		wsci, ok := b.subConns.Get(addr)
   153  		if ok {
   154  			wsc = wsci.(*weightedSubConn)
   155  		} else {
   156  			// addr is a new address (not existing in b.subConns).
   157  			var sc balancer.SubConn
   158  			sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
   159  				StateListener: func(state balancer.SubConnState) {
   160  					b.updateSubConnState(sc, state)
   161  				},
   162  			})
   163  			if err != nil {
   164  				b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err)
   165  				continue
   166  			}
   167  			wsc = &weightedSubConn{
   168  				SubConn:           sc,
   169  				logger:            b.logger,
   170  				connectivityState: connectivity.Idle,
   171  				// Initially, we set load reports to off, because they are not
   172  				// running upon initial weightedSubConn creation.
   173  				cfg: &lbConfig{EnableOOBLoadReport: false},
   174  			}
   175  			b.subConns.Set(addr, wsc)
   176  			b.scMap[sc] = wsc
   177  			b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle)
   178  			sc.Connect()
   179  		}
   180  		// Update config for existing weightedSubConn or send update for first
   181  		// time to new one.  Ensures an OOB listener is running if needed
   182  		// (and stops the existing one if applicable).
   183  		wsc.updateConfig(b.cfg)
   184  	}
   185  
   186  	// Loop through existing subconns and remove ones that are not in addrs.
   187  	for _, addr := range b.subConns.Keys() {
   188  		if _, ok := addrsSet.Get(addr); ok {
   189  			// Existing address also in new address list; skip.
   190  			continue
   191  		}
   192  		// addr was removed by resolver.  Remove.
   193  		wsci, _ := b.subConns.Get(addr)
   194  		wsc := wsci.(*weightedSubConn)
   195  		wsc.SubConn.Shutdown()
   196  		b.subConns.Delete(addr)
   197  	}
   198  }
   199  
   200  func (b *wrrBalancer) ResolverError(err error) {
   201  	b.resolverErr = err
   202  	if b.subConns.Len() == 0 {
   203  		b.connectivityState = connectivity.TransientFailure
   204  	}
   205  	if b.connectivityState != connectivity.TransientFailure {
   206  		// No need to update the picker since no error is being returned.
   207  		return
   208  	}
   209  	b.regeneratePicker()
   210  }
   211  
   212  func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
   213  	b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
   214  }
   215  
   216  func (b *wrrBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
   217  	wsc := b.scMap[sc]
   218  	if wsc == nil {
   219  		b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state)
   220  		return
   221  	}
   222  	if b.logger.V(2) {
   223  		logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state)
   224  	}
   225  
   226  	cs := state.ConnectivityState
   227  
   228  	if cs == connectivity.TransientFailure {
   229  		// Save error to be reported via picker.
   230  		b.connErr = state.ConnectionError
   231  	}
   232  
   233  	if cs == connectivity.Shutdown {
   234  		delete(b.scMap, sc)
   235  		// The subconn was removed from b.subConns when the address was removed
   236  		// in updateAddresses.
   237  	}
   238  
   239  	oldCS := wsc.updateConnectivityState(cs)
   240  	b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs)
   241  
   242  	// Regenerate picker when one of the following happens:
   243  	//  - this sc entered or left ready
   244  	//  - the aggregated state of balancer is TransientFailure
   245  	//    (may need to update error message)
   246  	if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) ||
   247  		b.connectivityState == connectivity.TransientFailure {
   248  		b.regeneratePicker()
   249  	}
   250  }
   251  
   252  // Close stops the balancer.  It cancels any ongoing scheduler updates and
   253  // stops any ORCA listeners.
   254  func (b *wrrBalancer) Close() {
   255  	if b.stopPicker != nil {
   256  		b.stopPicker()
   257  		b.stopPicker = nil
   258  	}
   259  	for _, wsc := range b.scMap {
   260  		// Ensure any lingering OOB watchers are stopped.
   261  		wsc.updateConnectivityState(connectivity.Shutdown)
   262  	}
   263  }
   264  
   265  // ExitIdle is ignored; we always connect to all backends.
   266  func (b *wrrBalancer) ExitIdle() {}
   267  
   268  func (b *wrrBalancer) readySubConns() []*weightedSubConn {
   269  	var ret []*weightedSubConn
   270  	for _, v := range b.subConns.Values() {
   271  		wsc := v.(*weightedSubConn)
   272  		if wsc.connectivityState == connectivity.Ready {
   273  			ret = append(ret, wsc)
   274  		}
   275  	}
   276  	return ret
   277  }
   278  
   279  // mergeErrors builds an error from the last connection error and the last
   280  // resolver error.  Must only be called if b.connectivityState is
   281  // TransientFailure.
   282  func (b *wrrBalancer) mergeErrors() error {
   283  	// connErr must always be non-nil unless there are no SubConns, in which
   284  	// case resolverErr must be non-nil.
   285  	if b.connErr == nil {
   286  		return fmt.Errorf("last resolver error: %v", b.resolverErr)
   287  	}
   288  	if b.resolverErr == nil {
   289  		return fmt.Errorf("last connection error: %v", b.connErr)
   290  	}
   291  	return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr)
   292  }
   293  
   294  func (b *wrrBalancer) regeneratePicker() {
   295  	if b.stopPicker != nil {
   296  		b.stopPicker()
   297  		b.stopPicker = nil
   298  	}
   299  
   300  	switch b.connectivityState {
   301  	case connectivity.TransientFailure:
   302  		b.cc.UpdateState(balancer.State{
   303  			ConnectivityState: connectivity.TransientFailure,
   304  			Picker:            base.NewErrPicker(b.mergeErrors()),
   305  		})
   306  		return
   307  	case connectivity.Connecting, connectivity.Idle:
   308  		// Idle could happen very briefly if all subconns are Idle and we've
   309  		// asked them to connect but they haven't reported Connecting yet.
   310  		// Report the same as Connecting since this is temporary.
   311  		b.cc.UpdateState(balancer.State{
   312  			ConnectivityState: connectivity.Connecting,
   313  			Picker:            base.NewErrPicker(balancer.ErrNoSubConnAvailable),
   314  		})
   315  		return
   316  	case connectivity.Ready:
   317  		b.connErr = nil
   318  	}
   319  
   320  	p := &picker{
   321  		v:        grpcrand.Uint32(), // start the scheduler at a random point
   322  		cfg:      b.cfg,
   323  		subConns: b.readySubConns(),
   324  	}
   325  	var ctx context.Context
   326  	ctx, b.stopPicker = context.WithCancel(context.Background())
   327  	p.start(ctx)
   328  	b.cc.UpdateState(balancer.State{
   329  		ConnectivityState: b.connectivityState,
   330  		Picker:            p,
   331  	})
   332  }
   333  
   334  // picker is the WRR policy's picker.  It uses live-updating backend weights to
   335  // update the scheduler periodically and ensure picks are routed proportional
   336  // to those weights.
   337  type picker struct {
   338  	scheduler unsafe.Pointer     // *scheduler; accessed atomically
   339  	v         uint32             // incrementing value used by the scheduler; accessed atomically
   340  	cfg       *lbConfig          // active config when picker created
   341  	subConns  []*weightedSubConn // all READY subconns
   342  }
   343  
   344  // scWeights returns a slice containing the weights from p.subConns in the same
   345  // order as p.subConns.
   346  func (p *picker) scWeights() []float64 {
   347  	ws := make([]float64, len(p.subConns))
   348  	now := internal.TimeNow()
   349  	for i, wsc := range p.subConns {
   350  		ws[i] = wsc.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod))
   351  	}
   352  	return ws
   353  }
   354  
   355  func (p *picker) inc() uint32 {
   356  	return atomic.AddUint32(&p.v, 1)
   357  }
   358  
   359  func (p *picker) regenerateScheduler() {
   360  	s := newScheduler(p.scWeights(), p.inc)
   361  	atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s))
   362  }
   363  
   364  func (p *picker) start(ctx context.Context) {
   365  	p.regenerateScheduler()
   366  	if len(p.subConns) == 1 {
   367  		// No need to regenerate weights with only one backend.
   368  		return
   369  	}
   370  	go func() {
   371  		ticker := time.NewTicker(time.Duration(p.cfg.WeightUpdatePeriod))
   372  		defer ticker.Stop()
   373  		for {
   374  			select {
   375  			case <-ctx.Done():
   376  				return
   377  			case <-ticker.C:
   378  				p.regenerateScheduler()
   379  			}
   380  		}
   381  	}()
   382  }
   383  
   384  func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   385  	// Read the scheduler atomically.  All scheduler operations are threadsafe,
   386  	// and if the scheduler is replaced during this usage, we want to use the
   387  	// scheduler that was live when the pick started.
   388  	sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler))
   389  
   390  	pickedSC := p.subConns[sched.nextIndex()]
   391  	pr := balancer.PickResult{SubConn: pickedSC.SubConn}
   392  	if !p.cfg.EnableOOBLoadReport {
   393  		pr.Done = func(info balancer.DoneInfo) {
   394  			if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil {
   395  				pickedSC.OnLoadReport(load)
   396  			}
   397  		}
   398  	}
   399  	return pr, nil
   400  }
   401  
   402  // weightedSubConn is the wrapper of a subconn that holds the subconn and its
   403  // weight (and other parameters relevant to computing the effective weight).
   404  // When needed, it also tracks connectivity state, listens for metrics updates
   405  // by implementing the orca.OOBListener interface and manages that listener.
   406  type weightedSubConn struct {
   407  	balancer.SubConn
   408  	logger *grpclog.PrefixLogger
   409  
   410  	// The following fields are only accessed on calls into the LB policy, and
   411  	// do not need a mutex.
   412  	connectivityState connectivity.State
   413  	stopORCAListener  func()
   414  
   415  	// The following fields are accessed asynchronously and are protected by
   416  	// mu.  Note that mu may not be held when calling into the stopORCAListener
   417  	// or when registering a new listener, as those calls require the ORCA
   418  	// producer mu which is held when calling the listener, and the listener
   419  	// holds mu.
   420  	mu            sync.Mutex
   421  	weightVal     float64
   422  	nonEmptySince time.Time
   423  	lastUpdated   time.Time
   424  	cfg           *lbConfig
   425  }
   426  
   427  func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) {
   428  	if w.logger.V(2) {
   429  		w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load)
   430  	}
   431  	// Update weights of this subchannel according to the reported load
   432  	utilization := load.ApplicationUtilization
   433  	if utilization == 0 {
   434  		utilization = load.CpuUtilization
   435  	}
   436  	if utilization == 0 || load.RpsFractional == 0 {
   437  		if w.logger.V(2) {
   438  			w.logger.Infof("Ignoring empty load report for subchannel %v", w.SubConn)
   439  		}
   440  		return
   441  	}
   442  
   443  	w.mu.Lock()
   444  	defer w.mu.Unlock()
   445  
   446  	errorRate := load.Eps / load.RpsFractional
   447  	w.weightVal = load.RpsFractional / (utilization + errorRate*w.cfg.ErrorUtilizationPenalty)
   448  	if w.logger.V(2) {
   449  		w.logger.Infof("New weight for subchannel %v: %v", w.SubConn, w.weightVal)
   450  	}
   451  
   452  	w.lastUpdated = internal.TimeNow()
   453  	if w.nonEmptySince == (time.Time{}) {
   454  		w.nonEmptySince = w.lastUpdated
   455  	}
   456  }
   457  
   458  // updateConfig updates the parameters of the WRR policy and
   459  // stops/starts/restarts the ORCA OOB listener.
   460  func (w *weightedSubConn) updateConfig(cfg *lbConfig) {
   461  	w.mu.Lock()
   462  	oldCfg := w.cfg
   463  	w.cfg = cfg
   464  	w.mu.Unlock()
   465  
   466  	newPeriod := cfg.OOBReportingPeriod
   467  	if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport &&
   468  		newPeriod == oldCfg.OOBReportingPeriod {
   469  		// Load reporting wasn't enabled before or after, or load reporting was
   470  		// enabled before and after, and had the same period.  (Note that with
   471  		// load reporting disabled, OOBReportingPeriod is always 0.)
   472  		return
   473  	}
   474  	// (Optionally stop and) start the listener to use the new config's
   475  	// settings for OOB reporting.
   476  
   477  	if w.stopORCAListener != nil {
   478  		w.stopORCAListener()
   479  	}
   480  	if !cfg.EnableOOBLoadReport {
   481  		w.stopORCAListener = nil
   482  		return
   483  	}
   484  	if w.logger.V(2) {
   485  		w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod)
   486  	}
   487  	opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)}
   488  	w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts)
   489  }
   490  
   491  func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State {
   492  	switch cs {
   493  	case connectivity.Idle:
   494  		// Always reconnect when idle.
   495  		w.SubConn.Connect()
   496  	case connectivity.Ready:
   497  		// If we transition back to READY state, reset nonEmptySince so that we
   498  		// apply the blackout period after we start receiving load data.  Note
   499  		// that we cannot guarantee that we will never receive lingering
   500  		// callbacks for backend metric reports from the previous connection
   501  		// after the new connection has been established, but they should be
   502  		// masked by new backend metric reports from the new connection by the
   503  		// time the blackout period ends.
   504  		w.mu.Lock()
   505  		w.nonEmptySince = time.Time{}
   506  		w.mu.Unlock()
   507  	case connectivity.Shutdown:
   508  		if w.stopORCAListener != nil {
   509  			w.stopORCAListener()
   510  		}
   511  	}
   512  
   513  	oldCS := w.connectivityState
   514  
   515  	if oldCS == connectivity.TransientFailure &&
   516  		(cs == connectivity.Connecting || cs == connectivity.Idle) {
   517  		// Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or
   518  		// CONNECTING transitions to prevent the aggregated state from being
   519  		// always CONNECTING when many backends exist but are all down.
   520  		return oldCS
   521  	}
   522  
   523  	w.connectivityState = cs
   524  
   525  	return oldCS
   526  }
   527  
   528  // weight returns the current effective weight of the subconn, taking into
   529  // account the parameters.  Returns 0 for blacked out or expired data, which
   530  // will cause the backend weight to be treated as the mean of the weights of
   531  // the other backends.
   532  func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration) float64 {
   533  	w.mu.Lock()
   534  	defer w.mu.Unlock()
   535  	// If the most recent update was longer ago than the expiration period,
   536  	// reset nonEmptySince so that we apply the blackout period again if we
   537  	// start getting data again in the future, and return 0.
   538  	if now.Sub(w.lastUpdated) >= weightExpirationPeriod {
   539  		w.nonEmptySince = time.Time{}
   540  		return 0
   541  	}
   542  	// If we don't have at least blackoutPeriod worth of data, return 0.
   543  	if blackoutPeriod != 0 && (w.nonEmptySince == (time.Time{}) || now.Sub(w.nonEmptySince) < blackoutPeriod) {
   544  		return 0
   545  	}
   546  	return w.weightVal
   547  }
   548  

View as plain text