1
2
3
4 package grpc_opentracing_test
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "strconv"
13 "strings"
14 "testing"
15
16 "github.com/opentracing/opentracing-go"
17 "github.com/opentracing/opentracing-go/log"
18 "github.com/opentracing/opentracing-go/mocktracer"
19 "github.com/stretchr/testify/assert"
20 "github.com/stretchr/testify/require"
21 "github.com/stretchr/testify/suite"
22 "google.golang.org/grpc"
23 "google.golang.org/grpc/codes"
24
25 "github.com/grpc-ecosystem/go-grpc-middleware"
26 "github.com/grpc-ecosystem/go-grpc-middleware/tags"
27 "github.com/grpc-ecosystem/go-grpc-middleware/testing"
28 pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
29 "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
30 )
31
32 var (
33 goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
34 fakeInboundTraceId = 1337
35 fakeInboundSpanId = 999
36 traceHeaderName = "uber-trace-id"
37 filterFunc = func(ctx context.Context, fullMethodName string) bool { return true }
38 unaryRequestHandlerFunc = func(span opentracing.Span, req interface{}) {
39 span.LogFields(log.Bool("unary-request-handler", true))
40 }
41 )
42
43 type tracingAssertService struct {
44 pb_testproto.TestServiceServer
45 T *testing.T
46 }
47
48 func (s *tracingAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
49 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
50 tags := grpc_ctxtags.Extract(ctx)
51 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
52 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
53 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
54 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
55 return s.TestServiceServer.Ping(ctx, ping)
56 }
57
58 func (s *tracingAssertService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
59 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
60 return s.TestServiceServer.PingError(ctx, ping)
61 }
62
63 func (s *tracingAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
64 assert.NotNil(s.T, opentracing.SpanFromContext(stream.Context()), "handlers must have the spancontext in their context, otherwise propagation will fail")
65 tags := grpc_ctxtags.Extract(stream.Context())
66 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
67 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
68 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
69 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
70 return s.TestServiceServer.PingList(ping, stream)
71 }
72
73 func (s *tracingAssertService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) {
74 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
75 tags := grpc_ctxtags.Extract(ctx)
76 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
77 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
78 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
79 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "false", "sampled must be set to false")
80 return s.TestServiceServer.PingEmpty(ctx, empty)
81 }
82
83 func TestTaggingSuite(t *testing.T) {
84 mockTracer := mocktracer.New()
85 opts := []grpc_opentracing.Option{
86 grpc_opentracing.WithTracer(mockTracer),
87 grpc_opentracing.WithFilterFunc(filterFunc),
88 grpc_opentracing.WithTraceHeaderName(traceHeaderName),
89 grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc),
90 }
91 s := &OpentracingSuite{
92 mockTracer: mockTracer,
93 InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
94 }
95 suite.Run(t, s)
96 }
97
98 func TestTaggingSuiteJaeger(t *testing.T) {
99 mockTracer := mocktracer.New()
100 mockTracer.RegisterInjector(opentracing.HTTPHeaders, jaegerFormatInjector{})
101 mockTracer.RegisterExtractor(opentracing.HTTPHeaders, jaegerFormatExtractor{})
102 opts := []grpc_opentracing.Option{
103 grpc_opentracing.WithTracer(mockTracer),
104 grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc),
105 }
106 s := &OpentracingSuite{
107 mockTracer: mockTracer,
108 InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
109 }
110 suite.Run(t, s)
111 }
112
113 func makeInterceptorTestSuite(t *testing.T, opts []grpc_opentracing.Option) *grpc_testing.InterceptorTestSuite {
114 return &grpc_testing.InterceptorTestSuite{
115 TestService: &tracingAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}, T: t},
116 ClientOpts: []grpc.DialOption{
117 grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor(opts...)),
118 grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor(opts...)),
119 },
120 ServerOpts: []grpc.ServerOption{
121 grpc_middleware.WithStreamServerChain(
122 grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
123 grpc_opentracing.StreamServerInterceptor(opts...)),
124 grpc_middleware.WithUnaryServerChain(
125 grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
126 grpc_opentracing.UnaryServerInterceptor(opts...)),
127 },
128 }
129 }
130
131 type OpentracingSuite struct {
132 *grpc_testing.InterceptorTestSuite
133 mockTracer *mocktracer.MockTracer
134 }
135
136 func (s *OpentracingSuite) SetupTest() {
137 s.mockTracer.Reset()
138 }
139
140 func (s *OpentracingSuite) createContextFromFakeHttpRequestParent(ctx context.Context, sampled bool, opName string) context.Context {
141 jFlag := 0
142 if sampled {
143 jFlag = 1
144 }
145
146 if len(opName) == 0 {
147 opName = "/fake/parent/http/request"
148 }
149
150 hdr := http.Header{}
151 hdr.Set(traceHeaderName, fmt.Sprintf("%d:%d:%d:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, jFlag))
152 hdr.Set("mockpfx-ids-traceid", fmt.Sprint(fakeInboundTraceId))
153 hdr.Set("mockpfx-ids-spanid", fmt.Sprint(fakeInboundSpanId))
154 hdr.Set("mockpfx-ids-sampled", fmt.Sprint(sampled))
155
156 parentSpanContext, err := s.mockTracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(hdr))
157 require.NoError(s.T(), err, "parsing a fake HTTP request headers shouldn't fail, ever")
158 fakeSpan := s.mockTracer.StartSpan(
159 opName,
160
161 opentracing.ChildOf(parentSpanContext),
162 )
163 fakeSpan.Finish()
164 return opentracing.ContextWithSpan(ctx, fakeSpan)
165 }
166
167 func (s *OpentracingSuite) assertTracesCreated(methodName string) (clientSpan *mocktracer.MockSpan, serverSpan *mocktracer.MockSpan) {
168 spans := s.mockTracer.FinishedSpans()
169 for _, span := range spans {
170 s.T().Logf("span: %v, tags: %v", span, span.Tags())
171 }
172 require.Len(s.T(), spans, 3, "should record 3 spans: one fake inbound, one client, one server")
173 traceIdAssert := fmt.Sprintf("traceId=%d", fakeInboundTraceId)
174 for _, span := range spans {
175 assert.Contains(s.T(), span.String(), traceIdAssert, "not part of the fake parent trace: %v", span)
176 if span.OperationName == methodName {
177 kind := fmt.Sprintf("%v", span.Tag("span.kind"))
178 if kind == "client" {
179 clientSpan = span
180 } else if kind == "server" {
181 serverSpan = span
182 }
183 assert.EqualValues(s.T(), span.Tag("component"), "gRPC", "span must be tagged with gRPC component")
184 }
185 }
186 require.NotNil(s.T(), clientSpan, "client span must be there")
187 require.NotNil(s.T(), serverSpan, "server span must be there")
188 assert.EqualValues(s.T(), serverSpan.Tag("grpc.request.value"), "something", "grpc_ctxtags must be propagated, in this case ones from request fields")
189 return clientSpan, serverSpan
190 }
191
192 func (s *OpentracingSuite) TestPing_PropagatesTraces() {
193 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
194 _, err := s.Client.Ping(ctx, goodPing)
195 require.NoError(s.T(), err, "there must be not be an on a successful call")
196 s.assertTracesCreated("/mwitkow.testproto.TestService/Ping")
197 }
198
199 func (s *OpentracingSuite) TestPing_CustomOpName() {
200 customOpName := "customOpName"
201
202 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, customOpName)
203 _, err := s.Client.Ping(ctx, goodPing)
204 require.NoError(s.T(), err, "there must be not be an error on a successful call")
205
206 spans := s.mockTracer.FinishedSpans()
207 spanOpNames := make([]string, len(spans))
208 for _, span := range spans {
209 spanOpNames = append(spanOpNames, span.OperationName)
210 }
211
212 require.Contains(s.T(), spanOpNames, customOpName, "finished spans must contain the custom operation name")
213
214 }
215
216 func (s *OpentracingSuite) TestPing_WithUnaryRequestHandlerFunc() {
217 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
218 _, err := s.Client.Ping(ctx, goodPing)
219 require.NoError(s.T(), err, "there must be not be an on a successful call")
220
221 var hasLogKey bool
222 Loop:
223 for _, span := range s.mockTracer.FinishedSpans() {
224 for _, record := range span.Logs() {
225 for _, field := range record.Fields {
226 if field.Key == "unary-request-handler" {
227 hasLogKey = true
228 break Loop
229 }
230 }
231 }
232 }
233 require.True(s.T(), hasLogKey, "span field 'unary-request-handler' not found")
234 }
235
236 func (s *OpentracingSuite) TestPing_ClientContextTags() {
237 const name = "opentracing.custom"
238 ctx := grpc_opentracing.ClientAddContextTags(
239 s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, ""),
240 opentracing.Tags{name: ""},
241 )
242
243 _, err := s.Client.Ping(ctx, goodPing)
244 require.NoError(s.T(), err, "there must be not be an on a successful call")
245
246 for _, span := range s.mockTracer.FinishedSpans() {
247 if span.OperationName == "/mwitkow.testproto.TestService/Ping" {
248 kind := fmt.Sprintf("%v", span.Tag("span.kind"))
249 if kind == "client" {
250 assert.Contains(s.T(), span.Tags(), name, "custom opentracing.Tags must be included in context")
251 }
252 }
253 }
254 }
255
256 func (s *OpentracingSuite) TestPingList_PropagatesTraces() {
257 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
258 stream, err := s.Client.PingList(ctx, goodPing)
259 require.NoError(s.T(), err, "should not fail on establishing the stream")
260 for {
261 _, err := stream.Recv()
262 if err == io.EOF {
263 break
264 }
265 require.NoError(s.T(), err, "reading stream should not fail")
266 }
267 s.assertTracesCreated("/mwitkow.testproto.TestService/PingList")
268 }
269
270 func (s *OpentracingSuite) TestPingError_PropagatesTraces() {
271 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "")
272 erroringPing := &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(codes.OutOfRange)}
273 _, err := s.Client.PingError(ctx, erroringPing)
274 require.Error(s.T(), err, "there must be an error returned here")
275 clientSpan, serverSpan := s.assertTracesCreated("/mwitkow.testproto.TestService/PingError")
276 assert.Equal(s.T(), true, clientSpan.Tag("error"), "client span needs to be marked as an error")
277 assert.Equal(s.T(), true, serverSpan.Tag("error"), "server span needs to be marked as an error")
278 }
279
280 func (s *OpentracingSuite) TestPingEmpty_NotSampleTraces() {
281 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), false, "")
282 _, err := s.Client.PingEmpty(ctx, &pb_testproto.Empty{})
283 require.NoError(s.T(), err, "there must be not be an on a successful call")
284 }
285
286 type jaegerFormatInjector struct{}
287
288 func (jaegerFormatInjector) Inject(ctx mocktracer.MockSpanContext, carrier interface{}) error {
289 w := carrier.(opentracing.TextMapWriter)
290 flags := 0
291 if ctx.Sampled {
292 flags = 1
293 }
294 w.Set(traceHeaderName, fmt.Sprintf("%d:%d::%d", ctx.TraceID, ctx.SpanID, flags))
295
296 return nil
297 }
298
299 type jaegerFormatExtractor struct{}
300
301 func (jaegerFormatExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) {
302 rval := mocktracer.MockSpanContext{Sampled: true}
303 reader, ok := carrier.(opentracing.TextMapReader)
304 if !ok {
305 return rval, opentracing.ErrInvalidCarrier
306 }
307 err := reader.ForeachKey(func(key, val string) error {
308 lowerKey := strings.ToLower(key)
309 switch {
310 case lowerKey == traceHeaderName:
311 parts := strings.Split(val, ":")
312 if len(parts) != 4 {
313 return errors.New("invalid trace id format")
314 }
315 traceId, err := strconv.Atoi(parts[0])
316 if err != nil {
317 return err
318 }
319 rval.TraceID = traceId
320 spanId, err := strconv.Atoi(parts[1])
321 if err != nil {
322 return err
323 }
324 rval.SpanID = spanId
325 flags, err := strconv.Atoi(parts[3])
326 if err != nil {
327 return err
328 }
329 rval.Sampled = flags%2 == 1
330 }
331 return nil
332 })
333 if rval.TraceID == 0 || rval.SpanID == 0 {
334 return rval, opentracing.ErrSpanContextNotFound
335 }
336 if err != nil {
337 return rval, err
338 }
339 return rval, nil
340 }
View as plain text