...

Source file src/google.golang.org/grpc/binarylog/binarylog_end2end_test.go

Documentation: google.golang.org/grpc/binarylog

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package binarylog_test
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"sort"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/binarylog"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/credentials/insecure"
    35  	"google.golang.org/grpc/grpclog"
    36  	iblog "google.golang.org/grpc/internal/binarylog"
    37  	"google.golang.org/grpc/internal/grpctest"
    38  	"google.golang.org/grpc/internal/stubserver"
    39  	"google.golang.org/grpc/metadata"
    40  	"google.golang.org/grpc/status"
    41  	"google.golang.org/protobuf/proto"
    42  
    43  	binlogpb "google.golang.org/grpc/binarylog/grpc_binarylog_v1"
    44  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    45  	testpb "google.golang.org/grpc/interop/grpc_testing"
    46  )
    47  
    48  var grpclogLogger = grpclog.Component("binarylog")
    49  
    50  type s struct {
    51  	grpctest.Tester
    52  }
    53  
    54  func Test(t *testing.T) {
    55  	grpctest.RunSubTests(t, s{})
    56  }
    57  
    58  func init() {
    59  	// Setting environment variable in tests doesn't work because of the init
    60  	// orders. Set the loggers directly here.
    61  	iblog.SetLogger(iblog.AllLogger)
    62  	binarylog.SetSink(testSink)
    63  }
    64  
    65  var testSink = &testBinLogSink{}
    66  
    67  type testBinLogSink struct {
    68  	mu  sync.Mutex
    69  	buf []*binlogpb.GrpcLogEntry
    70  }
    71  
    72  func (s *testBinLogSink) Write(e *binlogpb.GrpcLogEntry) error {
    73  	s.mu.Lock()
    74  	s.buf = append(s.buf, e)
    75  	s.mu.Unlock()
    76  	return nil
    77  }
    78  
    79  func (s *testBinLogSink) Close() error { return nil }
    80  
    81  // Returns all client entris if client is true, otherwise return all server
    82  // entries.
    83  func (s *testBinLogSink) logEntries(client bool) []*binlogpb.GrpcLogEntry {
    84  	logger := binlogpb.GrpcLogEntry_LOGGER_SERVER
    85  	if client {
    86  		logger = binlogpb.GrpcLogEntry_LOGGER_CLIENT
    87  	}
    88  	var ret []*binlogpb.GrpcLogEntry
    89  	s.mu.Lock()
    90  	for _, e := range s.buf {
    91  		if e.Logger == logger {
    92  			ret = append(ret, e)
    93  		}
    94  	}
    95  	s.mu.Unlock()
    96  	return ret
    97  }
    98  
    99  func (s *testBinLogSink) clear() {
   100  	s.mu.Lock()
   101  	s.buf = nil
   102  	s.mu.Unlock()
   103  }
   104  
   105  var (
   106  	// For headers:
   107  	testMetadata = metadata.MD{
   108  		"key1": []string{"value1"},
   109  		"key2": []string{"value2"},
   110  	}
   111  	// For trailers:
   112  	testTrailerMetadata = metadata.MD{
   113  		"tkey1": []string{"trailerValue1"},
   114  		"tkey2": []string{"trailerValue2"},
   115  	}
   116  	// The id for which the service handler should return error.
   117  	errorID int32 = 32202
   118  
   119  	globalRPCID uint64 // RPC id starts with 1, but we do ++ at the beginning of each test.
   120  )
   121  
   122  func idToPayload(id int32) *testpb.Payload {
   123  	return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
   124  }
   125  
   126  func payloadToID(p *testpb.Payload) int32 {
   127  	if p == nil || len(p.Body) != 4 {
   128  		panic("invalid payload")
   129  	}
   130  	return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
   131  }
   132  
   133  type testServer struct {
   134  	testgrpc.UnimplementedTestServiceServer
   135  	te *test
   136  }
   137  
   138  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   139  	md, ok := metadata.FromIncomingContext(ctx)
   140  	if ok {
   141  		if err := grpc.SendHeader(ctx, md); err != nil {
   142  			return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
   143  		}
   144  		if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
   145  			return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
   146  		}
   147  	}
   148  
   149  	if id := payloadToID(in.Payload); id == errorID {
   150  		return nil, fmt.Errorf("got error id: %v", id)
   151  	}
   152  
   153  	return &testpb.SimpleResponse{Payload: in.Payload}, nil
   154  }
   155  
   156  func (s *testServer) FullDuplexCall(stream testgrpc.TestService_FullDuplexCallServer) error {
   157  	md, ok := metadata.FromIncomingContext(stream.Context())
   158  	if ok {
   159  		if err := stream.SendHeader(md); err != nil {
   160  			return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
   161  		}
   162  		stream.SetTrailer(testTrailerMetadata)
   163  	}
   164  	for {
   165  		in, err := stream.Recv()
   166  		if err == io.EOF {
   167  			// read done.
   168  			return nil
   169  		}
   170  		if err != nil {
   171  			return err
   172  		}
   173  
   174  		if id := payloadToID(in.Payload); id == errorID {
   175  			return fmt.Errorf("got error id: %v", id)
   176  		}
   177  
   178  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   179  			return err
   180  		}
   181  	}
   182  }
   183  
   184  func (s *testServer) StreamingInputCall(stream testgrpc.TestService_StreamingInputCallServer) error {
   185  	md, ok := metadata.FromIncomingContext(stream.Context())
   186  	if ok {
   187  		if err := stream.SendHeader(md); err != nil {
   188  			return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
   189  		}
   190  		stream.SetTrailer(testTrailerMetadata)
   191  	}
   192  	for {
   193  		in, err := stream.Recv()
   194  		if err == io.EOF {
   195  			// read done.
   196  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
   197  		}
   198  		if err != nil {
   199  			return err
   200  		}
   201  
   202  		if id := payloadToID(in.Payload); id == errorID {
   203  			return fmt.Errorf("got error id: %v", id)
   204  		}
   205  	}
   206  }
   207  
   208  func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error {
   209  	md, ok := metadata.FromIncomingContext(stream.Context())
   210  	if ok {
   211  		if err := stream.SendHeader(md); err != nil {
   212  			return status.Errorf(status.Code(err), "stream.SendHeader(%v) = %v, want %v", md, err, nil)
   213  		}
   214  		stream.SetTrailer(testTrailerMetadata)
   215  	}
   216  
   217  	if id := payloadToID(in.Payload); id == errorID {
   218  		return fmt.Errorf("got error id: %v", id)
   219  	}
   220  
   221  	for i := 0; i < 5; i++ {
   222  		if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
   223  			return err
   224  		}
   225  	}
   226  	return nil
   227  }
   228  
   229  // test is an end-to-end test. It should be created with the newTest
   230  // func, modified as needed, and then started with its startServer method.
   231  // It should be cleaned up with the tearDown method.
   232  type test struct {
   233  	t *testing.T
   234  
   235  	testService testgrpc.TestServiceServer // nil means none
   236  	// srv and srvAddr are set once startServer is called.
   237  	srv     *grpc.Server
   238  	srvAddr string // Server IP without port.
   239  	srvIP   net.IP
   240  	srvPort int
   241  
   242  	cc *grpc.ClientConn // nil until requested via clientConn
   243  
   244  	// Fields for client address. Set by the service handler.
   245  	clientAddrMu sync.Mutex
   246  	clientIP     net.IP
   247  	clientPort   int
   248  }
   249  
   250  func (te *test) tearDown() {
   251  	if te.cc != nil {
   252  		te.cc.Close()
   253  		te.cc = nil
   254  	}
   255  	te.srv.Stop()
   256  }
   257  
   258  // newTest returns a new test using the provided testing.T and
   259  // environment.  It is returned with default values. Tests should
   260  // modify it before calling its startServer and clientConn methods.
   261  func newTest(t *testing.T) *test {
   262  	te := &test{
   263  		t: t,
   264  	}
   265  	return te
   266  }
   267  
   268  type listenerWrapper struct {
   269  	net.Listener
   270  	te *test
   271  }
   272  
   273  func (lw *listenerWrapper) Accept() (net.Conn, error) {
   274  	conn, err := lw.Listener.Accept()
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	lw.te.clientAddrMu.Lock()
   279  	lw.te.clientIP = conn.RemoteAddr().(*net.TCPAddr).IP
   280  	lw.te.clientPort = conn.RemoteAddr().(*net.TCPAddr).Port
   281  	lw.te.clientAddrMu.Unlock()
   282  	return conn, nil
   283  }
   284  
   285  // startServer starts a gRPC server listening. Callers should defer a
   286  // call to te.tearDown to clean up.
   287  func (te *test) startServer(ts testgrpc.TestServiceServer) {
   288  	te.testService = ts
   289  	lis, err := net.Listen("tcp", "localhost:0")
   290  
   291  	lis = &listenerWrapper{
   292  		Listener: lis,
   293  		te:       te,
   294  	}
   295  
   296  	if err != nil {
   297  		te.t.Fatalf("Failed to listen: %v", err)
   298  	}
   299  	var opts []grpc.ServerOption
   300  	s := grpc.NewServer(opts...)
   301  	te.srv = s
   302  	if te.testService != nil {
   303  		testgrpc.RegisterTestServiceServer(s, te.testService)
   304  	}
   305  
   306  	go s.Serve(lis)
   307  	te.srvAddr = lis.Addr().String()
   308  	te.srvIP = lis.Addr().(*net.TCPAddr).IP
   309  	te.srvPort = lis.Addr().(*net.TCPAddr).Port
   310  }
   311  
   312  func (te *test) clientConn() *grpc.ClientConn {
   313  	if te.cc != nil {
   314  		return te.cc
   315  	}
   316  	opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()}
   317  
   318  	var err error
   319  	te.cc, err = grpc.NewClient(te.srvAddr, opts...)
   320  	if err != nil {
   321  		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
   322  	}
   323  	return te.cc
   324  }
   325  
   326  type rpcType int
   327  
   328  const (
   329  	unaryRPC rpcType = iota
   330  	clientStreamRPC
   331  	serverStreamRPC
   332  	fullDuplexStreamRPC
   333  	cancelRPC
   334  )
   335  
   336  type rpcConfig struct {
   337  	count    int     // Number of requests and responses for streaming RPCs.
   338  	success  bool    // Whether the RPC should succeed or return error.
   339  	callType rpcType // Type of RPC.
   340  }
   341  
   342  func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
   343  	var (
   344  		resp *testpb.SimpleResponse
   345  		req  *testpb.SimpleRequest
   346  		err  error
   347  	)
   348  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   349  	if c.success {
   350  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
   351  	} else {
   352  		req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
   353  	}
   354  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   355  	defer cancel()
   356  	ctx = metadata.NewOutgoingContext(ctx, testMetadata)
   357  
   358  	resp, err = tc.UnaryCall(ctx, req)
   359  	return req, resp, err
   360  }
   361  
   362  func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
   363  	var (
   364  		reqs  []proto.Message
   365  		resps []proto.Message
   366  		err   error
   367  	)
   368  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   369  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   370  	defer cancel()
   371  	ctx = metadata.NewOutgoingContext(ctx, testMetadata)
   372  
   373  	stream, err := tc.FullDuplexCall(ctx)
   374  	if err != nil {
   375  		return reqs, resps, err
   376  	}
   377  
   378  	if c.callType == cancelRPC {
   379  		cancel()
   380  		return reqs, resps, context.Canceled
   381  	}
   382  
   383  	var startID int32
   384  	if !c.success {
   385  		startID = errorID
   386  	}
   387  	for i := 0; i < c.count; i++ {
   388  		req := &testpb.StreamingOutputCallRequest{
   389  			Payload: idToPayload(int32(i) + startID),
   390  		}
   391  		reqs = append(reqs, req)
   392  		if err = stream.Send(req); err != nil {
   393  			return reqs, resps, err
   394  		}
   395  		var resp *testpb.StreamingOutputCallResponse
   396  		if resp, err = stream.Recv(); err != nil {
   397  			return reqs, resps, err
   398  		}
   399  		resps = append(resps, resp)
   400  	}
   401  	if err = stream.CloseSend(); err != nil && err != io.EOF {
   402  		return reqs, resps, err
   403  	}
   404  	if _, err = stream.Recv(); err != io.EOF {
   405  		return reqs, resps, err
   406  	}
   407  
   408  	return reqs, resps, nil
   409  }
   410  
   411  func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, proto.Message, error) {
   412  	var (
   413  		reqs []proto.Message
   414  		resp *testpb.StreamingInputCallResponse
   415  		err  error
   416  	)
   417  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   418  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   419  	defer cancel()
   420  	ctx = metadata.NewOutgoingContext(ctx, testMetadata)
   421  
   422  	stream, err := tc.StreamingInputCall(ctx)
   423  	if err != nil {
   424  		return reqs, resp, err
   425  	}
   426  	var startID int32
   427  	if !c.success {
   428  		startID = errorID
   429  	}
   430  	for i := 0; i < c.count; i++ {
   431  		req := &testpb.StreamingInputCallRequest{
   432  			Payload: idToPayload(int32(i) + startID),
   433  		}
   434  		reqs = append(reqs, req)
   435  		if err = stream.Send(req); err != nil {
   436  			return reqs, resp, err
   437  		}
   438  	}
   439  	resp, err = stream.CloseAndRecv()
   440  	return reqs, resp, err
   441  }
   442  
   443  func (te *test) doServerStreamCall(c *rpcConfig) (proto.Message, []proto.Message, error) {
   444  	var (
   445  		req   *testpb.StreamingOutputCallRequest
   446  		resps []proto.Message
   447  		err   error
   448  	)
   449  
   450  	tc := testgrpc.NewTestServiceClient(te.clientConn())
   451  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   452  	defer cancel()
   453  	ctx = metadata.NewOutgoingContext(ctx, testMetadata)
   454  
   455  	var startID int32
   456  	if !c.success {
   457  		startID = errorID
   458  	}
   459  	req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
   460  	stream, err := tc.StreamingOutputCall(ctx, req)
   461  	if err != nil {
   462  		return req, resps, err
   463  	}
   464  	for {
   465  		var resp *testpb.StreamingOutputCallResponse
   466  		resp, err := stream.Recv()
   467  		if err == io.EOF {
   468  			return req, resps, nil
   469  		} else if err != nil {
   470  			return req, resps, err
   471  		}
   472  		resps = append(resps, resp)
   473  	}
   474  }
   475  
   476  type expectedData struct {
   477  	te *test
   478  	cc *rpcConfig
   479  
   480  	method    string
   481  	requests  []proto.Message
   482  	responses []proto.Message
   483  	err       error
   484  }
   485  
   486  func (ed *expectedData) newClientHeaderEntry(client bool, rpcID, inRPCID uint64) *binlogpb.GrpcLogEntry {
   487  	logger := binlogpb.GrpcLogEntry_LOGGER_CLIENT
   488  	var peer *binlogpb.Address
   489  	if !client {
   490  		logger = binlogpb.GrpcLogEntry_LOGGER_SERVER
   491  		ed.te.clientAddrMu.Lock()
   492  		peer = &binlogpb.Address{
   493  			Address: ed.te.clientIP.String(),
   494  			IpPort:  uint32(ed.te.clientPort),
   495  		}
   496  		if ed.te.clientIP.To4() != nil {
   497  			peer.Type = binlogpb.Address_TYPE_IPV4
   498  		} else {
   499  			peer.Type = binlogpb.Address_TYPE_IPV6
   500  		}
   501  		ed.te.clientAddrMu.Unlock()
   502  	}
   503  	return &binlogpb.GrpcLogEntry{
   504  		Timestamp:            nil,
   505  		CallId:               rpcID,
   506  		SequenceIdWithinCall: inRPCID,
   507  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_CLIENT_HEADER,
   508  		Logger:               logger,
   509  		Payload: &binlogpb.GrpcLogEntry_ClientHeader{
   510  			ClientHeader: &binlogpb.ClientHeader{
   511  				Metadata:   iblog.MdToMetadataProto(testMetadata),
   512  				MethodName: ed.method,
   513  				Authority:  ed.te.srvAddr,
   514  			},
   515  		},
   516  		Peer: peer,
   517  	}
   518  }
   519  
   520  func (ed *expectedData) newServerHeaderEntry(client bool, rpcID, inRPCID uint64) *binlogpb.GrpcLogEntry {
   521  	logger := binlogpb.GrpcLogEntry_LOGGER_SERVER
   522  	var peer *binlogpb.Address
   523  	if client {
   524  		logger = binlogpb.GrpcLogEntry_LOGGER_CLIENT
   525  		peer = &binlogpb.Address{
   526  			Address: ed.te.srvIP.String(),
   527  			IpPort:  uint32(ed.te.srvPort),
   528  		}
   529  		if ed.te.srvIP.To4() != nil {
   530  			peer.Type = binlogpb.Address_TYPE_IPV4
   531  		} else {
   532  			peer.Type = binlogpb.Address_TYPE_IPV6
   533  		}
   534  	}
   535  	return &binlogpb.GrpcLogEntry{
   536  		Timestamp:            nil,
   537  		CallId:               rpcID,
   538  		SequenceIdWithinCall: inRPCID,
   539  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_HEADER,
   540  		Logger:               logger,
   541  		Payload: &binlogpb.GrpcLogEntry_ServerHeader{
   542  			ServerHeader: &binlogpb.ServerHeader{
   543  				Metadata: iblog.MdToMetadataProto(testMetadata),
   544  			},
   545  		},
   546  		Peer: peer,
   547  	}
   548  }
   549  
   550  func (ed *expectedData) newClientMessageEntry(client bool, rpcID, inRPCID uint64, msg proto.Message) *binlogpb.GrpcLogEntry {
   551  	logger := binlogpb.GrpcLogEntry_LOGGER_CLIENT
   552  	if !client {
   553  		logger = binlogpb.GrpcLogEntry_LOGGER_SERVER
   554  	}
   555  	data, err := proto.Marshal(msg)
   556  	if err != nil {
   557  		grpclogLogger.Infof("binarylogging_testing: failed to marshal proto message: %v", err)
   558  	}
   559  	return &binlogpb.GrpcLogEntry{
   560  		Timestamp:            nil,
   561  		CallId:               rpcID,
   562  		SequenceIdWithinCall: inRPCID,
   563  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_CLIENT_MESSAGE,
   564  		Logger:               logger,
   565  		Payload: &binlogpb.GrpcLogEntry_Message{
   566  			Message: &binlogpb.Message{
   567  				Length: uint32(len(data)),
   568  				Data:   data,
   569  			},
   570  		},
   571  	}
   572  }
   573  
   574  func (ed *expectedData) newServerMessageEntry(client bool, rpcID, inRPCID uint64, msg proto.Message) *binlogpb.GrpcLogEntry {
   575  	logger := binlogpb.GrpcLogEntry_LOGGER_CLIENT
   576  	if !client {
   577  		logger = binlogpb.GrpcLogEntry_LOGGER_SERVER
   578  	}
   579  	data, err := proto.Marshal(msg)
   580  	if err != nil {
   581  		grpclogLogger.Infof("binarylogging_testing: failed to marshal proto message: %v", err)
   582  	}
   583  	return &binlogpb.GrpcLogEntry{
   584  		Timestamp:            nil,
   585  		CallId:               rpcID,
   586  		SequenceIdWithinCall: inRPCID,
   587  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_MESSAGE,
   588  		Logger:               logger,
   589  		Payload: &binlogpb.GrpcLogEntry_Message{
   590  			Message: &binlogpb.Message{
   591  				Length: uint32(len(data)),
   592  				Data:   data,
   593  			},
   594  		},
   595  	}
   596  }
   597  
   598  func (ed *expectedData) newHalfCloseEntry(client bool, rpcID, inRPCID uint64) *binlogpb.GrpcLogEntry {
   599  	logger := binlogpb.GrpcLogEntry_LOGGER_CLIENT
   600  	if !client {
   601  		logger = binlogpb.GrpcLogEntry_LOGGER_SERVER
   602  	}
   603  	return &binlogpb.GrpcLogEntry{
   604  		Timestamp:            nil,
   605  		CallId:               rpcID,
   606  		SequenceIdWithinCall: inRPCID,
   607  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_CLIENT_HALF_CLOSE,
   608  		Payload:              nil, // No payload here.
   609  		Logger:               logger,
   610  	}
   611  }
   612  
   613  func (ed *expectedData) newServerTrailerEntry(client bool, rpcID, inRPCID uint64, stErr error) *binlogpb.GrpcLogEntry {
   614  	logger := binlogpb.GrpcLogEntry_LOGGER_SERVER
   615  	var peer *binlogpb.Address
   616  	if client {
   617  		logger = binlogpb.GrpcLogEntry_LOGGER_CLIENT
   618  		peer = &binlogpb.Address{
   619  			Address: ed.te.srvIP.String(),
   620  			IpPort:  uint32(ed.te.srvPort),
   621  		}
   622  		if ed.te.srvIP.To4() != nil {
   623  			peer.Type = binlogpb.Address_TYPE_IPV4
   624  		} else {
   625  			peer.Type = binlogpb.Address_TYPE_IPV6
   626  		}
   627  	}
   628  	st, ok := status.FromError(stErr)
   629  	if !ok {
   630  		grpclogLogger.Info("binarylogging: error in trailer is not a status error")
   631  	}
   632  	stProto := st.Proto()
   633  	var (
   634  		detailsBytes []byte
   635  		err          error
   636  	)
   637  	if stProto != nil && len(stProto.Details) != 0 {
   638  		detailsBytes, err = proto.Marshal(stProto)
   639  		if err != nil {
   640  			grpclogLogger.Infof("binarylogging: failed to marshal status proto: %v", err)
   641  		}
   642  	}
   643  	return &binlogpb.GrpcLogEntry{
   644  		Timestamp:            nil,
   645  		CallId:               rpcID,
   646  		SequenceIdWithinCall: inRPCID,
   647  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER,
   648  		Logger:               logger,
   649  		Payload: &binlogpb.GrpcLogEntry_Trailer{
   650  			Trailer: &binlogpb.Trailer{
   651  				Metadata: iblog.MdToMetadataProto(testTrailerMetadata),
   652  				// st will be nil if err was not a status error, but nil is ok.
   653  				StatusCode:    uint32(st.Code()),
   654  				StatusMessage: st.Message(),
   655  				StatusDetails: detailsBytes,
   656  			},
   657  		},
   658  		Peer: peer,
   659  	}
   660  }
   661  
   662  func (ed *expectedData) newCancelEntry(rpcID, inRPCID uint64) *binlogpb.GrpcLogEntry {
   663  	return &binlogpb.GrpcLogEntry{
   664  		Timestamp:            nil,
   665  		CallId:               rpcID,
   666  		SequenceIdWithinCall: inRPCID,
   667  		Type:                 binlogpb.GrpcLogEntry_EVENT_TYPE_CANCEL,
   668  		Logger:               binlogpb.GrpcLogEntry_LOGGER_CLIENT,
   669  		Payload:              nil,
   670  	}
   671  }
   672  
   673  func (ed *expectedData) toClientLogEntries() []*binlogpb.GrpcLogEntry {
   674  	var (
   675  		ret     []*binlogpb.GrpcLogEntry
   676  		idInRPC uint64 = 1
   677  	)
   678  	ret = append(ret, ed.newClientHeaderEntry(true, globalRPCID, idInRPC))
   679  	idInRPC++
   680  
   681  	switch ed.cc.callType {
   682  	case unaryRPC, fullDuplexStreamRPC:
   683  		for i := 0; i < len(ed.requests); i++ {
   684  			ret = append(ret, ed.newClientMessageEntry(true, globalRPCID, idInRPC, ed.requests[i]))
   685  			idInRPC++
   686  			if i == 0 {
   687  				// First message, append ServerHeader.
   688  				ret = append(ret, ed.newServerHeaderEntry(true, globalRPCID, idInRPC))
   689  				idInRPC++
   690  			}
   691  			if !ed.cc.success {
   692  				// There is no response in the RPC error case.
   693  				continue
   694  			}
   695  			ret = append(ret, ed.newServerMessageEntry(true, globalRPCID, idInRPC, ed.responses[i]))
   696  			idInRPC++
   697  		}
   698  		if ed.cc.success && ed.cc.callType == fullDuplexStreamRPC {
   699  			ret = append(ret, ed.newHalfCloseEntry(true, globalRPCID, idInRPC))
   700  			idInRPC++
   701  		}
   702  	case clientStreamRPC, serverStreamRPC:
   703  		for i := 0; i < len(ed.requests); i++ {
   704  			ret = append(ret, ed.newClientMessageEntry(true, globalRPCID, idInRPC, ed.requests[i]))
   705  			idInRPC++
   706  		}
   707  		if ed.cc.callType == clientStreamRPC {
   708  			ret = append(ret, ed.newHalfCloseEntry(true, globalRPCID, idInRPC))
   709  			idInRPC++
   710  		}
   711  		ret = append(ret, ed.newServerHeaderEntry(true, globalRPCID, idInRPC))
   712  		idInRPC++
   713  		if ed.cc.success {
   714  			for i := 0; i < len(ed.responses); i++ {
   715  				ret = append(ret, ed.newServerMessageEntry(true, globalRPCID, idInRPC, ed.responses[0]))
   716  				idInRPC++
   717  			}
   718  		}
   719  	}
   720  
   721  	if ed.cc.callType == cancelRPC {
   722  		ret = append(ret, ed.newCancelEntry(globalRPCID, idInRPC))
   723  		idInRPC++
   724  	} else {
   725  		ret = append(ret, ed.newServerTrailerEntry(true, globalRPCID, idInRPC, ed.err))
   726  		idInRPC++
   727  	}
   728  	return ret
   729  }
   730  
   731  func (ed *expectedData) toServerLogEntries() []*binlogpb.GrpcLogEntry {
   732  	var (
   733  		ret     []*binlogpb.GrpcLogEntry
   734  		idInRPC uint64 = 1
   735  	)
   736  	ret = append(ret, ed.newClientHeaderEntry(false, globalRPCID, idInRPC))
   737  	idInRPC++
   738  
   739  	switch ed.cc.callType {
   740  	case unaryRPC:
   741  		ret = append(ret, ed.newClientMessageEntry(false, globalRPCID, idInRPC, ed.requests[0]))
   742  		idInRPC++
   743  		ret = append(ret, ed.newServerHeaderEntry(false, globalRPCID, idInRPC))
   744  		idInRPC++
   745  		if ed.cc.success {
   746  			ret = append(ret, ed.newServerMessageEntry(false, globalRPCID, idInRPC, ed.responses[0]))
   747  			idInRPC++
   748  		}
   749  	case fullDuplexStreamRPC:
   750  		ret = append(ret, ed.newServerHeaderEntry(false, globalRPCID, idInRPC))
   751  		idInRPC++
   752  		for i := 0; i < len(ed.requests); i++ {
   753  			ret = append(ret, ed.newClientMessageEntry(false, globalRPCID, idInRPC, ed.requests[i]))
   754  			idInRPC++
   755  			if !ed.cc.success {
   756  				// There is no response in the RPC error case.
   757  				continue
   758  			}
   759  			ret = append(ret, ed.newServerMessageEntry(false, globalRPCID, idInRPC, ed.responses[i]))
   760  			idInRPC++
   761  		}
   762  
   763  		if ed.cc.success && ed.cc.callType == fullDuplexStreamRPC {
   764  			ret = append(ret, ed.newHalfCloseEntry(false, globalRPCID, idInRPC))
   765  			idInRPC++
   766  		}
   767  	case clientStreamRPC:
   768  		ret = append(ret, ed.newServerHeaderEntry(false, globalRPCID, idInRPC))
   769  		idInRPC++
   770  		for i := 0; i < len(ed.requests); i++ {
   771  			ret = append(ret, ed.newClientMessageEntry(false, globalRPCID, idInRPC, ed.requests[i]))
   772  			idInRPC++
   773  		}
   774  		if ed.cc.success {
   775  			ret = append(ret, ed.newHalfCloseEntry(false, globalRPCID, idInRPC))
   776  			idInRPC++
   777  			ret = append(ret, ed.newServerMessageEntry(false, globalRPCID, idInRPC, ed.responses[0]))
   778  			idInRPC++
   779  		}
   780  	case serverStreamRPC:
   781  		ret = append(ret, ed.newClientMessageEntry(false, globalRPCID, idInRPC, ed.requests[0]))
   782  		idInRPC++
   783  		ret = append(ret, ed.newServerHeaderEntry(false, globalRPCID, idInRPC))
   784  		idInRPC++
   785  		for i := 0; i < len(ed.responses); i++ {
   786  			ret = append(ret, ed.newServerMessageEntry(false, globalRPCID, idInRPC, ed.responses[0]))
   787  			idInRPC++
   788  		}
   789  	}
   790  
   791  	ret = append(ret, ed.newServerTrailerEntry(false, globalRPCID, idInRPC, ed.err))
   792  	idInRPC++
   793  
   794  	return ret
   795  }
   796  
   797  func runRPCs(t *testing.T, cc *rpcConfig) *expectedData {
   798  	te := newTest(t)
   799  	te.startServer(&testServer{te: te})
   800  	defer te.tearDown()
   801  
   802  	expect := &expectedData{
   803  		te: te,
   804  		cc: cc,
   805  	}
   806  
   807  	switch cc.callType {
   808  	case unaryRPC:
   809  		expect.method = "/grpc.testing.TestService/UnaryCall"
   810  		req, resp, err := te.doUnaryCall(cc)
   811  		expect.requests = []proto.Message{req}
   812  		expect.responses = []proto.Message{resp}
   813  		expect.err = err
   814  	case clientStreamRPC:
   815  		expect.method = "/grpc.testing.TestService/StreamingInputCall"
   816  		reqs, resp, err := te.doClientStreamCall(cc)
   817  		expect.requests = reqs
   818  		expect.responses = []proto.Message{resp}
   819  		expect.err = err
   820  	case serverStreamRPC:
   821  		expect.method = "/grpc.testing.TestService/StreamingOutputCall"
   822  		req, resps, err := te.doServerStreamCall(cc)
   823  		expect.responses = resps
   824  		expect.requests = []proto.Message{req}
   825  		expect.err = err
   826  	case fullDuplexStreamRPC, cancelRPC:
   827  		expect.method = "/grpc.testing.TestService/FullDuplexCall"
   828  		expect.requests, expect.responses, expect.err = te.doFullDuplexCallRoundtrip(cc)
   829  	}
   830  	if cc.success != (expect.err == nil) {
   831  		t.Fatalf("cc.success: %v, got error: %v", cc.success, expect.err)
   832  	}
   833  	te.cc.Close()
   834  	te.srv.GracefulStop() // Wait for the server to stop.
   835  
   836  	return expect
   837  }
   838  
   839  // equalLogEntry sorts the metadata entries by key (to compare metadata).
   840  //
   841  // This function is typically called with only two entries. It's written in this
   842  // way so the code can be put in a for loop instead of copied twice.
   843  func equalLogEntry(entries ...*binlogpb.GrpcLogEntry) (equal bool) {
   844  	for i, e := range entries {
   845  		// Clear out some fields we don't compare.
   846  		e.Timestamp = nil
   847  		e.CallId = 0 // CallID is global to the binary, hard to compare.
   848  		if h := e.GetClientHeader(); h != nil {
   849  			h.Timeout = nil
   850  			tmp := append(h.Metadata.Entry[:0], h.Metadata.Entry...)
   851  			h.Metadata.Entry = tmp
   852  			sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key })
   853  		}
   854  		if h := e.GetServerHeader(); h != nil {
   855  			tmp := append(h.Metadata.Entry[:0], h.Metadata.Entry...)
   856  			h.Metadata.Entry = tmp
   857  			sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key })
   858  		}
   859  		if h := e.GetTrailer(); h != nil {
   860  			sort.Slice(h.Metadata.Entry, func(i, j int) bool { return h.Metadata.Entry[i].Key < h.Metadata.Entry[j].Key })
   861  		}
   862  
   863  		if i > 0 && !proto.Equal(e, entries[i-1]) {
   864  			return false
   865  		}
   866  	}
   867  	return true
   868  }
   869  
   870  func testClientBinaryLog(t *testing.T, c *rpcConfig) error {
   871  	defer testSink.clear()
   872  	expect := runRPCs(t, c)
   873  	want := expect.toClientLogEntries()
   874  	var got []*binlogpb.GrpcLogEntry
   875  	// In racy cases, some entries are not logged when the RPC is finished (e.g.
   876  	// context.Cancel).
   877  	//
   878  	// Check 10 times, with a sleep of 1/100 seconds between each check. Makes
   879  	// it an 1-second wait in total.
   880  	for i := 0; i < 10; i++ {
   881  		got = testSink.logEntries(true) // all client entries.
   882  		if len(want) == len(got) {
   883  			break
   884  		}
   885  		time.Sleep(100 * time.Millisecond)
   886  	}
   887  	if len(want) != len(got) {
   888  		for i, e := range want {
   889  			t.Errorf("in want: %d, %s", i, e.GetType())
   890  		}
   891  		for i, e := range got {
   892  			t.Errorf("in got: %d, %s", i, e.GetType())
   893  		}
   894  		return fmt.Errorf("didn't get same amount of log entries, want: %d, got: %d", len(want), len(got))
   895  	}
   896  	var errored bool
   897  	for i := 0; i < len(got); i++ {
   898  		if !equalLogEntry(want[i], got[i]) {
   899  			t.Errorf("entry: %d, want %+v, got %+v", i, want[i], got[i])
   900  			errored = true
   901  		}
   902  	}
   903  	if errored {
   904  		return fmt.Errorf("test failed")
   905  	}
   906  	return nil
   907  }
   908  
   909  func (s) TestClientBinaryLogUnaryRPC(t *testing.T) {
   910  	if err := testClientBinaryLog(t, &rpcConfig{success: true, callType: unaryRPC}); err != nil {
   911  		t.Fatal(err)
   912  	}
   913  }
   914  
   915  func (s) TestClientBinaryLogUnaryRPCError(t *testing.T) {
   916  	if err := testClientBinaryLog(t, &rpcConfig{success: false, callType: unaryRPC}); err != nil {
   917  		t.Fatal(err)
   918  	}
   919  }
   920  
   921  func (s) TestClientBinaryLogClientStreamRPC(t *testing.T) {
   922  	count := 5
   923  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: true, callType: clientStreamRPC}); err != nil {
   924  		t.Fatal(err)
   925  	}
   926  }
   927  
   928  func (s) TestClientBinaryLogClientStreamRPCError(t *testing.T) {
   929  	count := 1
   930  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: false, callType: clientStreamRPC}); err != nil {
   931  		t.Fatal(err)
   932  	}
   933  }
   934  
   935  func (s) TestClientBinaryLogServerStreamRPC(t *testing.T) {
   936  	count := 5
   937  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: true, callType: serverStreamRPC}); err != nil {
   938  		t.Fatal(err)
   939  	}
   940  }
   941  
   942  func (s) TestClientBinaryLogServerStreamRPCError(t *testing.T) {
   943  	count := 5
   944  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: false, callType: serverStreamRPC}); err != nil {
   945  		t.Fatal(err)
   946  	}
   947  }
   948  
   949  func (s) TestClientBinaryLogFullDuplexRPC(t *testing.T) {
   950  	count := 5
   951  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}); err != nil {
   952  		t.Fatal(err)
   953  	}
   954  }
   955  
   956  func (s) TestClientBinaryLogFullDuplexRPCError(t *testing.T) {
   957  	count := 5
   958  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}); err != nil {
   959  		t.Fatal(err)
   960  	}
   961  }
   962  
   963  func (s) TestClientBinaryLogCancel(t *testing.T) {
   964  	count := 5
   965  	if err := testClientBinaryLog(t, &rpcConfig{count: count, success: false, callType: cancelRPC}); err != nil {
   966  		t.Fatal(err)
   967  	}
   968  }
   969  
   970  func testServerBinaryLog(t *testing.T, c *rpcConfig) error {
   971  	defer testSink.clear()
   972  	expect := runRPCs(t, c)
   973  	want := expect.toServerLogEntries()
   974  	var got []*binlogpb.GrpcLogEntry
   975  	// In racy cases, some entries are not logged when the RPC is finished (e.g.
   976  	// context.Cancel). This is unlikely to happen on server side, but it does
   977  	// no harm to retry.
   978  	//
   979  	// Check 10 times, with a sleep of 1/100 seconds between each check. Makes
   980  	// it an 1-second wait in total.
   981  	for i := 0; i < 10; i++ {
   982  		got = testSink.logEntries(false) // all server entries.
   983  		if len(want) == len(got) {
   984  			break
   985  		}
   986  		time.Sleep(100 * time.Millisecond)
   987  	}
   988  
   989  	if len(want) != len(got) {
   990  		for i, e := range want {
   991  			t.Errorf("in want: %d, %s", i, e.GetType())
   992  		}
   993  		for i, e := range got {
   994  			t.Errorf("in got: %d, %s", i, e.GetType())
   995  		}
   996  		return fmt.Errorf("didn't get same amount of log entries, want: %d, got: %d", len(want), len(got))
   997  	}
   998  	var errored bool
   999  	for i := 0; i < len(got); i++ {
  1000  		if !equalLogEntry(want[i], got[i]) {
  1001  			t.Errorf("entry: %d, want %+v, got %+v", i, want[i], got[i])
  1002  			errored = true
  1003  		}
  1004  	}
  1005  	if errored {
  1006  		return fmt.Errorf("test failed")
  1007  	}
  1008  	return nil
  1009  }
  1010  
  1011  func (s) TestServerBinaryLogUnaryRPC(t *testing.T) {
  1012  	if err := testServerBinaryLog(t, &rpcConfig{success: true, callType: unaryRPC}); err != nil {
  1013  		t.Fatal(err)
  1014  	}
  1015  }
  1016  
  1017  func (s) TestServerBinaryLogUnaryRPCError(t *testing.T) {
  1018  	if err := testServerBinaryLog(t, &rpcConfig{success: false, callType: unaryRPC}); err != nil {
  1019  		t.Fatal(err)
  1020  	}
  1021  }
  1022  
  1023  func (s) TestServerBinaryLogClientStreamRPC(t *testing.T) {
  1024  	count := 5
  1025  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: true, callType: clientStreamRPC}); err != nil {
  1026  		t.Fatal(err)
  1027  	}
  1028  }
  1029  
  1030  func (s) TestServerBinaryLogClientStreamRPCError(t *testing.T) {
  1031  	count := 1
  1032  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: false, callType: clientStreamRPC}); err != nil {
  1033  		t.Fatal(err)
  1034  	}
  1035  }
  1036  
  1037  func (s) TestServerBinaryLogServerStreamRPC(t *testing.T) {
  1038  	count := 5
  1039  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: true, callType: serverStreamRPC}); err != nil {
  1040  		t.Fatal(err)
  1041  	}
  1042  }
  1043  
  1044  func (s) TestServerBinaryLogServerStreamRPCError(t *testing.T) {
  1045  	count := 5
  1046  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: false, callType: serverStreamRPC}); err != nil {
  1047  		t.Fatal(err)
  1048  	}
  1049  }
  1050  
  1051  func (s) TestServerBinaryLogFullDuplex(t *testing.T) {
  1052  	count := 5
  1053  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}); err != nil {
  1054  		t.Fatal(err)
  1055  	}
  1056  }
  1057  
  1058  func (s) TestServerBinaryLogFullDuplexError(t *testing.T) {
  1059  	count := 5
  1060  	if err := testServerBinaryLog(t, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}); err != nil {
  1061  		t.Fatal(err)
  1062  	}
  1063  }
  1064  
  1065  // TestCanceledStatus ensures a server that responds with a Canceled status has
  1066  // its trailers logged appropriately and is not treated as a canceled RPC.
  1067  func (s) TestCanceledStatus(t *testing.T) {
  1068  	defer testSink.clear()
  1069  
  1070  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1071  	defer cancel()
  1072  
  1073  	const statusMsgWant = "server returned Canceled"
  1074  	ss := &stubserver.StubServer{
  1075  		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
  1076  			grpc.SetTrailer(ctx, metadata.Pairs("key", "value"))
  1077  			return nil, status.Error(codes.Canceled, statusMsgWant)
  1078  		},
  1079  	}
  1080  	if err := ss.Start(nil); err != nil {
  1081  		t.Fatalf("Error starting endpoint server: %v", err)
  1082  	}
  1083  	defer ss.Stop()
  1084  
  1085  	if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled {
  1086  		t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err)
  1087  	}
  1088  
  1089  	got := testSink.logEntries(true)
  1090  	last := got[len(got)-1]
  1091  	if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER ||
  1092  		last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) ||
  1093  		last.GetTrailer().GetStatusMessage() != statusMsgWant ||
  1094  		len(last.GetTrailer().GetMetadata().GetEntry()) != 1 ||
  1095  		last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" ||
  1096  		string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" {
  1097  		t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got)
  1098  	}
  1099  }
  1100  

View as plain text