1
18
19
20 package main
21
22 import (
23 "context"
24 "flag"
25 "fmt"
26 "math/rand"
27 "net"
28 "os"
29 "strconv"
30 "strings"
31 "sync"
32 "sync/atomic"
33 "time"
34
35 "google.golang.org/grpc"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/credentials/google"
39 "google.golang.org/grpc/credentials/insecure"
40 "google.golang.org/grpc/grpclog"
41 "google.golang.org/grpc/interop"
42 "google.golang.org/grpc/resolver"
43 "google.golang.org/grpc/status"
44 "google.golang.org/grpc/testdata"
45
46 _ "google.golang.org/grpc/xds/googledirectpath"
47
48 testgrpc "google.golang.org/grpc/interop/grpc_testing"
49 metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
50 )
51
52 const (
53 googleDefaultCredsName = "google_default_credentials"
54 computeEngineCredsName = "compute_engine_channel_creds"
55 )
56
57 var (
58 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
59 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
60 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
61 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
62 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
63 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
64 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
65 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
66 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
67 caFile = flag.String("ca_file", "", "The file containing the CA root cert file")
68 customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")
69
70 totalNumCalls int64
71 logger = grpclog.Component("stress")
72 )
73
74
75 type testCaseWithWeight struct {
76 name string
77 weight int
78 }
79
80
81 func parseTestCases(testCaseString string) []testCaseWithWeight {
82 testCaseStrings := strings.Split(testCaseString, ",")
83 testCases := make([]testCaseWithWeight, len(testCaseStrings))
84 for i, str := range testCaseStrings {
85 testCaseNameAndWeight := strings.Split(str, ":")
86 if len(testCaseNameAndWeight) != 2 {
87 panic(fmt.Sprintf("invalid test case with weight: %s", str))
88 }
89
90 testCaseName := strings.ToLower(testCaseNameAndWeight[0])
91 switch testCaseName {
92 case
93 "empty_unary",
94 "large_unary",
95 "client_streaming",
96 "server_streaming",
97 "ping_pong",
98 "empty_stream",
99 "timeout_on_sleeping_server",
100 "cancel_after_begin",
101 "cancel_after_first_response",
102 "status_code_and_message",
103 "custom_metadata":
104 default:
105 panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
106 }
107 testCases[i].name = testCaseName
108 w, err := strconv.Atoi(testCaseNameAndWeight[1])
109 if err != nil {
110 panic(fmt.Sprintf("%v", err))
111 }
112 testCases[i].weight = w
113 }
114 return testCases
115 }
116
117
118 type weightedRandomTestSelector struct {
119 tests []testCaseWithWeight
120 totalWeight int
121 }
122
123
124 func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
125 var totalWeight int
126 for _, t := range tests {
127 totalWeight += t.weight
128 }
129 rand.Seed(time.Now().UnixNano())
130 return &weightedRandomTestSelector{tests, totalWeight}
131 }
132
133 func (selector weightedRandomTestSelector) getNextTest() string {
134 random := rand.Intn(selector.totalWeight)
135 var weightSofar int
136 for _, test := range selector.tests {
137 weightSofar += test.weight
138 if random < weightSofar {
139 return test.name
140 }
141 }
142 panic("no test case selected by weightedRandomTestSelector")
143 }
144
145
146 type gauge struct {
147 mutex sync.RWMutex
148 val int64
149 }
150
151 func (g *gauge) set(v int64) {
152 g.mutex.Lock()
153 defer g.mutex.Unlock()
154 g.val = v
155 }
156
157 func (g *gauge) get() int64 {
158 g.mutex.RLock()
159 defer g.mutex.RUnlock()
160 return g.val
161 }
162
163
164 type server struct {
165 metricspb.UnimplementedMetricsServiceServer
166 mutex sync.RWMutex
167
168 gauges map[string]*gauge
169 }
170
171
172 func newMetricsServer() *server {
173 return &server{gauges: make(map[string]*gauge)}
174 }
175
176
177 func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
178 s.mutex.RLock()
179 defer s.mutex.RUnlock()
180
181 for name, gauge := range s.gauges {
182 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
183 return err
184 }
185 }
186 return nil
187 }
188
189
190 func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
191 s.mutex.RLock()
192 defer s.mutex.RUnlock()
193
194 if g, ok := s.gauges[in.Name]; ok {
195 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
196 }
197 return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
198 }
199
200
201 func (s *server) createGauge(name string) *gauge {
202 s.mutex.Lock()
203 defer s.mutex.Unlock()
204
205 if _, ok := s.gauges[name]; ok {
206
207 panic(fmt.Sprintf("gauge %s already exists", name))
208 }
209 var g gauge
210 s.gauges[name] = &g
211 return &g
212 }
213
214 func startServer(server *server, port int) {
215 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
216 if err != nil {
217 logger.Fatalf("failed to listen: %v", err)
218 }
219
220 s := grpc.NewServer()
221 metricspb.RegisterMetricsServiceServer(s, server)
222 s.Serve(lis)
223 }
224
225
226 func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
227 client := testgrpc.NewTestServiceClient(conn)
228 var numCalls int64
229 ctx := context.Background()
230 startTime := time.Now()
231 for {
232 test := selector.getNextTest()
233 switch test {
234 case "empty_unary":
235 interop.DoEmptyUnaryCall(ctx, client)
236 case "large_unary":
237 interop.DoLargeUnaryCall(ctx, client)
238 case "client_streaming":
239 interop.DoClientStreaming(ctx, client)
240 case "server_streaming":
241 interop.DoServerStreaming(ctx, client)
242 case "ping_pong":
243 interop.DoPingPong(ctx, client)
244 case "empty_stream":
245 interop.DoEmptyStream(ctx, client)
246 case "timeout_on_sleeping_server":
247 interop.DoTimeoutOnSleepingServer(ctx, client)
248 case "cancel_after_begin":
249 interop.DoCancelAfterBegin(ctx, client)
250 case "cancel_after_first_response":
251 interop.DoCancelAfterFirstResponse(ctx, client)
252 case "status_code_and_message":
253 interop.DoStatusCodeAndMessage(ctx, client)
254 case "custom_metadata":
255 interop.DoCustomMetadata(ctx, client)
256 }
257 numCalls++
258 defer func() { atomic.AddInt64(&totalNumCalls, numCalls) }()
259 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
260
261 select {
262 case <-stop:
263 return
264 default:
265 }
266 }
267 }
268
269 func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
270 logger.Infof("server_addresses: %s", *serverAddresses)
271 logger.Infof("test_cases: %s", *testCases)
272 logger.Infof("test_duration_secs: %d", *testDurationSecs)
273 logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
274 logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
275 logger.Infof("metrics_port: %d", *metricsPort)
276 logger.Infof("use_tls: %t", *useTLS)
277 logger.Infof("use_test_ca: %t", *testCA)
278 logger.Infof("server_host_override: %s", *tlsServerName)
279 logger.Infof("custom_credentials_type: %s", *customCredentialsType)
280
281 logger.Infoln("addresses:")
282 for i, addr := range addresses {
283 logger.Infof("%d. %s\n", i+1, addr)
284 }
285 logger.Infoln("tests:")
286 for i, test := range tests {
287 logger.Infof("%d. %v\n", i+1, test)
288 }
289 }
290
291 func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
292 var opts []grpc.DialOption
293 if *customCredentialsType != "" {
294 if *customCredentialsType == googleDefaultCredsName {
295 opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
296 } else if *customCredentialsType == computeEngineCredsName {
297 opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
298 } else {
299 logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
300 }
301 } else if useTLS {
302 var sn string
303 if tlsServerName != "" {
304 sn = tlsServerName
305 }
306 var creds credentials.TransportCredentials
307 if testCA {
308 var err error
309 if *caFile == "" {
310 *caFile = testdata.Path("x509/server_ca_cert.pem")
311 }
312 creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
313 if err != nil {
314 logger.Fatalf("Failed to create TLS credentials: %v", err)
315 }
316 } else {
317 creds = credentials.NewClientTLSFromCert(nil, sn)
318 }
319 opts = append(opts, grpc.WithTransportCredentials(creds))
320 } else {
321 opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
322 }
323 return grpc.Dial(address, opts...)
324 }
325
326 func main() {
327 flag.Parse()
328 resolver.SetDefaultScheme("dns")
329 addresses := strings.Split(*serverAddresses, ",")
330 tests := parseTestCases(*testCases)
331 logParameterInfo(addresses, tests)
332 testSelector := newWeightedRandomTestSelector(tests)
333 metricsServer := newMetricsServer()
334
335 var wg sync.WaitGroup
336 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
337 stop := make(chan bool)
338
339 for serverIndex, address := range addresses {
340 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
341 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
342 if err != nil {
343 logger.Fatalf("Fail to dial: %v", err)
344 }
345 defer conn.Close()
346 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
347 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
348 go func() {
349 defer wg.Done()
350 g := metricsServer.createGauge(name)
351 performRPCs(g, conn, testSelector, stop)
352 }()
353 }
354
355 }
356 }
357 go startServer(metricsServer, *metricsPort)
358 if *testDurationSecs > 0 {
359 time.Sleep(time.Duration(*testDurationSecs) * time.Second)
360 close(stop)
361 }
362 wg.Wait()
363 fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
364 logger.Infof(" ===== ALL DONE ===== ")
365 }
366
View as plain text