...

Source file src/github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing/interceptors_test.go

Documentation: github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing

     1  // Copyright 2017 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_opentracing_test
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/opentracing/opentracing-go"
    17  	"github.com/opentracing/opentracing-go/log"
    18  	"github.com/opentracing/opentracing-go/mocktracer"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/stretchr/testify/suite"
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/codes"
    24  
    25  	"github.com/grpc-ecosystem/go-grpc-middleware"
    26  	"github.com/grpc-ecosystem/go-grpc-middleware/tags"
    27  	"github.com/grpc-ecosystem/go-grpc-middleware/testing"
    28  	pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
    29  	"github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
    30  )
    31  
    32  var (
    33  	goodPing                = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
    34  	fakeInboundTraceId      = 1337
    35  	fakeInboundSpanId       = 999
    36  	traceHeaderName         = "uber-trace-id"
    37  	filterFunc              = func(ctx context.Context, fullMethodName string) bool { return true }
    38  	unaryRequestHandlerFunc = func(span opentracing.Span, req interface{}) {
    39  		span.LogFields(log.Bool("unary-request-handler", true))
    40  	}
    41  )
    42  
    43  type tracingAssertService struct {
    44  	pb_testproto.TestServiceServer
    45  	T *testing.T
    46  }
    47  
    48  func (s *tracingAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
    49  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    50  	tags := grpc_ctxtags.Extract(ctx)
    51  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    52  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    53  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    54  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
    55  	return s.TestServiceServer.Ping(ctx, ping)
    56  }
    57  
    58  func (s *tracingAssertService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
    59  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    60  	return s.TestServiceServer.PingError(ctx, ping)
    61  }
    62  
    63  func (s *tracingAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
    64  	assert.NotNil(s.T, opentracing.SpanFromContext(stream.Context()), "handlers must have the spancontext in their context, otherwise propagation will fail")
    65  	tags := grpc_ctxtags.Extract(stream.Context())
    66  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    67  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    68  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    69  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
    70  	return s.TestServiceServer.PingList(ping, stream)
    71  }
    72  
    73  func (s *tracingAssertService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) {
    74  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    75  	tags := grpc_ctxtags.Extract(ctx)
    76  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    77  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    78  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    79  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "false", "sampled must be set to false")
    80  	return s.TestServiceServer.PingEmpty(ctx, empty)
    81  }
    82  
    83  func TestTaggingSuite(t *testing.T) {
    84  	mockTracer := mocktracer.New()
    85  	opts := []grpc_opentracing.Option{
    86  		grpc_opentracing.WithTracer(mockTracer),
    87  		grpc_opentracing.WithFilterFunc(filterFunc),
    88  		grpc_opentracing.WithTraceHeaderName(traceHeaderName),
    89  		grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc),
    90  	}
    91  	s := &OpentracingSuite{
    92  		mockTracer:           mockTracer,
    93  		InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
    94  	}
    95  	suite.Run(t, s)
    96  }
    97  
    98  func TestTaggingSuiteJaeger(t *testing.T) {
    99  	mockTracer := mocktracer.New()
   100  	mockTracer.RegisterInjector(opentracing.HTTPHeaders, jaegerFormatInjector{})
   101  	mockTracer.RegisterExtractor(opentracing.HTTPHeaders, jaegerFormatExtractor{})
   102  	opts := []grpc_opentracing.Option{
   103  		grpc_opentracing.WithTracer(mockTracer),
   104  		grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc),
   105  	}
   106  	s := &OpentracingSuite{
   107  		mockTracer:           mockTracer,
   108  		InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
   109  	}
   110  	suite.Run(t, s)
   111  }
   112  
   113  func makeInterceptorTestSuite(t *testing.T, opts []grpc_opentracing.Option) *grpc_testing.InterceptorTestSuite {
   114  	return &grpc_testing.InterceptorTestSuite{
   115  		TestService: &tracingAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}, T: t},
   116  		ClientOpts: []grpc.DialOption{
   117  			grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor(opts...)),
   118  			grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor(opts...)),
   119  		},
   120  		ServerOpts: []grpc.ServerOption{
   121  			grpc_middleware.WithStreamServerChain(
   122  				grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
   123  				grpc_opentracing.StreamServerInterceptor(opts...)),
   124  			grpc_middleware.WithUnaryServerChain(
   125  				grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
   126  				grpc_opentracing.UnaryServerInterceptor(opts...)),
   127  		},
   128  	}
   129  }
   130  
   131  type OpentracingSuite struct {
   132  	*grpc_testing.InterceptorTestSuite
   133  	mockTracer *mocktracer.MockTracer
   134  }
   135  
   136  func (s *OpentracingSuite) SetupTest() {
   137  	s.mockTracer.Reset()
   138  }
   139  
   140  func (s *OpentracingSuite) createContextFromFakeHttpRequestParent(ctx context.Context, sampled bool, opName string) context.Context {
   141  	jFlag := 0
   142  	if sampled {
   143  		jFlag = 1
   144  	}
   145  
   146  	if len(opName) == 0 {
   147  		opName = "/fake/parent/http/request"
   148  	}
   149  
   150  	hdr := http.Header{}
   151  	hdr.Set(traceHeaderName, fmt.Sprintf("%d:%d:%d:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, jFlag))
   152  	hdr.Set("mockpfx-ids-traceid", fmt.Sprint(fakeInboundTraceId))
   153  	hdr.Set("mockpfx-ids-spanid", fmt.Sprint(fakeInboundSpanId))
   154  	hdr.Set("mockpfx-ids-sampled", fmt.Sprint(sampled))
   155  
   156  	parentSpanContext, err := s.mockTracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(hdr))
   157  	require.NoError(s.T(), err, "parsing a fake HTTP request headers shouldn't fail, ever")
   158  	fakeSpan := s.mockTracer.StartSpan(
   159  		opName,
   160  		// this is magical, it attaches the new span to the parent parentSpanContext, and creates an unparented one if empty.
   161  		opentracing.ChildOf(parentSpanContext),
   162  	)
   163  	fakeSpan.Finish()
   164  	return opentracing.ContextWithSpan(ctx, fakeSpan)
   165  }
   166  
   167  func (s *OpentracingSuite) assertTracesCreated(methodName string) (clientSpan *mocktracer.MockSpan, serverSpan *mocktracer.MockSpan) {
   168  	spans := s.mockTracer.FinishedSpans()
   169  	for _, span := range spans {
   170  		s.T().Logf("span: %v, tags: %v", span, span.Tags())
   171  	}
   172  	require.Len(s.T(), spans, 3, "should record 3 spans: one fake inbound, one client, one server")
   173  	traceIdAssert := fmt.Sprintf("traceId=%d", fakeInboundTraceId)
   174  	for _, span := range spans {
   175  		assert.Contains(s.T(), span.String(), traceIdAssert, "not part of the fake parent trace: %v", span)
   176  		if span.OperationName == methodName {
   177  			kind := fmt.Sprintf("%v", span.Tag("span.kind"))
   178  			if kind == "client" {
   179  				clientSpan = span
   180  			} else if kind == "server" {
   181  				serverSpan = span
   182  			}
   183  			assert.EqualValues(s.T(), span.Tag("component"), "gRPC", "span must be tagged with gRPC component")
   184  		}
   185  	}
   186  	require.NotNil(s.T(), clientSpan, "client span must be there")
   187  	require.NotNil(s.T(), serverSpan, "server span must be there")
   188  	assert.EqualValues(s.T(), serverSpan.Tag("grpc.request.value"), "something", "grpc_ctxtags must be propagated, in this case ones from request fields")
   189  	return clientSpan, serverSpan
   190  }
   191  
   192  func (s *OpentracingSuite) TestPing_PropagatesTraces() {
   193  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
   194  	_, err := s.Client.Ping(ctx, goodPing)
   195  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   196  	s.assertTracesCreated("/mwitkow.testproto.TestService/Ping")
   197  }
   198  
   199  func (s *OpentracingSuite) TestPing_CustomOpName() {
   200  	customOpName := "customOpName"
   201  
   202  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, customOpName)
   203  	_, err := s.Client.Ping(ctx, goodPing)
   204  	require.NoError(s.T(), err, "there must be not be an error on a successful call")
   205  
   206  	spans := s.mockTracer.FinishedSpans()
   207  	spanOpNames := make([]string, len(spans))
   208  	for _, span := range spans {
   209  		spanOpNames = append(spanOpNames, span.OperationName)
   210  	}
   211  
   212  	require.Contains(s.T(), spanOpNames, customOpName, "finished spans must contain the custom operation name")
   213  
   214  }
   215  
   216  func (s *OpentracingSuite) TestPing_WithUnaryRequestHandlerFunc() {
   217  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
   218  	_, err := s.Client.Ping(ctx, goodPing)
   219  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   220  
   221  	var hasLogKey bool
   222  Loop:
   223  	for _, span := range s.mockTracer.FinishedSpans() {
   224  		for _, record := range span.Logs() {
   225  			for _, field := range record.Fields {
   226  				if field.Key == "unary-request-handler" {
   227  					hasLogKey = true
   228  					break Loop
   229  				}
   230  			}
   231  		}
   232  	}
   233  	require.True(s.T(), hasLogKey, "span field 'unary-request-handler' not found")
   234  }
   235  
   236  func (s *OpentracingSuite) TestPing_ClientContextTags() {
   237  	const name = "opentracing.custom"
   238  	ctx := grpc_opentracing.ClientAddContextTags(
   239  		s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, ""),
   240  		opentracing.Tags{name: ""},
   241  	)
   242  
   243  	_, err := s.Client.Ping(ctx, goodPing)
   244  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   245  
   246  	for _, span := range s.mockTracer.FinishedSpans() {
   247  		if span.OperationName == "/mwitkow.testproto.TestService/Ping" {
   248  			kind := fmt.Sprintf("%v", span.Tag("span.kind"))
   249  			if kind == "client" {
   250  				assert.Contains(s.T(), span.Tags(), name, "custom opentracing.Tags must be included in context")
   251  			}
   252  		}
   253  	}
   254  }
   255  
   256  func (s *OpentracingSuite) TestPingList_PropagatesTraces() {
   257  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
   258  	stream, err := s.Client.PingList(ctx, goodPing)
   259  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   260  	for {
   261  		_, err := stream.Recv()
   262  		if err == io.EOF {
   263  			break
   264  		}
   265  		require.NoError(s.T(), err, "reading stream should not fail")
   266  	}
   267  	s.assertTracesCreated("/mwitkow.testproto.TestService/PingList")
   268  }
   269  
   270  func (s *OpentracingSuite) TestPingError_PropagatesTraces() {
   271  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
   272  	erroringPing := &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(codes.OutOfRange)}
   273  	_, err := s.Client.PingError(ctx, erroringPing)
   274  	require.Error(s.T(), err, "there must be an error returned here")
   275  	clientSpan, serverSpan := s.assertTracesCreated("/mwitkow.testproto.TestService/PingError")
   276  	assert.Equal(s.T(), true, clientSpan.Tag("error"), "client span needs to be marked as an error")
   277  	assert.Equal(s.T(), true, serverSpan.Tag("error"), "server span needs to be marked as an error")
   278  }
   279  
   280  func (s *OpentracingSuite) TestPingEmpty_NotSampleTraces() {
   281  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), false, "")
   282  	_, err := s.Client.PingEmpty(ctx, &pb_testproto.Empty{})
   283  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   284  }
   285  
   286  type jaegerFormatInjector struct{}
   287  
   288  func (jaegerFormatInjector) Inject(ctx mocktracer.MockSpanContext, carrier interface{}) error {
   289  	w := carrier.(opentracing.TextMapWriter)
   290  	flags := 0
   291  	if ctx.Sampled {
   292  		flags = 1
   293  	}
   294  	w.Set(traceHeaderName, fmt.Sprintf("%d:%d::%d", ctx.TraceID, ctx.SpanID, flags))
   295  
   296  	return nil
   297  }
   298  
   299  type jaegerFormatExtractor struct{}
   300  
   301  func (jaegerFormatExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) {
   302  	rval := mocktracer.MockSpanContext{Sampled: true}
   303  	reader, ok := carrier.(opentracing.TextMapReader)
   304  	if !ok {
   305  		return rval, opentracing.ErrInvalidCarrier
   306  	}
   307  	err := reader.ForeachKey(func(key, val string) error {
   308  		lowerKey := strings.ToLower(key)
   309  		switch {
   310  		case lowerKey == traceHeaderName:
   311  			parts := strings.Split(val, ":")
   312  			if len(parts) != 4 {
   313  				return errors.New("invalid trace id format")
   314  			}
   315  			traceId, err := strconv.Atoi(parts[0])
   316  			if err != nil {
   317  				return err
   318  			}
   319  			rval.TraceID = traceId
   320  			spanId, err := strconv.Atoi(parts[1])
   321  			if err != nil {
   322  				return err
   323  			}
   324  			rval.SpanID = spanId
   325  			flags, err := strconv.Atoi(parts[3])
   326  			if err != nil {
   327  				return err
   328  			}
   329  			rval.Sampled = flags%2 == 1
   330  		}
   331  		return nil
   332  	})
   333  	if rval.TraceID == 0 || rval.SpanID == 0 {
   334  		return rval, opentracing.ErrSpanContextNotFound
   335  	}
   336  	if err != nil {
   337  		return rval, err
   338  	}
   339  	return rval, nil
   340  }

View as plain text