...

Source file src/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/test/test_utils.go

Documentation: go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/test

     1  // Copyright The OpenTelemetry Authors
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  /*
     5   *
     6   * Copyright 2014 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  // Package test contains functions used by interop client/server.
    23  //
    24  // Copied from https://github.com/grpc/grpc-go/tree/v1.61.0/interop
    25  // That package was not intended to be used by external code.
    26  // See https://github.com/open-telemetry/opentelemetry-go-contrib/issues/4896
    27  package test // import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/test"
    28  
    29  import (
    30  	"context"
    31  	"fmt"
    32  	"io"
    33  	"time"
    34  
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/grpclog"
    38  	"google.golang.org/grpc/metadata"
    39  	"google.golang.org/grpc/status"
    40  	"google.golang.org/protobuf/proto"
    41  
    42  	testpb "google.golang.org/grpc/interop/grpc_testing"
    43  )
    44  
    45  var (
    46  	reqSizes            = []int{27182, 8, 1828, 45904}
    47  	respSizes           = []int{31415, 9, 2653, 58979}
    48  	largeReqSize        = 271828
    49  	largeRespSize       = 314159
    50  	initialMetadataKey  = "x-grpc-test-echo-initial"
    51  	trailingMetadataKey = "x-grpc-test-echo-trailing-bin"
    52  
    53  	logger = grpclog.Component("interop")
    54  )
    55  
    56  // ClientNewPayload returns a payload of the given type and size.
    57  func ClientNewPayload(t testpb.PayloadType, size int) *testpb.Payload {
    58  	if size < 0 {
    59  		logger.Fatalf("Requested a response with invalid length %d", size)
    60  	}
    61  	body := make([]byte, size)
    62  	switch t {
    63  	case testpb.PayloadType_COMPRESSABLE:
    64  	default:
    65  		logger.Fatalf("Unsupported payload type: %d", t)
    66  	}
    67  	return &testpb.Payload{
    68  		Type: t,
    69  		Body: body,
    70  	}
    71  }
    72  
    73  // DoEmptyUnaryCall performs a unary RPC with empty request and response messages.
    74  func DoEmptyUnaryCall(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
    75  	reply, err := tc.EmptyCall(ctx, &testpb.Empty{}, args...)
    76  	if err != nil {
    77  		logger.Fatal("/TestService/EmptyCall RPC failed: ", err)
    78  	}
    79  	if !proto.Equal(&testpb.Empty{}, reply) {
    80  		logger.Fatalf("/TestService/EmptyCall receives %v, want %v", reply, testpb.Empty{})
    81  	}
    82  }
    83  
    84  // DoLargeUnaryCall performs a unary RPC with large payload in the request and response.
    85  func DoLargeUnaryCall(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
    86  	pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
    87  	req := &testpb.SimpleRequest{
    88  		ResponseType: testpb.PayloadType_COMPRESSABLE,
    89  		ResponseSize: int32(largeRespSize),
    90  		Payload:      pl,
    91  	}
    92  	reply, err := tc.UnaryCall(ctx, req, args...)
    93  	if err != nil {
    94  		logger.Fatal("/TestService/UnaryCall RPC failed: ", err)
    95  	}
    96  	t := reply.GetPayload().GetType()
    97  	s := len(reply.GetPayload().GetBody())
    98  	if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize {
    99  		logger.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize)
   100  	}
   101  }
   102  
   103  // DoClientStreaming performs a client streaming RPC.
   104  func DoClientStreaming(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
   105  	stream, err := tc.StreamingInputCall(ctx, args...)
   106  	if err != nil {
   107  		logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
   108  	}
   109  	var sum int
   110  	for _, s := range reqSizes {
   111  		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, s)
   112  		req := &testpb.StreamingInputCallRequest{
   113  			Payload: pl,
   114  		}
   115  		if err := stream.Send(req); err != nil {
   116  			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
   117  		}
   118  		sum += s
   119  	}
   120  	reply, err := stream.CloseAndRecv()
   121  	if err != nil {
   122  		logger.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil)
   123  	}
   124  	if reply.GetAggregatedPayloadSize() != int32(sum) {
   125  		logger.Fatalf("%v.CloseAndRecv().GetAggregatePayloadSize() = %v; want %v", stream, reply.GetAggregatedPayloadSize(), sum)
   126  	}
   127  }
   128  
   129  // DoServerStreaming performs a server streaming RPC.
   130  func DoServerStreaming(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
   131  	respParam := make([]*testpb.ResponseParameters, len(respSizes))
   132  	for i, s := range respSizes {
   133  		respParam[i] = &testpb.ResponseParameters{
   134  			Size: int32(s),
   135  		}
   136  	}
   137  	req := &testpb.StreamingOutputCallRequest{
   138  		ResponseType:       testpb.PayloadType_COMPRESSABLE,
   139  		ResponseParameters: respParam,
   140  	}
   141  	stream, err := tc.StreamingOutputCall(ctx, req, args...)
   142  	if err != nil {
   143  		logger.Fatalf("%v.StreamingOutputCall(_) = _, %v", tc, err)
   144  	}
   145  	var rpcStatus error
   146  	var respCnt int
   147  	var index int
   148  	for {
   149  		reply, err := stream.Recv()
   150  		if err != nil {
   151  			rpcStatus = err
   152  			break
   153  		}
   154  		t := reply.GetPayload().GetType()
   155  		if t != testpb.PayloadType_COMPRESSABLE {
   156  			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
   157  		}
   158  		size := len(reply.GetPayload().GetBody())
   159  		if size != respSizes[index] {
   160  			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
   161  		}
   162  		index++
   163  		respCnt++
   164  	}
   165  	if rpcStatus != io.EOF {
   166  		logger.Fatalf("Failed to finish the server streaming rpc: %v", rpcStatus)
   167  	}
   168  	if respCnt != len(respSizes) {
   169  		logger.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
   170  	}
   171  }
   172  
   173  // DoPingPong performs ping-pong style bi-directional streaming RPC.
   174  func DoPingPong(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
   175  	stream, err := tc.FullDuplexCall(ctx, args...)
   176  	if err != nil {
   177  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   178  	}
   179  	var index int
   180  	for index < len(reqSizes) {
   181  		respParam := []*testpb.ResponseParameters{
   182  			{
   183  				Size: int32(respSizes[index]),
   184  			},
   185  		}
   186  		pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, reqSizes[index])
   187  		req := &testpb.StreamingOutputCallRequest{
   188  			ResponseType:       testpb.PayloadType_COMPRESSABLE,
   189  			ResponseParameters: respParam,
   190  			Payload:            pl,
   191  		}
   192  		if err := stream.Send(req); err != nil {
   193  			logger.Fatalf("%v has error %v while sending %v", stream, err, req)
   194  		}
   195  		reply, err := stream.Recv()
   196  		if err != nil {
   197  			logger.Fatalf("%v.Recv() = %v", stream, err)
   198  		}
   199  		t := reply.GetPayload().GetType()
   200  		if t != testpb.PayloadType_COMPRESSABLE {
   201  			logger.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE)
   202  		}
   203  		size := len(reply.GetPayload().GetBody())
   204  		if size != respSizes[index] {
   205  			logger.Fatalf("Got reply body of length %d, want %d", size, respSizes[index])
   206  		}
   207  		index++
   208  	}
   209  	if err := stream.CloseSend(); err != nil {
   210  		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
   211  	}
   212  	if _, err := stream.Recv(); err != io.EOF {
   213  		logger.Fatalf("%v failed to complele the ping pong test: %v", stream, err)
   214  	}
   215  }
   216  
   217  // DoEmptyStream sets up a bi-directional streaming with zero message.
   218  func DoEmptyStream(ctx context.Context, tc testpb.TestServiceClient, args ...grpc.CallOption) {
   219  	stream, err := tc.FullDuplexCall(ctx, args...)
   220  	if err != nil {
   221  		logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
   222  	}
   223  	if err := stream.CloseSend(); err != nil {
   224  		logger.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
   225  	}
   226  	if _, err := stream.Recv(); err != io.EOF {
   227  		logger.Fatalf("%v failed to complete the empty stream test: %v", stream, err)
   228  	}
   229  }
   230  
   231  type testServer struct {
   232  	testpb.UnimplementedTestServiceServer
   233  }
   234  
   235  // NewTestServer creates a test server for test service.  opts carries optional
   236  // settings and does not need to be provided.  If multiple opts are provided,
   237  // only the first one is used.
   238  func NewTestServer() testpb.TestServiceServer {
   239  	return &testServer{}
   240  }
   241  
   242  func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   243  	return new(testpb.Empty), nil
   244  }
   245  
   246  func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) {
   247  	if size < 0 {
   248  		return nil, fmt.Errorf("requested a response with invalid length %d", size)
   249  	}
   250  	body := make([]byte, size)
   251  	switch t {
   252  	case testpb.PayloadType_COMPRESSABLE:
   253  	default:
   254  		return nil, fmt.Errorf("unsupported payload type: %d", t)
   255  	}
   256  	return &testpb.Payload{
   257  		Type: t,
   258  		Body: body,
   259  	}, nil
   260  }
   261  
   262  func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   263  	st := in.GetResponseStatus()
   264  	if md, ok := metadata.FromIncomingContext(ctx); ok {
   265  		if initialMetadata, ok := md[initialMetadataKey]; ok {
   266  			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
   267  			_ = grpc.SendHeader(ctx, header)
   268  		}
   269  		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
   270  			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
   271  			_ = grpc.SetTrailer(ctx, trailer)
   272  		}
   273  	}
   274  	if st != nil && st.Code != 0 {
   275  		return nil, status.Error(codes.Code(st.Code), st.Message)
   276  	}
   277  	pl, err := serverNewPayload(in.GetResponseType(), in.GetResponseSize())
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  	return &testpb.SimpleResponse{
   282  		Payload: pl,
   283  	}, nil
   284  }
   285  
   286  func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
   287  	cs := args.GetResponseParameters()
   288  	for _, c := range cs {
   289  		if us := c.GetIntervalUs(); us > 0 {
   290  			time.Sleep(time.Duration(us) * time.Microsecond)
   291  		}
   292  		pl, err := serverNewPayload(args.GetResponseType(), c.GetSize())
   293  		if err != nil {
   294  			return err
   295  		}
   296  		if err := stream.Send(&testpb.StreamingOutputCallResponse{
   297  			Payload: pl,
   298  		}); err != nil {
   299  			return err
   300  		}
   301  	}
   302  	return nil
   303  }
   304  
   305  func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error {
   306  	var sum int
   307  	for {
   308  		in, err := stream.Recv()
   309  		if err == io.EOF {
   310  			return stream.SendAndClose(&testpb.StreamingInputCallResponse{
   311  				AggregatedPayloadSize: int32(sum),
   312  			})
   313  		}
   314  		if err != nil {
   315  			return err
   316  		}
   317  		p := in.GetPayload().GetBody()
   318  		sum += len(p)
   319  	}
   320  }
   321  
   322  func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
   323  	if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
   324  		if initialMetadata, ok := md[initialMetadataKey]; ok {
   325  			header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
   326  			_ = stream.SendHeader(header)
   327  		}
   328  		if trailingMetadata, ok := md[trailingMetadataKey]; ok {
   329  			trailer := metadata.Pairs(trailingMetadataKey, trailingMetadata[0])
   330  			stream.SetTrailer(trailer)
   331  		}
   332  	}
   333  	for {
   334  		in, err := stream.Recv()
   335  		if err == io.EOF {
   336  			// read done.
   337  			return nil
   338  		}
   339  		if err != nil {
   340  			return err
   341  		}
   342  		st := in.GetResponseStatus()
   343  		if st != nil && st.Code != 0 {
   344  			return status.Error(codes.Code(st.Code), st.Message)
   345  		}
   346  
   347  		cs := in.GetResponseParameters()
   348  		for _, c := range cs {
   349  			if us := c.GetIntervalUs(); us > 0 {
   350  				time.Sleep(time.Duration(us) * time.Microsecond)
   351  			}
   352  			pl, err := serverNewPayload(in.GetResponseType(), c.GetSize())
   353  			if err != nil {
   354  				return err
   355  			}
   356  			if err := stream.Send(&testpb.StreamingOutputCallResponse{
   357  				Payload: pl,
   358  			}); err != nil {
   359  				return err
   360  			}
   361  		}
   362  	}
   363  }
   364  
   365  func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServer) error {
   366  	var msgBuf []*testpb.StreamingOutputCallRequest
   367  	for {
   368  		in, err := stream.Recv()
   369  		if err == io.EOF {
   370  			// read done.
   371  			break
   372  		}
   373  		if err != nil {
   374  			return err
   375  		}
   376  		msgBuf = append(msgBuf, in)
   377  	}
   378  	for _, m := range msgBuf {
   379  		cs := m.GetResponseParameters()
   380  		for _, c := range cs {
   381  			if us := c.GetIntervalUs(); us > 0 {
   382  				time.Sleep(time.Duration(us) * time.Microsecond)
   383  			}
   384  			pl, err := serverNewPayload(m.GetResponseType(), c.GetSize())
   385  			if err != nil {
   386  				return err
   387  			}
   388  			if err := stream.Send(&testpb.StreamingOutputCallResponse{
   389  				Payload: pl,
   390  			}); err != nil {
   391  				return err
   392  			}
   393  		}
   394  	}
   395  	return nil
   396  }
   397  

View as plain text