1
18
19 package grpclb
20
21 import (
22 "context"
23 "fmt"
24 "io"
25 "net"
26 "sync"
27 "time"
28
29 "google.golang.org/grpc"
30 "google.golang.org/grpc/balancer"
31 "google.golang.org/grpc/connectivity"
32 "google.golang.org/grpc/credentials/insecure"
33 "google.golang.org/grpc/internal/backoff"
34 imetadata "google.golang.org/grpc/internal/metadata"
35 "google.golang.org/grpc/keepalive"
36 "google.golang.org/grpc/metadata"
37 "google.golang.org/grpc/resolver"
38 "google.golang.org/protobuf/proto"
39 "google.golang.org/protobuf/types/known/timestamppb"
40
41 lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
42 )
43
44 func serverListEqual(a, b []*lbpb.Server) bool {
45 if len(a) != len(b) {
46 return false
47 }
48 for i := 0; i < len(a); i++ {
49 if !proto.Equal(a[i], b[i]) {
50 return false
51 }
52 }
53 return true
54 }
55
56
57
58 func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
59 if lb.logger.V(2) {
60 lb.logger.Infof("Processing server list: %#v", l)
61 }
62 lb.mu.Lock()
63 defer lb.mu.Unlock()
64
65
66
67 lb.serverListReceived = true
68
69
70 if serverListEqual(lb.fullServerList, l.Servers) {
71 if lb.logger.V(2) {
72 lb.logger.Infof("Ignoring new server list as it is the same as the previous one")
73 }
74 return
75 }
76 lb.fullServerList = l.Servers
77
78 var backendAddrs []resolver.Address
79 for i, s := range l.Servers {
80 if s.Drop {
81 continue
82 }
83
84 md := metadata.Pairs(lbTokenKey, s.LoadBalanceToken)
85 ip := net.IP(s.IpAddress)
86 ipStr := ip.String()
87 if ip.To4() == nil {
88
89
90 ipStr = fmt.Sprintf("[%s]", ipStr)
91 }
92 addr := imetadata.Set(resolver.Address{Addr: fmt.Sprintf("%s:%d", ipStr, s.Port)}, md)
93 if lb.logger.V(2) {
94 lb.logger.Infof("Server list entry:|%d|, ipStr:|%s|, port:|%d|, load balancer token:|%v|", i, ipStr, s.Port, s.LoadBalanceToken)
95 }
96 backendAddrs = append(backendAddrs, addr)
97 }
98
99
100
101 lb.refreshSubConns(backendAddrs, false, lb.usePickFirst)
102 }
103
104
105
106
107
108 func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fallback bool, pickFirst bool) {
109 opts := balancer.NewSubConnOptions{}
110 if !fallback {
111 opts.CredsBundle = lb.grpclbBackendCreds
112 }
113
114 lb.backendAddrs = backendAddrs
115 lb.backendAddrsWithoutMetadata = nil
116
117 fallbackModeChanged := lb.inFallback != fallback
118 lb.inFallback = fallback
119 if fallbackModeChanged && lb.inFallback {
120
121
122
123 lb.fullServerList = nil
124 }
125
126 balancingPolicyChanged := lb.usePickFirst != pickFirst
127 lb.usePickFirst = pickFirst
128
129 if fallbackModeChanged || balancingPolicyChanged {
130
131
132
133
134
135 for a, sc := range lb.subConns {
136 sc.Shutdown()
137 delete(lb.subConns, a)
138 }
139 }
140
141 if lb.usePickFirst {
142 var (
143 scKey resolver.Address
144 sc balancer.SubConn
145 )
146 for scKey, sc = range lb.subConns {
147 break
148 }
149 if sc != nil {
150 if len(backendAddrs) == 0 {
151 sc.Shutdown()
152 delete(lb.subConns, scKey)
153 return
154 }
155 lb.cc.ClientConn.UpdateAddresses(sc, backendAddrs)
156 sc.Connect()
157 return
158 }
159 opts.StateListener = func(scs balancer.SubConnState) { lb.updateSubConnState(sc, scs) }
160
161 sc, err := lb.cc.ClientConn.NewSubConn(backendAddrs, opts)
162 if err != nil {
163 lb.logger.Warningf("Failed to create new SubConn: %v", err)
164 return
165 }
166 sc.Connect()
167 lb.subConns[backendAddrs[0]] = sc
168 lb.scStates[sc] = connectivity.Idle
169 return
170 }
171
172
173
174 addrsSet := make(map[resolver.Address]struct{})
175
176 for _, addr := range backendAddrs {
177 addrWithoutAttrs := addr
178 addrWithoutAttrs.Attributes = nil
179 addrsSet[addrWithoutAttrs] = struct{}{}
180 lb.backendAddrsWithoutMetadata = append(lb.backendAddrsWithoutMetadata, addrWithoutAttrs)
181
182 if _, ok := lb.subConns[addrWithoutAttrs]; !ok {
183
184 var sc balancer.SubConn
185 opts.StateListener = func(scs balancer.SubConnState) { lb.updateSubConnState(sc, scs) }
186 sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts)
187 if err != nil {
188 lb.logger.Warningf("Failed to create new SubConn: %v", err)
189 continue
190 }
191 lb.subConns[addrWithoutAttrs] = sc
192 if _, ok := lb.scStates[sc]; !ok {
193
194
195 lb.scStates[sc] = connectivity.Idle
196 }
197 sc.Connect()
198 }
199 }
200
201 for a, sc := range lb.subConns {
202
203 if _, ok := addrsSet[a]; !ok {
204 sc.Shutdown()
205 delete(lb.subConns, a)
206
207
208 }
209 }
210
211
212
213
214
215 lb.updateStateAndPicker(true, true)
216 }
217
218 type remoteBalancerCCWrapper struct {
219 cc *grpc.ClientConn
220 lb *lbBalancer
221 backoff backoff.Strategy
222 done chan struct{}
223
224 streamMu sync.Mutex
225 streamCancel func()
226
227
228 wg sync.WaitGroup
229 }
230
231 func (lb *lbBalancer) newRemoteBalancerCCWrapper() error {
232 var dopts []grpc.DialOption
233 if creds := lb.opt.DialCreds; creds != nil {
234 dopts = append(dopts, grpc.WithTransportCredentials(creds))
235 } else if bundle := lb.grpclbClientConnCreds; bundle != nil {
236 dopts = append(dopts, grpc.WithCredentialsBundle(bundle))
237 } else {
238 dopts = append(dopts, grpc.WithTransportCredentials(insecure.NewCredentials()))
239 }
240 if lb.opt.Dialer != nil {
241 dopts = append(dopts, grpc.WithContextDialer(lb.opt.Dialer))
242 }
243 if lb.opt.CustomUserAgent != "" {
244 dopts = append(dopts, grpc.WithUserAgent(lb.opt.CustomUserAgent))
245 }
246
247 dopts = append(dopts, grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"pick_first"}`))
248 dopts = append(dopts, grpc.WithResolvers(lb.manualResolver))
249 dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParent))
250
251
252 dopts = append(dopts, grpc.WithKeepaliveParams(keepalive.ClientParameters{
253 Time: 20 * time.Second,
254 Timeout: 10 * time.Second,
255 PermitWithoutStream: true,
256 }))
257
258
259
260
261
262 target := lb.manualResolver.Scheme() + ":///grpclb.subClientConn"
263 cc, err := grpc.Dial(target, dopts...)
264 if err != nil {
265 return fmt.Errorf("grpc.Dial(%s): %v", target, err)
266 }
267 ccw := &remoteBalancerCCWrapper{
268 cc: cc,
269 lb: lb,
270 backoff: lb.backoff,
271 done: make(chan struct{}),
272 }
273 lb.ccRemoteLB = ccw
274 ccw.wg.Add(1)
275 go ccw.watchRemoteBalancer()
276 return nil
277 }
278
279
280
281 func (ccw *remoteBalancerCCWrapper) close() {
282 close(ccw.done)
283 ccw.cc.Close()
284 ccw.wg.Wait()
285 }
286
287 func (ccw *remoteBalancerCCWrapper) readServerList(s *balanceLoadClientStream) error {
288 for {
289 reply, err := s.Recv()
290 if err != nil {
291 if err == io.EOF {
292 return errServerTerminatedConnection
293 }
294 return fmt.Errorf("grpclb: failed to recv server list: %v", err)
295 }
296 if serverList := reply.GetServerList(); serverList != nil {
297 ccw.lb.processServerList(serverList)
298 }
299 if reply.GetFallbackResponse() != nil {
300
301 ccw.lb.mu.Lock()
302 ccw.lb.refreshSubConns(ccw.lb.resolvedBackendAddrs, true, ccw.lb.usePickFirst)
303 ccw.lb.mu.Unlock()
304 }
305 }
306 }
307
308 func (ccw *remoteBalancerCCWrapper) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) {
309 ticker := time.NewTicker(interval)
310 defer ticker.Stop()
311 lastZero := false
312 for {
313 select {
314 case <-ticker.C:
315 case <-s.Context().Done():
316 return
317 }
318 stats := ccw.lb.clientStats.toClientStats()
319 zero := isZeroStats(stats)
320 if zero && lastZero {
321
322 continue
323 }
324 lastZero = zero
325 t := time.Now()
326 stats.Timestamp = ×tamppb.Timestamp{
327 Seconds: t.Unix(),
328 Nanos: int32(t.Nanosecond()),
329 }
330 if err := s.Send(&lbpb.LoadBalanceRequest{
331 LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{
332 ClientStats: stats,
333 },
334 }); err != nil {
335 return
336 }
337 }
338 }
339
340 func (ccw *remoteBalancerCCWrapper) callRemoteBalancer(ctx context.Context) (backoff bool, _ error) {
341 lbClient := &loadBalancerClient{cc: ccw.cc}
342 stream, err := lbClient.BalanceLoad(ctx, grpc.WaitForReady(true))
343 if err != nil {
344 return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer: %v", err)
345 }
346 ccw.lb.mu.Lock()
347 ccw.lb.remoteBalancerConnected = true
348 ccw.lb.mu.Unlock()
349
350
351 initReq := &lbpb.LoadBalanceRequest{
352 LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
353 InitialRequest: &lbpb.InitialLoadBalanceRequest{
354 Name: ccw.lb.target,
355 },
356 },
357 }
358 if err := stream.Send(initReq); err != nil {
359 return true, fmt.Errorf("grpclb: failed to send init request: %v", err)
360 }
361 reply, err := stream.Recv()
362 if err != nil {
363 return true, fmt.Errorf("grpclb: failed to recv init response: %v", err)
364 }
365 initResp := reply.GetInitialResponse()
366 if initResp == nil {
367 return true, fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
368 }
369
370 ccw.wg.Add(1)
371 go func() {
372 defer ccw.wg.Done()
373 if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 {
374 ccw.sendLoadReport(stream, d)
375 }
376 }()
377
378 return false, ccw.readServerList(stream)
379 }
380
381
382
383
384 func (ccw *remoteBalancerCCWrapper) cancelRemoteBalancerCall() {
385 ccw.streamMu.Lock()
386 if ccw.streamCancel != nil {
387 ccw.streamCancel()
388 ccw.streamCancel = nil
389 }
390 ccw.streamMu.Unlock()
391 }
392
393 func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() {
394 defer func() {
395 ccw.wg.Done()
396 ccw.streamMu.Lock()
397 if ccw.streamCancel != nil {
398
399
400 ccw.streamCancel()
401 ccw.streamCancel = nil
402 }
403 ccw.streamMu.Unlock()
404 }()
405
406 var retryCount int
407 var ctx context.Context
408 for {
409 ccw.streamMu.Lock()
410 if ccw.streamCancel != nil {
411 ccw.streamCancel()
412 ccw.streamCancel = nil
413 }
414 ctx, ccw.streamCancel = context.WithCancel(context.Background())
415 ccw.streamMu.Unlock()
416
417 doBackoff, err := ccw.callRemoteBalancer(ctx)
418 select {
419 case <-ccw.done:
420 return
421 default:
422 if err != nil {
423 if err == errServerTerminatedConnection {
424 ccw.lb.logger.Infof("Call to remote balancer failed: %v", err)
425 } else {
426 ccw.lb.logger.Warningf("Call to remote balancer failed: %v", err)
427 }
428 }
429 }
430
431 ccw.lb.cc.ClientConn.ResolveNow(resolver.ResolveNowOptions{})
432
433 ccw.lb.mu.Lock()
434 ccw.lb.remoteBalancerConnected = false
435 ccw.lb.fullServerList = nil
436
437
438 if !ccw.lb.inFallback && ccw.lb.state != connectivity.Ready {
439
440 ccw.lb.refreshSubConns(ccw.lb.resolvedBackendAddrs, true, ccw.lb.usePickFirst)
441 }
442 ccw.lb.mu.Unlock()
443
444 if !doBackoff {
445 retryCount = 0
446 continue
447 }
448
449 timer := time.NewTimer(ccw.backoff.Backoff(retryCount))
450 select {
451 case <-timer.C:
452 case <-ccw.done:
453 timer.Stop()
454 return
455 }
456 retryCount++
457 }
458 }
459
View as plain text