...

Source file src/google.golang.org/grpc/interop/stress/client/main.go

Documentation: google.golang.org/grpc/interop/stress/client

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // client starts an interop client to do stress test and a metrics server to report qps.
    20  package main
    21  
    22  import (
    23  	"context"
    24  	"flag"
    25  	"fmt"
    26  	"math/rand"
    27  	"net"
    28  	"os"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/credentials/google"
    39  	"google.golang.org/grpc/credentials/insecure"
    40  	"google.golang.org/grpc/grpclog"
    41  	"google.golang.org/grpc/interop"
    42  	"google.golang.org/grpc/resolver"
    43  	"google.golang.org/grpc/status"
    44  	"google.golang.org/grpc/testdata"
    45  
    46  	_ "google.golang.org/grpc/xds/googledirectpath" // Register xDS resolver required for c2p directpath.
    47  
    48  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    49  	metricspb "google.golang.org/grpc/interop/stress/grpc_testing"
    50  )
    51  
    52  const (
    53  	googleDefaultCredsName = "google_default_credentials"
    54  	computeEngineCredsName = "compute_engine_channel_creds"
    55  )
    56  
    57  var (
    58  	serverAddresses       = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
    59  	testCases             = flag.String("test_cases", "", "a list of test cases along with the relative weights")
    60  	testDurationSecs      = flag.Int("test_duration_secs", -1, "test duration in seconds")
    61  	numChannelsPerServer  = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
    62  	numStubsPerChannel    = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
    63  	metricsPort           = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
    64  	useTLS                = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
    65  	testCA                = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
    66  	tlsServerName         = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
    67  	caFile                = flag.String("ca_file", "", "The file containing the CA root cert file")
    68  	customCredentialsType = flag.String("custom_credentials_type", "", "Custom credentials type to use")
    69  
    70  	totalNumCalls int64
    71  	logger        = grpclog.Component("stress")
    72  )
    73  
    74  // testCaseWithWeight contains the test case type and its weight.
    75  type testCaseWithWeight struct {
    76  	name   string
    77  	weight int
    78  }
    79  
    80  // parseTestCases converts test case string to a list of struct testCaseWithWeight.
    81  func parseTestCases(testCaseString string) []testCaseWithWeight {
    82  	testCaseStrings := strings.Split(testCaseString, ",")
    83  	testCases := make([]testCaseWithWeight, len(testCaseStrings))
    84  	for i, str := range testCaseStrings {
    85  		testCaseNameAndWeight := strings.Split(str, ":")
    86  		if len(testCaseNameAndWeight) != 2 {
    87  			panic(fmt.Sprintf("invalid test case with weight: %s", str))
    88  		}
    89  		// Check if test case is supported.
    90  		testCaseName := strings.ToLower(testCaseNameAndWeight[0])
    91  		switch testCaseName {
    92  		case
    93  			"empty_unary",
    94  			"large_unary",
    95  			"client_streaming",
    96  			"server_streaming",
    97  			"ping_pong",
    98  			"empty_stream",
    99  			"timeout_on_sleeping_server",
   100  			"cancel_after_begin",
   101  			"cancel_after_first_response",
   102  			"status_code_and_message",
   103  			"custom_metadata":
   104  		default:
   105  			panic(fmt.Sprintf("unknown test type: %s", testCaseNameAndWeight[0]))
   106  		}
   107  		testCases[i].name = testCaseName
   108  		w, err := strconv.Atoi(testCaseNameAndWeight[1])
   109  		if err != nil {
   110  			panic(fmt.Sprintf("%v", err))
   111  		}
   112  		testCases[i].weight = w
   113  	}
   114  	return testCases
   115  }
   116  
   117  // weightedRandomTestSelector defines a weighted random selector for test case types.
   118  type weightedRandomTestSelector struct {
   119  	tests       []testCaseWithWeight
   120  	totalWeight int
   121  }
   122  
   123  // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
   124  func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
   125  	var totalWeight int
   126  	for _, t := range tests {
   127  		totalWeight += t.weight
   128  	}
   129  	rand.Seed(time.Now().UnixNano())
   130  	return &weightedRandomTestSelector{tests, totalWeight}
   131  }
   132  
   133  func (selector weightedRandomTestSelector) getNextTest() string {
   134  	random := rand.Intn(selector.totalWeight)
   135  	var weightSofar int
   136  	for _, test := range selector.tests {
   137  		weightSofar += test.weight
   138  		if random < weightSofar {
   139  			return test.name
   140  		}
   141  	}
   142  	panic("no test case selected by weightedRandomTestSelector")
   143  }
   144  
   145  // gauge stores the qps of one interop client (one stub).
   146  type gauge struct {
   147  	mutex sync.RWMutex
   148  	val   int64
   149  }
   150  
   151  func (g *gauge) set(v int64) {
   152  	g.mutex.Lock()
   153  	defer g.mutex.Unlock()
   154  	g.val = v
   155  }
   156  
   157  func (g *gauge) get() int64 {
   158  	g.mutex.RLock()
   159  	defer g.mutex.RUnlock()
   160  	return g.val
   161  }
   162  
   163  // server implements metrics server functions.
   164  type server struct {
   165  	metricspb.UnimplementedMetricsServiceServer
   166  	mutex sync.RWMutex
   167  	// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
   168  	gauges map[string]*gauge
   169  }
   170  
   171  // newMetricsServer returns a new metrics server.
   172  func newMetricsServer() *server {
   173  	return &server{gauges: make(map[string]*gauge)}
   174  }
   175  
   176  // GetAllGauges returns all gauges.
   177  func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
   178  	s.mutex.RLock()
   179  	defer s.mutex.RUnlock()
   180  
   181  	for name, gauge := range s.gauges {
   182  		if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
   183  			return err
   184  		}
   185  	}
   186  	return nil
   187  }
   188  
   189  // GetGauge returns the gauge for the given name.
   190  func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
   191  	s.mutex.RLock()
   192  	defer s.mutex.RUnlock()
   193  
   194  	if g, ok := s.gauges[in.Name]; ok {
   195  		return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
   196  	}
   197  	return nil, status.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
   198  }
   199  
   200  // createGauge creates a gauge using the given name in metrics server.
   201  func (s *server) createGauge(name string) *gauge {
   202  	s.mutex.Lock()
   203  	defer s.mutex.Unlock()
   204  
   205  	if _, ok := s.gauges[name]; ok {
   206  		// gauge already exists.
   207  		panic(fmt.Sprintf("gauge %s already exists", name))
   208  	}
   209  	var g gauge
   210  	s.gauges[name] = &g
   211  	return &g
   212  }
   213  
   214  func startServer(server *server, port int) {
   215  	lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
   216  	if err != nil {
   217  		logger.Fatalf("failed to listen: %v", err)
   218  	}
   219  
   220  	s := grpc.NewServer()
   221  	metricspb.RegisterMetricsServiceServer(s, server)
   222  	s.Serve(lis)
   223  }
   224  
   225  // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
   226  func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
   227  	client := testgrpc.NewTestServiceClient(conn)
   228  	var numCalls int64
   229  	ctx := context.Background()
   230  	startTime := time.Now()
   231  	for {
   232  		test := selector.getNextTest()
   233  		switch test {
   234  		case "empty_unary":
   235  			interop.DoEmptyUnaryCall(ctx, client)
   236  		case "large_unary":
   237  			interop.DoLargeUnaryCall(ctx, client)
   238  		case "client_streaming":
   239  			interop.DoClientStreaming(ctx, client)
   240  		case "server_streaming":
   241  			interop.DoServerStreaming(ctx, client)
   242  		case "ping_pong":
   243  			interop.DoPingPong(ctx, client)
   244  		case "empty_stream":
   245  			interop.DoEmptyStream(ctx, client)
   246  		case "timeout_on_sleeping_server":
   247  			interop.DoTimeoutOnSleepingServer(ctx, client)
   248  		case "cancel_after_begin":
   249  			interop.DoCancelAfterBegin(ctx, client)
   250  		case "cancel_after_first_response":
   251  			interop.DoCancelAfterFirstResponse(ctx, client)
   252  		case "status_code_and_message":
   253  			interop.DoStatusCodeAndMessage(ctx, client)
   254  		case "custom_metadata":
   255  			interop.DoCustomMetadata(ctx, client)
   256  		}
   257  		numCalls++
   258  		defer func() { atomic.AddInt64(&totalNumCalls, numCalls) }()
   259  		gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
   260  
   261  		select {
   262  		case <-stop:
   263  			return
   264  		default:
   265  		}
   266  	}
   267  }
   268  
   269  func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
   270  	logger.Infof("server_addresses: %s", *serverAddresses)
   271  	logger.Infof("test_cases: %s", *testCases)
   272  	logger.Infof("test_duration_secs: %d", *testDurationSecs)
   273  	logger.Infof("num_channels_per_server: %d", *numChannelsPerServer)
   274  	logger.Infof("num_stubs_per_channel: %d", *numStubsPerChannel)
   275  	logger.Infof("metrics_port: %d", *metricsPort)
   276  	logger.Infof("use_tls: %t", *useTLS)
   277  	logger.Infof("use_test_ca: %t", *testCA)
   278  	logger.Infof("server_host_override: %s", *tlsServerName)
   279  	logger.Infof("custom_credentials_type: %s", *customCredentialsType)
   280  
   281  	logger.Infoln("addresses:")
   282  	for i, addr := range addresses {
   283  		logger.Infof("%d. %s\n", i+1, addr)
   284  	}
   285  	logger.Infoln("tests:")
   286  	for i, test := range tests {
   287  		logger.Infof("%d. %v\n", i+1, test)
   288  	}
   289  }
   290  
   291  func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
   292  	var opts []grpc.DialOption
   293  	if *customCredentialsType != "" {
   294  		if *customCredentialsType == googleDefaultCredsName {
   295  			opts = append(opts, grpc.WithCredentialsBundle(google.NewDefaultCredentials()))
   296  		} else if *customCredentialsType == computeEngineCredsName {
   297  			opts = append(opts, grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()))
   298  		} else {
   299  			logger.Fatalf("Unknown custom credentials: %v", *customCredentialsType)
   300  		}
   301  	} else if useTLS {
   302  		var sn string
   303  		if tlsServerName != "" {
   304  			sn = tlsServerName
   305  		}
   306  		var creds credentials.TransportCredentials
   307  		if testCA {
   308  			var err error
   309  			if *caFile == "" {
   310  				*caFile = testdata.Path("x509/server_ca_cert.pem")
   311  			}
   312  			creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
   313  			if err != nil {
   314  				logger.Fatalf("Failed to create TLS credentials: %v", err)
   315  			}
   316  		} else {
   317  			creds = credentials.NewClientTLSFromCert(nil, sn)
   318  		}
   319  		opts = append(opts, grpc.WithTransportCredentials(creds))
   320  	} else {
   321  		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
   322  	}
   323  	return grpc.Dial(address, opts...)
   324  }
   325  
   326  func main() {
   327  	flag.Parse()
   328  	resolver.SetDefaultScheme("dns")
   329  	addresses := strings.Split(*serverAddresses, ",")
   330  	tests := parseTestCases(*testCases)
   331  	logParameterInfo(addresses, tests)
   332  	testSelector := newWeightedRandomTestSelector(tests)
   333  	metricsServer := newMetricsServer()
   334  
   335  	var wg sync.WaitGroup
   336  	wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
   337  	stop := make(chan bool)
   338  
   339  	for serverIndex, address := range addresses {
   340  		for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
   341  			conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
   342  			if err != nil {
   343  				logger.Fatalf("Fail to dial: %v", err)
   344  			}
   345  			defer conn.Close()
   346  			for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
   347  				name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
   348  				go func() {
   349  					defer wg.Done()
   350  					g := metricsServer.createGauge(name)
   351  					performRPCs(g, conn, testSelector, stop)
   352  				}()
   353  			}
   354  
   355  		}
   356  	}
   357  	go startServer(metricsServer, *metricsPort)
   358  	if *testDurationSecs > 0 {
   359  		time.Sleep(time.Duration(*testDurationSecs) * time.Second)
   360  		close(stop)
   361  	}
   362  	wg.Wait()
   363  	fmt.Fprintf(os.Stdout, "Total calls made: %v\n", totalNumCalls)
   364  	logger.Infof(" ===== ALL DONE ===== ")
   365  }
   366  

View as plain text