...

Source file src/google.golang.org/grpc/stats/stats_test.go

Documentation: google.golang.org/grpc/stats

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package stats_test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/credentials/insecure"
    33  	"google.golang.org/grpc/internal"
    34  	"google.golang.org/grpc/internal/grpctest"
    35  	"google.golang.org/grpc/internal/stubserver"
    36  	"google.golang.org/grpc/internal/testutils"
    37  	"google.golang.org/grpc/metadata"
    38  	"google.golang.org/grpc/stats"
    39  	"google.golang.org/grpc/status"
    40  	"google.golang.org/protobuf/proto"
    41  
    42  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    43  	testpb "google.golang.org/grpc/interop/grpc_testing"
    44  )
    45  
    46  const defaultTestTimeout = 10 * time.Second
    47  
    48  type s struct {
    49  	grpctest.Tester
    50  }
    51  
    52  func Test(t *testing.T) {
    53  	grpctest.RunSubTests(t, s{})
    54  }
    55  
    56  func init() {
    57  	grpc.EnableTracing = false
    58  }
    59  
    60  type connCtxKey struct{}
    61  type rpcCtxKey struct{}
    62  
    63  var (
    64  	// For headers sent to server:
    65  	testMetadata = metadata.MD{
    66  		"key1":       []string{"value1"},
    67  		"key2":       []string{"value2"},
    68  		"user-agent": []string{fmt.Sprintf("test/0.0.1 grpc-go/%s", grpc.Version)},
    69  	}
    70  	// For headers sent from server:
    71  	testHeaderMetadata = metadata.MD{
    72  		"hkey1": []string{"headerValue1"},
    73  		"hkey2": []string{"headerValue2"},
    74  	}
    75  	// For trailers sent from server:
    76  	testTrailerMetadata = metadata.MD{
    77  		"tkey1": []string{"trailerValue1"},
    78  		"tkey2": []string{"trailerValue2"},
    79  	}
    80  	// The id for which the service handler should return error.
    81  	errorID int32 = 32202
    82  )
    83  
    84  func idToPayload(id int32) *testpb.Payload {
    85  	return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
    86  }
    87  
    88  func payloadToID(p *testpb.Payload) int32 {
    89  	if p == nil || len(p.Body) != 4 {
    90  		panic("invalid payload")
    91  	}
    92  	return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
    93  }
    94  
    95  type testServer struct {
    96  	testgrpc.UnimplementedTestServiceServer
    97  }
    98  
    99  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   100  	if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil {
   101  		return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testHeaderMetadata, err)
   102  	}
   103  	if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
   104  		return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
   105  	}
   106  
   107  	if id := payloadToID(in.Payload); id == errorID {
   108  		return nil, fmt.Errorf("got error id: %v", id)
   109  	}
   110  
   111  	return &testpb.SimpleResponse{Payload: in.Payload}, nil
   112  }
   113  
   114  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   115  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   116  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   117  	}
   118  	stream.SetTrailer(testTrailerMetadata)
   119  	for {
   120  		in, err := stream.Recv()
   121  		if err == io.EOF {
   122  			// read done.
   123  			return nil
   124  		}
   125  		if err != nil {
   126  			return err
   127  		}
   128  
   129  		if id := payloadToID(in.Payload); id == errorID {
   130  			return fmt.Errorf("got error id: %v", id)
   131  		}
   132  
   133  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   134  			return err
   135  		}
   136  	}
   137  }
   138  
   139  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   140  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   141  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   142  	}
   143  	stream.SetTrailer(testTrailerMetadata)
   144  	for {
   145  		in, err := stream.Recv()
   146  		if err == io.EOF {
   147  			// read done.
   148  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
   149  		}
   150  		if err != nil {
   151  			return err
   152  		}
   153  
   154  		if id := payloadToID(in.Payload); id == errorID {
   155  			return fmt.Errorf("got error id: %v", id)
   156  		}
   157  	}
   158  }
   159  
   160  func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   161  	if err := stream.SendHeader(testHeaderMetadata); err != nil {
   162  		return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil)
   163  	}
   164  	stream.SetTrailer(testTrailerMetadata)
   165  
   166  	if id := payloadToID(in.Payload); id == errorID {
   167  		return fmt.Errorf("got error id: %v", id)
   168  	}
   169  
   170  	for i := 0; i < 5; i++ {
   171  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   172  			return err
   173  		}
   174  	}
   175  	return nil
   176  }
   177  
   178  // test is an end-to-end test. It should be created with the newTest
   179  // func, modified as needed, and then started with its startServer method.
   180  // It should be cleaned up with the tearDown method.
   181  type test struct {
   182  	t                   *testing.T
   183  	compress            string
   184  	clientStatsHandlers []stats.Handler
   185  	serverStatsHandlers []stats.Handler
   186  
   187  	testServer testgrpc.TestServiceServer // nil means none
   188  	// srv and srvAddr are set once startServer is called.
   189  	srv     *grpc.Server
   190  	srvAddr string
   191  
   192  	cc *grpc.ClientConn // nil until requested via clientConn
   193  }
   194  
   195  func (te *test) tearDown() {
   196  	if te.cc != nil {
   197  		te.cc.Close()
   198  		te.cc = nil
   199  	}
   200  	te.srv.Stop()
   201  }
   202  
   203  type testConfig struct {
   204  	compress string
   205  }
   206  
   207  // newTest returns a new test using the provided testing.T and
   208  // environment.  It is returned with default values. Tests should
   209  // modify it before calling its startServer and clientConn methods.
   210  func newTest(t *testing.T, tc *testConfig, chs []stats.Handler, shs []stats.Handler) *test {
   211  	te := &test{
   212  		t:                   t,
   213  		compress:            tc.compress,
   214  		clientStatsHandlers: chs,
   215  		serverStatsHandlers: shs,
   216  	}
   217  	return te
   218  }
   219  
   220  // startServer starts a gRPC server listening. Callers should defer a
   221  // call to te.tearDown to clean up.
   222  func (te *test) startServer(ts testgrpc.TestServiceServer) {
   223  	te.testServer = ts
   224  	lis, err := net.Listen("tcp", "localhost:0")
   225  	if err != nil {
   226  		te.t.Fatalf("Failed to listen: %v", err)
   227  	}
   228  	var opts []grpc.ServerOption
   229  	if te.compress == "gzip" {
   230  		opts = append(opts,
   231  			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
   232  			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
   233  		)
   234  	}
   235  	for _, sh := range te.serverStatsHandlers {
   236  		opts = append(opts, grpc.StatsHandler(sh))
   237  	}
   238  	s := grpc.NewServer(opts...)
   239  	te.srv = s
   240  	if te.testServer != nil {
   241  		testgrpc.RegisterTestServiceServer(s, te.testServer)
   242  	}
   243  
   244  	go s.Serve(lis)
   245  	te.srvAddr = lis.Addr().String()
   246  }
   247  
   248  func (te *test) clientConn() *grpc.ClientConn {
   249  	if te.cc != nil {
   250  		return te.cc
   251  	}
   252  	opts := []grpc.DialOption{
   253  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   254  		grpc.WithBlock(),
   255  		grpc.WithUserAgent("test/0.0.1"),
   256  	}
   257  	if te.compress == "gzip" {
   258  		opts = append(opts,
   259  			grpc.WithCompressor(grpc.NewGZIPCompressor()),
   260  			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
   261  		)
   262  	}
   263  	for _, sh := range te.clientStatsHandlers {
   264  		opts = append(opts, grpc.WithStatsHandler(sh))
   265  	}
   266  
   267  	var err error
   268  	te.cc, err = grpc.Dial(te.srvAddr, opts...)
   269  	if err != nil {
   270  		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
   271  	}
   272  	return te.cc
   273  }
   274  
   275  type rpcType int
   276  
   277  const (
   278  	unaryRPC rpcType = iota
   279  	clientStreamRPC
   280  	serverStreamRPC
   281  	fullDuplexStreamRPC
   282  )
   283  
   284  type rpcConfig struct {
   285  	count    int  // Number of requests and responses for streaming RPCs.
   286  	success  bool // Whether the RPC should succeed or return error.
   287  	failfast bool
   288  	callType rpcType // Type of RPC.
   289  }
   290  
   291  func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
   292  	var (
   293  		resp *testpb.SimpleResponse
   294  		req  *testpb.SimpleRequest
   295  		err  error
   296  	)
   297  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   298  	if c.success {
   299  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
   300  	} else {
   301  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
   302  	}
   303  
   304  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   305  	defer cancel()
   306  	resp, err = tc.UnaryCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   307  	return req, resp, err
   308  }
   309  
   310  func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
   311  	var (
   312  		reqs  []proto.Message
   313  		resps []proto.Message
   314  		err   error
   315  	)
   316  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   317  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   318  	defer cancel()
   319  	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   320  	if err != nil {
   321  		return reqs, resps, err
   322  	}
   323  	var startID int32
   324  	if !c.success {
   325  		startID = errorID
   326  	}
   327  	for i := 0; i < c.count; i++ {
   328  		req := &testpb.StreamingOutputCallRequest{
   329  			Payload: idToPayload(int32(i) + startID),
   330  		}
   331  		reqs = append(reqs, req)
   332  		if err = stream.Send(req); err != nil {
   333  			return reqs, resps, err
   334  		}
   335  		var resp *testpb.StreamingOutputCallResponse
   336  		if resp, err = stream.Recv(); err != nil {
   337  			return reqs, resps, err
   338  		}
   339  		resps = append(resps, resp)
   340  	}
   341  	if err = stream.CloseSend(); err != nil && err != io.EOF {
   342  		return reqs, resps, err
   343  	}
   344  	if _, err = stream.Recv(); err != io.EOF {
   345  		return reqs, resps, err
   346  	}
   347  
   348  	return reqs, resps, nil
   349  }
   350  
   351  func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, *testpb.StreamingInputCallResponse, error) {
   352  	var (
   353  		reqs []proto.Message
   354  		resp *testpb.StreamingInputCallResponse
   355  		err  error
   356  	)
   357  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   358  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   359  	defer cancel()
   360  	stream, err := tc.StreamingInputCall(metadata.NewOutgoingContext(tCtx, testMetadata), grpc.WaitForReady(!c.failfast))
   361  	if err != nil {
   362  		return reqs, resp, err
   363  	}
   364  	var startID int32
   365  	if !c.success {
   366  		startID = errorID
   367  	}
   368  	for i := 0; i < c.count; i++ {
   369  		req := &testpb.StreamingInputCallRequest{
   370  			Payload: idToPayload(int32(i) + startID),
   371  		}
   372  		reqs = append(reqs, req)
   373  		if err = stream.Send(req); err != nil {
   374  			return reqs, resp, err
   375  		}
   376  	}
   377  	resp, err = stream.CloseAndRecv()
   378  	return reqs, resp, err
   379  }
   380  
   381  func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallRequest, []proto.Message, error) {
   382  	var (
   383  		req   *testpb.StreamingOutputCallRequest
   384  		resps []proto.Message
   385  		err   error
   386  	)
   387  
   388  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   389  
   390  	var startID int32
   391  	if !c.success {
   392  		startID = errorID
   393  	}
   394  	req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
   395  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   396  	defer cancel()
   397  	stream, err := tc.StreamingOutputCall(metadata.NewOutgoingContext(tCtx, testMetadata), req, grpc.WaitForReady(!c.failfast))
   398  	if err != nil {
   399  		return req, resps, err
   400  	}
   401  	for {
   402  		var resp *testpb.StreamingOutputCallResponse
   403  		resp, err := stream.Recv()
   404  		if err == io.EOF {
   405  			return req, resps, nil
   406  		} else if err != nil {
   407  			return req, resps, err
   408  		}
   409  		resps = append(resps, resp)
   410  	}
   411  }
   412  
   413  type expectedData struct {
   414  	method         string
   415  	isClientStream bool
   416  	isServerStream bool
   417  	serverAddr     string
   418  	compression    string
   419  	reqIdx         int
   420  	requests       []proto.Message
   421  	respIdx        int
   422  	responses      []proto.Message
   423  	err            error
   424  	failfast       bool
   425  }
   426  
   427  type gotData struct {
   428  	ctx    context.Context
   429  	client bool
   430  	s      any // This could be RPCStats or ConnStats.
   431  }
   432  
   433  const (
   434  	begin int = iota
   435  	end
   436  	inPayload
   437  	inHeader
   438  	inTrailer
   439  	outPayload
   440  	outHeader
   441  	// TODO: test outTrailer ?
   442  	connBegin
   443  	connEnd
   444  )
   445  
   446  func checkBegin(t *testing.T, d *gotData, e *expectedData) {
   447  	var (
   448  		ok bool
   449  		st *stats.Begin
   450  	)
   451  	if st, ok = d.s.(*stats.Begin); !ok {
   452  		t.Fatalf("got %T, want Begin", d.s)
   453  	}
   454  	if d.ctx == nil {
   455  		t.Fatalf("d.ctx = nil, want <non-nil>")
   456  	}
   457  	if st.BeginTime.IsZero() {
   458  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   459  	}
   460  	if d.client {
   461  		if st.FailFast != e.failfast {
   462  			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
   463  		}
   464  	}
   465  	if st.IsClientStream != e.isClientStream {
   466  		t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream)
   467  	}
   468  	if st.IsServerStream != e.isServerStream {
   469  		t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream)
   470  	}
   471  }
   472  
   473  func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
   474  	var (
   475  		ok bool
   476  		st *stats.InHeader
   477  	)
   478  	if st, ok = d.s.(*stats.InHeader); !ok {
   479  		t.Fatalf("got %T, want InHeader", d.s)
   480  	}
   481  	if d.ctx == nil {
   482  		t.Fatalf("d.ctx = nil, want <non-nil>")
   483  	}
   484  	if st.Compression != e.compression {
   485  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   486  	}
   487  	if d.client {
   488  		// additional headers might be injected so instead of testing equality, test that all the
   489  		// expected headers keys have the expected header values.
   490  		for key := range testHeaderMetadata {
   491  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   492  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   493  			}
   494  		}
   495  	} else {
   496  		if st.FullMethod != e.method {
   497  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   498  		}
   499  		if st.LocalAddr.String() != e.serverAddr {
   500  			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
   501  		}
   502  		// additional headers might be injected so instead of testing equality, test that all the
   503  		// expected headers keys have the expected header values.
   504  		for key := range testMetadata {
   505  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   506  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   507  			}
   508  		}
   509  
   510  		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
   511  			if connInfo.RemoteAddr != st.RemoteAddr {
   512  				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
   513  			}
   514  			if connInfo.LocalAddr != st.LocalAddr {
   515  				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
   516  			}
   517  		} else {
   518  			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
   519  		}
   520  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   521  			if rpcInfo.FullMethodName != st.FullMethod {
   522  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   523  			}
   524  		} else {
   525  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   526  		}
   527  	}
   528  }
   529  
   530  func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
   531  	var (
   532  		ok bool
   533  		st *stats.InPayload
   534  	)
   535  	if st, ok = d.s.(*stats.InPayload); !ok {
   536  		t.Fatalf("got %T, want InPayload", d.s)
   537  	}
   538  	if d.ctx == nil {
   539  		t.Fatalf("d.ctx = nil, want <non-nil>")
   540  	}
   541  	if d.client {
   542  		b, err := proto.Marshal(e.responses[e.respIdx])
   543  		if err != nil {
   544  			t.Fatalf("failed to marshal message: %v", err)
   545  		}
   546  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
   547  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
   548  		}
   549  		e.respIdx++
   550  		if string(st.Data) != string(b) {
   551  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   552  		}
   553  		if st.Length != len(b) {
   554  			t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
   555  		}
   556  	} else {
   557  		b, err := proto.Marshal(e.requests[e.reqIdx])
   558  		if err != nil {
   559  			t.Fatalf("failed to marshal message: %v", err)
   560  		}
   561  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
   562  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
   563  		}
   564  		e.reqIdx++
   565  		if string(st.Data) != string(b) {
   566  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   567  		}
   568  		if st.Length != len(b) {
   569  			t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
   570  		}
   571  	}
   572  	// Below are sanity checks that WireLength and RecvTime are populated.
   573  	// TODO: check values of WireLength and RecvTime.
   574  	if len(st.Data) > 0 && st.CompressedLength == 0 {
   575  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   576  			st.CompressedLength)
   577  	}
   578  	if st.RecvTime.IsZero() {
   579  		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
   580  	}
   581  }
   582  
   583  func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
   584  	var (
   585  		ok bool
   586  		st *stats.InTrailer
   587  	)
   588  	if st, ok = d.s.(*stats.InTrailer); !ok {
   589  		t.Fatalf("got %T, want InTrailer", d.s)
   590  	}
   591  	if d.ctx == nil {
   592  		t.Fatalf("d.ctx = nil, want <non-nil>")
   593  	}
   594  	if !st.Client {
   595  		t.Fatalf("st IsClient = false, want true")
   596  	}
   597  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   598  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   599  	}
   600  }
   601  
   602  func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
   603  	var (
   604  		ok bool
   605  		st *stats.OutHeader
   606  	)
   607  	if st, ok = d.s.(*stats.OutHeader); !ok {
   608  		t.Fatalf("got %T, want OutHeader", d.s)
   609  	}
   610  	if d.ctx == nil {
   611  		t.Fatalf("d.ctx = nil, want <non-nil>")
   612  	}
   613  	if st.Compression != e.compression {
   614  		t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
   615  	}
   616  	if d.client {
   617  		if st.FullMethod != e.method {
   618  			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
   619  		}
   620  		if st.RemoteAddr.String() != e.serverAddr {
   621  			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
   622  		}
   623  		// additional headers might be injected so instead of testing equality, test that all the
   624  		// expected headers keys have the expected header values.
   625  		for key := range testMetadata {
   626  			if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) {
   627  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key))
   628  			}
   629  		}
   630  
   631  		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
   632  			if rpcInfo.FullMethodName != st.FullMethod {
   633  				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
   634  			}
   635  		} else {
   636  			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
   637  		}
   638  	} else {
   639  		// additional headers might be injected so instead of testing equality, test that all the
   640  		// expected headers keys have the expected header values.
   641  		for key := range testHeaderMetadata {
   642  			if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) {
   643  				t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key))
   644  			}
   645  		}
   646  	}
   647  }
   648  
   649  func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
   650  	var (
   651  		ok bool
   652  		st *stats.OutPayload
   653  	)
   654  	if st, ok = d.s.(*stats.OutPayload); !ok {
   655  		t.Fatalf("got %T, want OutPayload", d.s)
   656  	}
   657  	if d.ctx == nil {
   658  		t.Fatalf("d.ctx = nil, want <non-nil>")
   659  	}
   660  	if d.client {
   661  		b, err := proto.Marshal(e.requests[e.reqIdx])
   662  		if err != nil {
   663  			t.Fatalf("failed to marshal message: %v", err)
   664  		}
   665  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
   666  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
   667  		}
   668  		e.reqIdx++
   669  		if string(st.Data) != string(b) {
   670  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   671  		}
   672  		if st.Length != len(b) {
   673  			t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
   674  		}
   675  	} else {
   676  		b, err := proto.Marshal(e.responses[e.respIdx])
   677  		if err != nil {
   678  			t.Fatalf("failed to marshal message: %v", err)
   679  		}
   680  		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
   681  			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
   682  		}
   683  		e.respIdx++
   684  		if string(st.Data) != string(b) {
   685  			t.Fatalf("st.Data = %v, want %v", st.Data, b)
   686  		}
   687  		if st.Length != len(b) {
   688  			t.Fatalf("st.Length = %v, want %v", st.Length, len(b))
   689  		}
   690  	}
   691  	// Below are sanity checks that WireLength and SentTime are populated.
   692  	// TODO: check values of WireLength and SentTime.
   693  	if len(st.Data) > 0 && st.WireLength == 0 {
   694  		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
   695  			st.WireLength)
   696  	}
   697  	if st.SentTime.IsZero() {
   698  		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
   699  	}
   700  }
   701  
   702  func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
   703  	var (
   704  		ok bool
   705  		st *stats.OutTrailer
   706  	)
   707  	if st, ok = d.s.(*stats.OutTrailer); !ok {
   708  		t.Fatalf("got %T, want OutTrailer", d.s)
   709  	}
   710  	if d.ctx == nil {
   711  		t.Fatalf("d.ctx = nil, want <non-nil>")
   712  	}
   713  	if st.Client {
   714  		t.Fatalf("st IsClient = true, want false")
   715  	}
   716  	if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   717  		t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   718  	}
   719  }
   720  
   721  func checkEnd(t *testing.T, d *gotData, e *expectedData) {
   722  	var (
   723  		ok bool
   724  		st *stats.End
   725  	)
   726  	if st, ok = d.s.(*stats.End); !ok {
   727  		t.Fatalf("got %T, want End", d.s)
   728  	}
   729  	if d.ctx == nil {
   730  		t.Fatalf("d.ctx = nil, want <non-nil>")
   731  	}
   732  	if st.BeginTime.IsZero() {
   733  		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
   734  	}
   735  	if st.EndTime.IsZero() {
   736  		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
   737  	}
   738  
   739  	actual, ok := status.FromError(st.Error)
   740  	if !ok {
   741  		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
   742  	}
   743  
   744  	expectedStatus, _ := status.FromError(e.err)
   745  	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
   746  		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
   747  	}
   748  
   749  	if st.Client {
   750  		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
   751  			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
   752  		}
   753  	} else {
   754  		if st.Trailer != nil {
   755  			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
   756  		}
   757  	}
   758  }
   759  
   760  func checkConnBegin(t *testing.T, d *gotData) {
   761  	var (
   762  		ok bool
   763  		st *stats.ConnBegin
   764  	)
   765  	if st, ok = d.s.(*stats.ConnBegin); !ok {
   766  		t.Fatalf("got %T, want ConnBegin", d.s)
   767  	}
   768  	if d.ctx == nil {
   769  		t.Fatalf("d.ctx = nil, want <non-nil>")
   770  	}
   771  	st.IsClient() // TODO remove this.
   772  }
   773  
   774  func checkConnEnd(t *testing.T, d *gotData) {
   775  	var (
   776  		ok bool
   777  		st *stats.ConnEnd
   778  	)
   779  	if st, ok = d.s.(*stats.ConnEnd); !ok {
   780  		t.Fatalf("got %T, want ConnEnd", d.s)
   781  	}
   782  	if d.ctx == nil {
   783  		t.Fatalf("d.ctx = nil, want <non-nil>")
   784  	}
   785  	st.IsClient() // TODO remove this.
   786  }
   787  
   788  type statshandler struct {
   789  	mu      sync.Mutex
   790  	gotRPC  []*gotData
   791  	gotConn []*gotData
   792  }
   793  
   794  func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
   795  	return context.WithValue(ctx, connCtxKey{}, info)
   796  }
   797  
   798  func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
   799  	return context.WithValue(ctx, rpcCtxKey{}, info)
   800  }
   801  
   802  func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
   803  	h.mu.Lock()
   804  	defer h.mu.Unlock()
   805  	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
   806  }
   807  
   808  func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
   809  	h.mu.Lock()
   810  	defer h.mu.Unlock()
   811  	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
   812  }
   813  
   814  func checkConnStats(t *testing.T, got []*gotData) {
   815  	if len(got) <= 0 || len(got)%2 != 0 {
   816  		for i, g := range got {
   817  			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
   818  		}
   819  		t.Fatalf("got %v stats, want even positive number", len(got))
   820  	}
   821  	// The first conn stats must be a ConnBegin.
   822  	checkConnBegin(t, got[0])
   823  	// The last conn stats must be a ConnEnd.
   824  	checkConnEnd(t, got[len(got)-1])
   825  }
   826  
   827  func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   828  	if len(got) != len(checkFuncs) {
   829  		for i, g := range got {
   830  			t.Errorf(" - %v, %T", i, g.s)
   831  		}
   832  		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
   833  	}
   834  
   835  	for i, f := range checkFuncs {
   836  		f(t, got[i], expect)
   837  	}
   838  }
   839  
   840  func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
   841  	h := &statshandler{}
   842  	te := newTest(t, tc, nil, []stats.Handler{h})
   843  	te.startServer(&testServer{})
   844  	defer te.tearDown()
   845  
   846  	var (
   847  		reqs   []proto.Message
   848  		resps  []proto.Message
   849  		err    error
   850  		method string
   851  
   852  		isClientStream bool
   853  		isServerStream bool
   854  
   855  		req  proto.Message
   856  		resp proto.Message
   857  		e    error
   858  	)
   859  
   860  	switch cc.callType {
   861  	case unaryRPC:
   862  		method = "/grpc.testing.TestService/UnaryCall"
   863  		req, resp, e = te.doUnaryCall(cc)
   864  		reqs = []proto.Message{req}
   865  		resps = []proto.Message{resp}
   866  		err = e
   867  	case clientStreamRPC:
   868  		method = "/grpc.testing.TestService/StreamingInputCall"
   869  		reqs, resp, e = te.doClientStreamCall(cc)
   870  		resps = []proto.Message{resp}
   871  		err = e
   872  		isClientStream = true
   873  	case serverStreamRPC:
   874  		method = "/grpc.testing.TestService/StreamingOutputCall"
   875  		req, resps, e = te.doServerStreamCall(cc)
   876  		reqs = []proto.Message{req}
   877  		err = e
   878  		isServerStream = true
   879  	case fullDuplexStreamRPC:
   880  		method = "/grpc.testing.TestService/FullDuplexCall"
   881  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
   882  		isClientStream = true
   883  		isServerStream = true
   884  	}
   885  	if cc.success != (err == nil) {
   886  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
   887  	}
   888  	te.cc.Close()
   889  	te.srv.GracefulStop() // Wait for the server to stop.
   890  
   891  	for {
   892  		h.mu.Lock()
   893  		if len(h.gotRPC) >= len(checkFuncs) {
   894  			h.mu.Unlock()
   895  			break
   896  		}
   897  		h.mu.Unlock()
   898  		time.Sleep(10 * time.Millisecond)
   899  	}
   900  
   901  	for {
   902  		h.mu.Lock()
   903  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
   904  			h.mu.Unlock()
   905  			break
   906  		}
   907  		h.mu.Unlock()
   908  		time.Sleep(10 * time.Millisecond)
   909  	}
   910  
   911  	expect := &expectedData{
   912  		serverAddr:     te.srvAddr,
   913  		compression:    tc.compress,
   914  		method:         method,
   915  		requests:       reqs,
   916  		responses:      resps,
   917  		err:            err,
   918  		isClientStream: isClientStream,
   919  		isServerStream: isServerStream,
   920  	}
   921  
   922  	h.mu.Lock()
   923  	checkConnStats(t, h.gotConn)
   924  	h.mu.Unlock()
   925  	checkServerStats(t, h.gotRPC, expect, checkFuncs)
   926  }
   927  
   928  func (s) TestServerStatsUnaryRPC(t *testing.T) {
   929  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   930  		checkInHeader,
   931  		checkBegin,
   932  		checkInPayload,
   933  		checkOutHeader,
   934  		checkOutPayload,
   935  		checkOutTrailer,
   936  		checkEnd,
   937  	})
   938  }
   939  
   940  func (s) TestServerStatsUnaryRPCError(t *testing.T) {
   941  	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   942  		checkInHeader,
   943  		checkBegin,
   944  		checkInPayload,
   945  		checkOutHeader,
   946  		checkOutTrailer,
   947  		checkEnd,
   948  	})
   949  }
   950  
   951  func (s) TestServerStatsClientStreamRPC(t *testing.T) {
   952  	count := 5
   953  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   954  		checkInHeader,
   955  		checkBegin,
   956  		checkOutHeader,
   957  	}
   958  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   959  		checkInPayload,
   960  	}
   961  	for i := 0; i < count; i++ {
   962  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   963  	}
   964  	checkFuncs = append(checkFuncs,
   965  		checkOutPayload,
   966  		checkOutTrailer,
   967  		checkEnd,
   968  	)
   969  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
   970  }
   971  
   972  func (s) TestServerStatsClientStreamRPCError(t *testing.T) {
   973  	count := 1
   974  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
   975  		checkInHeader,
   976  		checkBegin,
   977  		checkOutHeader,
   978  		checkInPayload,
   979  		checkOutTrailer,
   980  		checkEnd,
   981  	})
   982  }
   983  
   984  func (s) TestServerStatsServerStreamRPC(t *testing.T) {
   985  	count := 5
   986  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   987  		checkInHeader,
   988  		checkBegin,
   989  		checkInPayload,
   990  		checkOutHeader,
   991  	}
   992  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
   993  		checkOutPayload,
   994  	}
   995  	for i := 0; i < count; i++ {
   996  		checkFuncs = append(checkFuncs, ioPayFuncs...)
   997  	}
   998  	checkFuncs = append(checkFuncs,
   999  		checkOutTrailer,
  1000  		checkEnd,
  1001  	)
  1002  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
  1003  }
  1004  
  1005  func (s) TestServerStatsServerStreamRPCError(t *testing.T) {
  1006  	count := 5
  1007  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1008  		checkInHeader,
  1009  		checkBegin,
  1010  		checkInPayload,
  1011  		checkOutHeader,
  1012  		checkOutTrailer,
  1013  		checkEnd,
  1014  	})
  1015  }
  1016  
  1017  func (s) TestServerStatsFullDuplexRPC(t *testing.T) {
  1018  	count := 5
  1019  	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1020  		checkInHeader,
  1021  		checkBegin,
  1022  		checkOutHeader,
  1023  	}
  1024  	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
  1025  		checkInPayload,
  1026  		checkOutPayload,
  1027  	}
  1028  	for i := 0; i < count; i++ {
  1029  		checkFuncs = append(checkFuncs, ioPayFuncs...)
  1030  	}
  1031  	checkFuncs = append(checkFuncs,
  1032  		checkOutTrailer,
  1033  		checkEnd,
  1034  	)
  1035  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
  1036  }
  1037  
  1038  func (s) TestServerStatsFullDuplexRPCError(t *testing.T) {
  1039  	count := 5
  1040  	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
  1041  		checkInHeader,
  1042  		checkBegin,
  1043  		checkOutHeader,
  1044  		checkInPayload,
  1045  		checkOutTrailer,
  1046  		checkEnd,
  1047  	})
  1048  }
  1049  
  1050  type checkFuncWithCount struct {
  1051  	f func(t *testing.T, d *gotData, e *expectedData)
  1052  	c int // expected count
  1053  }
  1054  
  1055  func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
  1056  	var expectLen int
  1057  	for _, v := range checkFuncs {
  1058  		expectLen += v.c
  1059  	}
  1060  	if len(got) != expectLen {
  1061  		for i, g := range got {
  1062  			t.Errorf(" - %v, %T", i, g.s)
  1063  		}
  1064  		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
  1065  	}
  1066  
  1067  	var tagInfoInCtx *stats.RPCTagInfo
  1068  	for i := 0; i < len(got); i++ {
  1069  		if _, ok := got[i].s.(stats.RPCStats); ok {
  1070  			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
  1071  			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
  1072  				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
  1073  			}
  1074  			tagInfoInCtx = tagInfoInCtxNew
  1075  		}
  1076  	}
  1077  
  1078  	for _, s := range got {
  1079  		switch s.s.(type) {
  1080  		case *stats.Begin:
  1081  			if checkFuncs[begin].c <= 0 {
  1082  				t.Fatalf("unexpected stats: %T", s.s)
  1083  			}
  1084  			checkFuncs[begin].f(t, s, expect)
  1085  			checkFuncs[begin].c--
  1086  		case *stats.OutHeader:
  1087  			if checkFuncs[outHeader].c <= 0 {
  1088  				t.Fatalf("unexpected stats: %T", s.s)
  1089  			}
  1090  			checkFuncs[outHeader].f(t, s, expect)
  1091  			checkFuncs[outHeader].c--
  1092  		case *stats.OutPayload:
  1093  			if checkFuncs[outPayload].c <= 0 {
  1094  				t.Fatalf("unexpected stats: %T", s.s)
  1095  			}
  1096  			checkFuncs[outPayload].f(t, s, expect)
  1097  			checkFuncs[outPayload].c--
  1098  		case *stats.InHeader:
  1099  			if checkFuncs[inHeader].c <= 0 {
  1100  				t.Fatalf("unexpected stats: %T", s.s)
  1101  			}
  1102  			checkFuncs[inHeader].f(t, s, expect)
  1103  			checkFuncs[inHeader].c--
  1104  		case *stats.InPayload:
  1105  			if checkFuncs[inPayload].c <= 0 {
  1106  				t.Fatalf("unexpected stats: %T", s.s)
  1107  			}
  1108  			checkFuncs[inPayload].f(t, s, expect)
  1109  			checkFuncs[inPayload].c--
  1110  		case *stats.InTrailer:
  1111  			if checkFuncs[inTrailer].c <= 0 {
  1112  				t.Fatalf("unexpected stats: %T", s.s)
  1113  			}
  1114  			checkFuncs[inTrailer].f(t, s, expect)
  1115  			checkFuncs[inTrailer].c--
  1116  		case *stats.End:
  1117  			if checkFuncs[end].c <= 0 {
  1118  				t.Fatalf("unexpected stats: %T", s.s)
  1119  			}
  1120  			checkFuncs[end].f(t, s, expect)
  1121  			checkFuncs[end].c--
  1122  		case *stats.ConnBegin:
  1123  			if checkFuncs[connBegin].c <= 0 {
  1124  				t.Fatalf("unexpected stats: %T", s.s)
  1125  			}
  1126  			checkFuncs[connBegin].f(t, s, expect)
  1127  			checkFuncs[connBegin].c--
  1128  		case *stats.ConnEnd:
  1129  			if checkFuncs[connEnd].c <= 0 {
  1130  				t.Fatalf("unexpected stats: %T", s.s)
  1131  			}
  1132  			checkFuncs[connEnd].f(t, s, expect)
  1133  			checkFuncs[connEnd].c--
  1134  		default:
  1135  			t.Fatalf("unexpected stats: %T", s.s)
  1136  		}
  1137  	}
  1138  }
  1139  
  1140  func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
  1141  	h := &statshandler{}
  1142  	te := newTest(t, tc, []stats.Handler{h}, nil)
  1143  	te.startServer(&testServer{})
  1144  	defer te.tearDown()
  1145  
  1146  	var (
  1147  		reqs   []proto.Message
  1148  		resps  []proto.Message
  1149  		method string
  1150  		err    error
  1151  
  1152  		isClientStream bool
  1153  		isServerStream bool
  1154  
  1155  		req  proto.Message
  1156  		resp proto.Message
  1157  		e    error
  1158  	)
  1159  	switch cc.callType {
  1160  	case unaryRPC:
  1161  		method = "/grpc.testing.TestService/UnaryCall"
  1162  		req, resp, e = te.doUnaryCall(cc)
  1163  		reqs = []proto.Message{req}
  1164  		resps = []proto.Message{resp}
  1165  		err = e
  1166  	case clientStreamRPC:
  1167  		method = "/grpc.testing.TestService/StreamingInputCall"
  1168  		reqs, resp, e = te.doClientStreamCall(cc)
  1169  		resps = []proto.Message{resp}
  1170  		err = e
  1171  		isClientStream = true
  1172  	case serverStreamRPC:
  1173  		method = "/grpc.testing.TestService/StreamingOutputCall"
  1174  		req, resps, e = te.doServerStreamCall(cc)
  1175  		reqs = []proto.Message{req}
  1176  		err = e
  1177  		isServerStream = true
  1178  	case fullDuplexStreamRPC:
  1179  		method = "/grpc.testing.TestService/FullDuplexCall"
  1180  		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
  1181  		isClientStream = true
  1182  		isServerStream = true
  1183  	}
  1184  	if cc.success != (err == nil) {
  1185  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1186  	}
  1187  	te.cc.Close()
  1188  	te.srv.GracefulStop() // Wait for the server to stop.
  1189  
  1190  	lenRPCStats := 0
  1191  	for _, v := range checkFuncs {
  1192  		lenRPCStats += v.c
  1193  	}
  1194  	for {
  1195  		h.mu.Lock()
  1196  		if len(h.gotRPC) >= lenRPCStats {
  1197  			h.mu.Unlock()
  1198  			break
  1199  		}
  1200  		h.mu.Unlock()
  1201  		time.Sleep(10 * time.Millisecond)
  1202  	}
  1203  
  1204  	for {
  1205  		h.mu.Lock()
  1206  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1207  			h.mu.Unlock()
  1208  			break
  1209  		}
  1210  		h.mu.Unlock()
  1211  		time.Sleep(10 * time.Millisecond)
  1212  	}
  1213  
  1214  	expect := &expectedData{
  1215  		serverAddr:     te.srvAddr,
  1216  		compression:    tc.compress,
  1217  		method:         method,
  1218  		requests:       reqs,
  1219  		responses:      resps,
  1220  		failfast:       cc.failfast,
  1221  		err:            err,
  1222  		isClientStream: isClientStream,
  1223  		isServerStream: isServerStream,
  1224  	}
  1225  
  1226  	h.mu.Lock()
  1227  	checkConnStats(t, h.gotConn)
  1228  	h.mu.Unlock()
  1229  	checkClientStats(t, h.gotRPC, expect, checkFuncs)
  1230  }
  1231  
  1232  func (s) TestClientStatsUnaryRPC(t *testing.T) {
  1233  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1234  		begin:      {checkBegin, 1},
  1235  		outHeader:  {checkOutHeader, 1},
  1236  		outPayload: {checkOutPayload, 1},
  1237  		inHeader:   {checkInHeader, 1},
  1238  		inPayload:  {checkInPayload, 1},
  1239  		inTrailer:  {checkInTrailer, 1},
  1240  		end:        {checkEnd, 1},
  1241  	})
  1242  }
  1243  
  1244  func (s) TestClientStatsUnaryRPCError(t *testing.T) {
  1245  	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
  1246  		begin:      {checkBegin, 1},
  1247  		outHeader:  {checkOutHeader, 1},
  1248  		outPayload: {checkOutPayload, 1},
  1249  		inHeader:   {checkInHeader, 1},
  1250  		inTrailer:  {checkInTrailer, 1},
  1251  		end:        {checkEnd, 1},
  1252  	})
  1253  }
  1254  
  1255  func (s) TestClientStatsClientStreamRPC(t *testing.T) {
  1256  	count := 5
  1257  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1258  		begin:      {checkBegin, 1},
  1259  		outHeader:  {checkOutHeader, 1},
  1260  		inHeader:   {checkInHeader, 1},
  1261  		outPayload: {checkOutPayload, count},
  1262  		inTrailer:  {checkInTrailer, 1},
  1263  		inPayload:  {checkInPayload, 1},
  1264  		end:        {checkEnd, 1},
  1265  	})
  1266  }
  1267  
  1268  func (s) TestClientStatsClientStreamRPCError(t *testing.T) {
  1269  	count := 1
  1270  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
  1271  		begin:      {checkBegin, 1},
  1272  		outHeader:  {checkOutHeader, 1},
  1273  		inHeader:   {checkInHeader, 1},
  1274  		outPayload: {checkOutPayload, 1},
  1275  		inTrailer:  {checkInTrailer, 1},
  1276  		end:        {checkEnd, 1},
  1277  	})
  1278  }
  1279  
  1280  func (s) TestClientStatsServerStreamRPC(t *testing.T) {
  1281  	count := 5
  1282  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1283  		begin:      {checkBegin, 1},
  1284  		outHeader:  {checkOutHeader, 1},
  1285  		outPayload: {checkOutPayload, 1},
  1286  		inHeader:   {checkInHeader, 1},
  1287  		inPayload:  {checkInPayload, count},
  1288  		inTrailer:  {checkInTrailer, 1},
  1289  		end:        {checkEnd, 1},
  1290  	})
  1291  }
  1292  
  1293  func (s) TestClientStatsServerStreamRPCError(t *testing.T) {
  1294  	count := 5
  1295  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
  1296  		begin:      {checkBegin, 1},
  1297  		outHeader:  {checkOutHeader, 1},
  1298  		outPayload: {checkOutPayload, 1},
  1299  		inHeader:   {checkInHeader, 1},
  1300  		inTrailer:  {checkInTrailer, 1},
  1301  		end:        {checkEnd, 1},
  1302  	})
  1303  }
  1304  
  1305  func (s) TestClientStatsFullDuplexRPC(t *testing.T) {
  1306  	count := 5
  1307  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1308  		begin:      {checkBegin, 1},
  1309  		outHeader:  {checkOutHeader, 1},
  1310  		outPayload: {checkOutPayload, count},
  1311  		inHeader:   {checkInHeader, 1},
  1312  		inPayload:  {checkInPayload, count},
  1313  		inTrailer:  {checkInTrailer, 1},
  1314  		end:        {checkEnd, 1},
  1315  	})
  1316  }
  1317  
  1318  func (s) TestClientStatsFullDuplexRPCError(t *testing.T) {
  1319  	count := 5
  1320  	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
  1321  		begin:      {checkBegin, 1},
  1322  		outHeader:  {checkOutHeader, 1},
  1323  		outPayload: {checkOutPayload, 1},
  1324  		inHeader:   {checkInHeader, 1},
  1325  		inTrailer:  {checkInTrailer, 1},
  1326  		end:        {checkEnd, 1},
  1327  	})
  1328  }
  1329  
  1330  func (s) TestTags(t *testing.T) {
  1331  	b := []byte{5, 2, 4, 3, 1}
  1332  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1333  	defer cancel()
  1334  	ctx := stats.SetTags(tCtx, b)
  1335  	if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) {
  1336  		t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b)
  1337  	}
  1338  	if tg := stats.Tags(ctx); tg != nil {
  1339  		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
  1340  	}
  1341  
  1342  	ctx = stats.SetIncomingTags(tCtx, b)
  1343  	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
  1344  		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
  1345  	}
  1346  	if tg := stats.OutgoingTags(ctx); tg != nil {
  1347  		t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg)
  1348  	}
  1349  }
  1350  
  1351  func (s) TestTrace(t *testing.T) {
  1352  	b := []byte{5, 2, 4, 3, 1}
  1353  	tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1354  	defer cancel()
  1355  	ctx := stats.SetTrace(tCtx, b)
  1356  	if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) {
  1357  		t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b)
  1358  	}
  1359  	if tr := stats.Trace(ctx); tr != nil {
  1360  		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
  1361  	}
  1362  
  1363  	ctx = stats.SetIncomingTrace(tCtx, b)
  1364  	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
  1365  		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
  1366  	}
  1367  	if tr := stats.OutgoingTrace(ctx); tr != nil {
  1368  		t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr)
  1369  	}
  1370  }
  1371  
  1372  func (s) TestMultipleClientStatsHandler(t *testing.T) {
  1373  	h := &statshandler{}
  1374  	tc := &testConfig{compress: ""}
  1375  	te := newTest(t, tc, []stats.Handler{h, h}, nil)
  1376  	te.startServer(&testServer{})
  1377  	defer te.tearDown()
  1378  
  1379  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1380  	_, _, err := te.doUnaryCall(cc)
  1381  	if cc.success != (err == nil) {
  1382  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1383  	}
  1384  	te.cc.Close()
  1385  	te.srv.GracefulStop() // Wait for the server to stop.
  1386  
  1387  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1388  		h.mu.Lock()
  1389  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok && len(h.gotRPC) == 12 {
  1390  			h.mu.Unlock()
  1391  			break
  1392  		}
  1393  		h.mu.Unlock()
  1394  		time.Sleep(10 * time.Millisecond)
  1395  	}
  1396  
  1397  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1398  		h.mu.Lock()
  1399  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok && len(h.gotConn) == 4 {
  1400  			h.mu.Unlock()
  1401  			break
  1402  		}
  1403  		h.mu.Unlock()
  1404  		time.Sleep(10 * time.Millisecond)
  1405  	}
  1406  
  1407  	// Each RPC generates 6 stats events on the client-side, times 2 StatsHandler
  1408  	if len(h.gotRPC) != 12 {
  1409  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1410  	}
  1411  
  1412  	// Each connection generates 4 conn events on the client-side, times 2 StatsHandler
  1413  	if len(h.gotConn) != 4 {
  1414  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1415  	}
  1416  }
  1417  
  1418  func (s) TestMultipleServerStatsHandler(t *testing.T) {
  1419  	h := &statshandler{}
  1420  	tc := &testConfig{compress: ""}
  1421  	te := newTest(t, tc, nil, []stats.Handler{h, h})
  1422  	te.startServer(&testServer{})
  1423  	defer te.tearDown()
  1424  
  1425  	cc := &rpcConfig{success: false, failfast: false, callType: unaryRPC}
  1426  	_, _, err := te.doUnaryCall(cc)
  1427  	if cc.success != (err == nil) {
  1428  		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
  1429  	}
  1430  	te.cc.Close()
  1431  	te.srv.GracefulStop() // Wait for the server to stop.
  1432  
  1433  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1434  		h.mu.Lock()
  1435  		if _, ok := h.gotRPC[len(h.gotRPC)-1].s.(*stats.End); ok {
  1436  			h.mu.Unlock()
  1437  			break
  1438  		}
  1439  		h.mu.Unlock()
  1440  		time.Sleep(10 * time.Millisecond)
  1441  	}
  1442  
  1443  	for start := time.Now(); time.Since(start) < defaultTestTimeout; {
  1444  		h.mu.Lock()
  1445  		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
  1446  			h.mu.Unlock()
  1447  			break
  1448  		}
  1449  		h.mu.Unlock()
  1450  		time.Sleep(10 * time.Millisecond)
  1451  	}
  1452  
  1453  	// Each RPC generates 6 stats events on the server-side, times 2 StatsHandler
  1454  	if len(h.gotRPC) != 12 {
  1455  		t.Fatalf("h.gotRPC: unexpected amount of RPCStats: %v != %v", len(h.gotRPC), 12)
  1456  	}
  1457  
  1458  	// Each connection generates 4 conn events on the server-side, times 2 StatsHandler
  1459  	if len(h.gotConn) != 4 {
  1460  		t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
  1461  	}
  1462  }
  1463  
  1464  // TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
  1465  // gets access to a Server on the server side, and thus the method that the
  1466  // server owns which specifies whether a method is made or not. The test sets up
  1467  // a server with a unary call and full duplex call configured, and makes an RPC.
  1468  // Within the stats handler, asking the server whether unary or duplex method
  1469  // names are registered should return true, and any other query should return
  1470  // false.
  1471  func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
  1472  	wg := sync.WaitGroup{}
  1473  	wg.Add(1)
  1474  	stubStatsHandler := &testutils.StubStatsHandler{
  1475  		TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
  1476  			// OpenTelemetry instrumentation needs the passed in Server to determine if
  1477  			// methods are registered in different handle calls in to record metrics.
  1478  			// This tag RPC call context gets passed into every handle call, so can
  1479  			// assert once here, since it maps to all the handle RPC calls that come
  1480  			// after. These internal calls will be how the OpenTelemetry instrumentation
  1481  			// component accesses this server and the subsequent helper on the server.
  1482  			server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
  1483  			if server == nil {
  1484  				t.Errorf("stats handler received ctx has no server present")
  1485  			}
  1486  			isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
  1487  			// /s/m and s/m are valid.
  1488  			if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
  1489  				t.Errorf("UnaryCall should be a registered method according to server")
  1490  			}
  1491  			if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
  1492  				t.Errorf("FullDuplexCall should be a registered method according to server")
  1493  			}
  1494  			if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
  1495  				t.Errorf("DoesNotExistCall should not be a registered method according to server")
  1496  			}
  1497  			if isRegisteredMethod(server, "/unknownService/UnaryCall") {
  1498  				t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
  1499  			}
  1500  			wg.Done()
  1501  			return ctx
  1502  		},
  1503  	}
  1504  	ss := &stubserver.StubServer{
  1505  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
  1506  			return &testpb.SimpleResponse{}, nil
  1507  		},
  1508  	}
  1509  	if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
  1510  		t.Fatalf("Error starting endpoint server: %v", err)
  1511  	}
  1512  	defer ss.Stop()
  1513  
  1514  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1515  	defer cancel()
  1516  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
  1517  		t.Fatalf("Unexpected error from UnaryCall: %v", err)
  1518  	}
  1519  	wg.Wait()
  1520  }
  1521  

View as plain text