...
1
18
19
20
21
22
23
24
25
26 package endpointsharding
27
28 import (
29 "encoding/json"
30 "errors"
31 "fmt"
32 "sync"
33 "sync/atomic"
34
35 "google.golang.org/grpc/balancer"
36 "google.golang.org/grpc/balancer/base"
37 "google.golang.org/grpc/connectivity"
38 "google.golang.org/grpc/internal/balancer/gracefulswitch"
39 "google.golang.org/grpc/internal/grpcrand"
40 "google.golang.org/grpc/resolver"
41 "google.golang.org/grpc/serviceconfig"
42 )
43
44
45
46 type ChildState struct {
47 Endpoint resolver.Endpoint
48 State balancer.State
49 }
50
51
52
53 func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
54 es := &endpointSharding{
55 cc: cc,
56 bOpts: opts,
57 }
58 es.children.Store(resolver.NewEndpointMap())
59 return es
60 }
61
62
63
64
65 type endpointSharding struct {
66 cc balancer.ClientConn
67 bOpts balancer.BuildOptions
68
69 children atomic.Pointer[resolver.EndpointMap]
70
71
72
73
74 inhibitChildUpdates atomic.Bool
75
76 mu sync.Mutex
77 }
78
79
80
81
82
83
84
85 func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState) error {
86 if len(state.ResolverState.Endpoints) == 0 {
87 return errors.New("endpoints list is empty")
88 }
89
90
91 for i, endpoint := range state.ResolverState.Endpoints {
92 if len(endpoint.Addresses) == 0 {
93 return fmt.Errorf("endpoint %d has empty addresses", i)
94 }
95 }
96
97 es.inhibitChildUpdates.Store(true)
98 defer func() {
99 es.inhibitChildUpdates.Store(false)
100 es.updateState()
101 }()
102 var ret error
103
104 children := es.children.Load()
105 newChildren := resolver.NewEndpointMap()
106
107
108 for _, endpoint := range state.ResolverState.Endpoints {
109 if _, ok := newChildren.Get(endpoint); ok {
110
111
112 continue
113 }
114 var bal *balancerWrapper
115 if child, ok := children.Get(endpoint); ok {
116 bal = child.(*balancerWrapper)
117 } else {
118 bal = &balancerWrapper{
119 childState: ChildState{Endpoint: endpoint},
120 ClientConn: es.cc,
121 es: es,
122 }
123 bal.Balancer = gracefulswitch.NewBalancer(bal, es.bOpts)
124 }
125 newChildren.Set(endpoint, bal)
126 if err := bal.UpdateClientConnState(balancer.ClientConnState{
127 BalancerConfig: state.BalancerConfig,
128 ResolverState: resolver.State{
129 Endpoints: []resolver.Endpoint{endpoint},
130 Attributes: state.ResolverState.Attributes,
131 },
132 }); err != nil && ret == nil {
133
134
135
136
137 ret = err
138 }
139 }
140
141 for _, e := range children.Keys() {
142 child, _ := children.Get(e)
143 bal := child.(balancer.Balancer)
144 if _, ok := newChildren.Get(e); !ok {
145 bal.Close()
146 }
147 }
148 es.children.Store(newChildren)
149 return ret
150 }
151
152
153
154
155 func (es *endpointSharding) ResolverError(err error) {
156 es.inhibitChildUpdates.Store(true)
157 defer func() {
158 es.inhibitChildUpdates.Store(false)
159 es.updateState()
160 }()
161 children := es.children.Load()
162 for _, child := range children.Values() {
163 bal := child.(balancer.Balancer)
164 bal.ResolverError(err)
165 }
166 }
167
168 func (es *endpointSharding) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
169
170 }
171
172 func (es *endpointSharding) Close() {
173 children := es.children.Load()
174 for _, child := range children.Values() {
175 bal := child.(balancer.Balancer)
176 bal.Close()
177 }
178 }
179
180
181
182
183 func (es *endpointSharding) updateState() {
184 if es.inhibitChildUpdates.Load() {
185 return
186 }
187 var readyPickers, connectingPickers, idlePickers, transientFailurePickers []balancer.Picker
188
189 es.mu.Lock()
190 defer es.mu.Unlock()
191
192 children := es.children.Load()
193 childStates := make([]ChildState, 0, children.Len())
194
195 for _, child := range children.Values() {
196 bw := child.(*balancerWrapper)
197 childState := bw.childState
198 childStates = append(childStates, childState)
199 childPicker := childState.State.Picker
200 switch childState.State.ConnectivityState {
201 case connectivity.Ready:
202 readyPickers = append(readyPickers, childPicker)
203 case connectivity.Connecting:
204 connectingPickers = append(connectingPickers, childPicker)
205 case connectivity.Idle:
206 idlePickers = append(idlePickers, childPicker)
207 case connectivity.TransientFailure:
208 transientFailurePickers = append(transientFailurePickers, childPicker)
209
210 }
211 }
212
213
214
215
216 var aggState connectivity.State
217 var pickers []balancer.Picker
218 if len(readyPickers) >= 1 {
219 aggState = connectivity.Ready
220 pickers = readyPickers
221 } else if len(connectingPickers) >= 1 {
222 aggState = connectivity.Connecting
223 pickers = connectingPickers
224 } else if len(idlePickers) >= 1 {
225 aggState = connectivity.Idle
226 pickers = idlePickers
227 } else if len(transientFailurePickers) >= 1 {
228 aggState = connectivity.TransientFailure
229 pickers = transientFailurePickers
230 } else {
231 aggState = connectivity.TransientFailure
232 pickers = []balancer.Picker{base.NewErrPicker(errors.New("no children to pick from"))}
233 }
234 p := &pickerWithChildStates{
235 pickers: pickers,
236 childStates: childStates,
237 next: uint32(grpcrand.Intn(len(pickers))),
238 }
239 es.cc.UpdateState(balancer.State{
240 ConnectivityState: aggState,
241 Picker: p,
242 })
243 }
244
245
246
247
248 type pickerWithChildStates struct {
249 pickers []balancer.Picker
250 childStates []ChildState
251 next uint32
252 }
253
254 func (p *pickerWithChildStates) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
255 nextIndex := atomic.AddUint32(&p.next, 1)
256 picker := p.pickers[nextIndex%uint32(len(p.pickers))]
257 return picker.Pick(info)
258 }
259
260
261
262 func ChildStatesFromPicker(picker balancer.Picker) []ChildState {
263 p, ok := picker.(*pickerWithChildStates)
264 if !ok {
265 return nil
266 }
267 return p.childStates
268 }
269
270
271
272 type balancerWrapper struct {
273 balancer.Balancer
274 balancer.ClientConn
275
276 es *endpointSharding
277
278 childState ChildState
279 }
280
281 func (bw *balancerWrapper) UpdateState(state balancer.State) {
282 bw.es.mu.Lock()
283 bw.childState.State = state
284 bw.es.mu.Unlock()
285 bw.es.updateState()
286 }
287
288 func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
289 return gracefulswitch.ParseConfig(cfg)
290 }
291
292
293 const PickFirstConfig = "[{\"pick_first\": {}}]"
294
View as plain text