...
1
18
19 package endpointsharding
20
21 import (
22 "context"
23 "encoding/json"
24 "fmt"
25 "log"
26 "testing"
27 "time"
28
29 "google.golang.org/grpc"
30 "google.golang.org/grpc/balancer"
31 "google.golang.org/grpc/credentials/insecure"
32 "google.golang.org/grpc/grpclog"
33 "google.golang.org/grpc/internal"
34 "google.golang.org/grpc/internal/grpctest"
35 "google.golang.org/grpc/internal/stubserver"
36 "google.golang.org/grpc/internal/testutils/roundrobin"
37 "google.golang.org/grpc/resolver"
38 "google.golang.org/grpc/resolver/manual"
39 "google.golang.org/grpc/serviceconfig"
40
41 testgrpc "google.golang.org/grpc/interop/grpc_testing"
42 )
43
44 type s struct {
45 grpctest.Tester
46 }
47
48 func Test(t *testing.T) {
49 grpctest.RunSubTests(t, s{})
50 }
51
52 var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig
53
54 var logger = grpclog.Component("endpoint-sharding-test")
55
56 func init() {
57 var err error
58 gracefulSwitchPickFirst, err = ParseConfig(json.RawMessage(PickFirstConfig))
59 if err != nil {
60 logger.Fatal(err)
61 }
62 balancer.Register(fakePetioleBuilder{})
63 }
64
65 const fakePetioleName = "fake_petiole"
66
67 type fakePetioleBuilder struct{}
68
69 func (fakePetioleBuilder) Name() string {
70 return fakePetioleName
71 }
72
73 func (fakePetioleBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
74 fp := &fakePetiole{
75 ClientConn: cc,
76 bOpts: opts,
77 }
78 fp.Balancer = NewBalancer(fp, opts)
79 return fp
80 }
81
82 func (fakePetioleBuilder) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
83 return nil, nil
84 }
85
86
87
88
89
90 type fakePetiole struct {
91 balancer.Balancer
92 balancer.ClientConn
93 bOpts balancer.BuildOptions
94 }
95
96 func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) error {
97 if el := state.ResolverState.Endpoints; len(el) != 2 {
98 return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el)
99 }
100
101 return fp.Balancer.UpdateClientConnState(balancer.ClientConnState{
102 BalancerConfig: gracefulSwitchPickFirst,
103 ResolverState: state.ResolverState,
104 })
105 }
106
107 func (fp *fakePetiole) UpdateState(state balancer.State) {
108 childStates := ChildStatesFromPicker(state.Picker)
109
110
111 if len(childStates) != 2 {
112 logger.Fatal(fmt.Errorf("length of child states received: %v, want 2", len(childStates)))
113 }
114
115 fp.ClientConn.UpdateState(state)
116 }
117
118
119
120
121
122
123
124
125
126 func (s) TestEndpointShardingBasic(t *testing.T) {
127 backend1 := stubserver.StartTestService(t, nil)
128 defer backend1.Stop()
129 backend2 := stubserver.StartTestService(t, nil)
130 defer backend2.Stop()
131
132 mr := manual.NewBuilderWithScheme("e2e-test")
133 defer mr.Close()
134
135 json := `{"loadBalancingConfig": [{"fake_petiole":{}}]}`
136 sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
137 mr.InitialState(resolver.State{
138 Endpoints: []resolver.Endpoint{
139 {Addresses: []resolver.Address{{Addr: backend1.Address}}},
140 {Addresses: []resolver.Address{{Addr: backend2.Address}}},
141 },
142 ServiceConfig: sc,
143 })
144
145 cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
146 if err != nil {
147 log.Fatalf("Failed to dial: %v", err)
148 }
149 defer cc.Close()
150 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
151 defer cancel()
152 client := testgrpc.NewTestServiceClient(cc)
153
154
155
156 if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}); err != nil {
157 t.Fatalf("error in expected round robin: %v", err)
158 }
159 }
160
View as plain text