1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ochttp
16
17 import (
18 "bytes"
19 "context"
20 "encoding/hex"
21 "encoding/json"
22 "errors"
23 "fmt"
24 "io"
25 "io/ioutil"
26 "log"
27 "net"
28 "net/http"
29 "net/http/httptest"
30 "net/url"
31 "reflect"
32 "strings"
33 "testing"
34 "time"
35
36 "go.opencensus.io/plugin/ochttp/propagation/b3"
37 "go.opencensus.io/plugin/ochttp/propagation/tracecontext"
38 "go.opencensus.io/trace"
39 )
40
41 type testExporter struct {
42 spans []*trace.SpanData
43 }
44
45 func (t *testExporter) ExportSpan(s *trace.SpanData) {
46 t.spans = append(t.spans, s)
47 }
48
49 type testTransport struct {
50 ch chan *http.Request
51 }
52
53 func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
54 t.ch <- req
55 return nil, errors.New("noop")
56 }
57
58 type testPropagator struct{}
59
60 func (t testPropagator) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) {
61 header := req.Header.Get("trace")
62 buf, err := hex.DecodeString(header)
63 if err != nil {
64 log.Fatalf("Cannot decode trace header: %q", header)
65 }
66 r := bytes.NewReader(buf)
67 r.Read(sc.TraceID[:])
68 r.Read(sc.SpanID[:])
69 opts, err := r.ReadByte()
70 if err != nil {
71 log.Fatalf("Cannot read trace options from trace header: %q", header)
72 }
73 sc.TraceOptions = trace.TraceOptions(opts)
74 return sc, true
75 }
76
77 func (t testPropagator) SpanContextToRequest(sc trace.SpanContext, req *http.Request) {
78 var buf bytes.Buffer
79 buf.Write(sc.TraceID[:])
80 buf.Write(sc.SpanID[:])
81 buf.WriteByte(byte(sc.TraceOptions))
82 req.Header.Set("trace", hex.EncodeToString(buf.Bytes()))
83 }
84
85 func TestTransport_RoundTrip_Race(t *testing.T) {
86
87
88
89
90
91
92 transport := &testTransport{ch: make(chan *http.Request, 1)}
93 rt := &Transport{
94 Propagation: &testPropagator{},
95 Base: transport,
96 }
97 req, _ := http.NewRequest("GET", "http://foo.com", nil)
98 go func() {
99 fmt.Println(*req)
100 }()
101 rt.RoundTrip(req)
102 _ = <-transport.ch
103 }
104
105 func TestTransport_RoundTrip(t *testing.T) {
106 _, parent := trace.StartSpan(context.Background(), "parent")
107 tests := []struct {
108 name string
109 parent *trace.Span
110 }{
111 {
112 name: "no parent",
113 parent: nil,
114 },
115 {
116 name: "parent",
117 parent: parent,
118 },
119 }
120
121 for _, tt := range tests {
122 t.Run(tt.name, func(t *testing.T) {
123 transport := &testTransport{ch: make(chan *http.Request, 1)}
124
125 rt := &Transport{
126 Propagation: &testPropagator{},
127 Base: transport,
128 }
129
130 req, _ := http.NewRequest("GET", "http://foo.com", nil)
131 if tt.parent != nil {
132 req = req.WithContext(trace.NewContext(req.Context(), tt.parent))
133 }
134 rt.RoundTrip(req)
135
136 req = <-transport.ch
137 span := trace.FromContext(req.Context())
138
139 if header := req.Header.Get("trace"); header == "" {
140 t.Fatalf("Trace header = empty; want valid trace header")
141 }
142 if span == nil {
143 t.Fatalf("Got no spans in req context; want one")
144 }
145 if tt.parent != nil {
146 if got, want := span.SpanContext().TraceID, tt.parent.SpanContext().TraceID; got != want {
147 t.Errorf("span.SpanContext().TraceID=%v; want %v", got, want)
148 }
149 }
150 })
151 }
152 }
153
154 func TestHandler(t *testing.T) {
155 traceID := [16]byte{16, 84, 69, 170, 120, 67, 188, 139, 242, 6, 177, 32, 0, 16, 0, 0}
156 tests := []struct {
157 header string
158 wantTraceID trace.TraceID
159 wantTraceOptions trace.TraceOptions
160 }{
161 {
162 header: "105445aa7843bc8bf206b12000100000000000000000000000",
163 wantTraceID: traceID,
164 wantTraceOptions: trace.TraceOptions(0),
165 },
166 {
167 header: "105445aa7843bc8bf206b12000100000000000000000000001",
168 wantTraceID: traceID,
169 wantTraceOptions: trace.TraceOptions(1),
170 },
171 }
172
173 for _, tt := range tests {
174 t.Run(tt.header, func(t *testing.T) {
175 handler := &Handler{
176 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
177 span := trace.FromContext(r.Context())
178 sc := span.SpanContext()
179 if got, want := sc.TraceID, tt.wantTraceID; got != want {
180 t.Errorf("TraceID = %q; want %q", got, want)
181 }
182 if got, want := sc.TraceOptions, tt.wantTraceOptions; got != want {
183 t.Errorf("TraceOptions = %v; want %v", got, want)
184 }
185 }),
186 StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)},
187 Propagation: &testPropagator{},
188 }
189 req, _ := http.NewRequest("GET", "http://foo.com", nil)
190 req.Header.Add("trace", tt.header)
191 handler.ServeHTTP(nil, req)
192 })
193 }
194 }
195
196 var _ http.RoundTripper = (*traceTransport)(nil)
197
198 type collector []*trace.SpanData
199
200 func (c *collector) ExportSpan(s *trace.SpanData) {
201 *c = append(*c, s)
202 }
203
204 func TestEndToEnd(t *testing.T) {
205 tc := []struct {
206 name string
207 handler *Handler
208 transport *Transport
209 wantSameTraceID bool
210 wantLinks bool
211 }{
212 {
213 name: "internal default propagation",
214 handler: &Handler{},
215 transport: &Transport{},
216 wantSameTraceID: true,
217 },
218 {
219 name: "external default propagation",
220 handler: &Handler{IsPublicEndpoint: true},
221 transport: &Transport{},
222 wantSameTraceID: false,
223 wantLinks: true,
224 },
225 {
226 name: "internal TraceContext propagation",
227 handler: &Handler{Propagation: &tracecontext.HTTPFormat{}},
228 transport: &Transport{Propagation: &tracecontext.HTTPFormat{}},
229 wantSameTraceID: true,
230 },
231 {
232 name: "misconfigured propagation",
233 handler: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}},
234 transport: &Transport{Propagation: &b3.HTTPFormat{}},
235 wantSameTraceID: false,
236 wantLinks: false,
237 },
238 }
239
240 for _, tt := range tc {
241 t.Run(tt.name, func(t *testing.T) {
242 var spans collector
243 trace.RegisterExporter(&spans)
244 defer trace.UnregisterExporter(&spans)
245
246
247 serverDone := make(chan struct{})
248 serverReturn := make(chan time.Time)
249 tt.handler.StartOptions.Sampler = trace.AlwaysSample()
250 url := serveHTTP(tt.handler, serverDone, serverReturn, 200)
251
252 ctx := context.Background()
253
254 req, err := http.NewRequest(
255 http.MethodPost,
256 fmt.Sprintf("%s/example/url/path?qparam=val", url),
257 strings.NewReader("expected-request-body"))
258 if err != nil {
259 t.Fatal(err)
260 }
261 req = req.WithContext(ctx)
262 tt.transport.StartOptions.Sampler = trace.AlwaysSample()
263 c := &http.Client{
264 Transport: tt.transport,
265 }
266 resp, err := c.Do(req)
267 if err != nil {
268 t.Fatal(err)
269 }
270 if resp.StatusCode != http.StatusOK {
271 t.Fatalf("resp.StatusCode = %d", resp.StatusCode)
272 }
273
274
275 serverReturn <- time.Now().Add(time.Millisecond)
276
277 respBody, err := ioutil.ReadAll(resp.Body)
278 if err != nil {
279 t.Fatal(err)
280 }
281 if got, want := string(respBody), "expected-response"; got != want {
282 t.Fatalf("respBody = %q; want %q", got, want)
283 }
284
285 resp.Body.Close()
286
287 <-serverDone
288 trace.UnregisterExporter(&spans)
289
290 if got, want := len(spans), 2; got != want {
291 t.Fatalf("len(spans) = %d; want %d", got, want)
292 }
293
294 var client, server *trace.SpanData
295 for _, sp := range spans {
296 switch sp.SpanKind {
297 case trace.SpanKindClient:
298 client = sp
299 if got, want := client.Name, "/example/url/path"; got != want {
300 t.Errorf("Span name: %q; want %q", got, want)
301 }
302 case trace.SpanKindServer:
303 server = sp
304 if got, want := server.Name, "/example/url/path"; got != want {
305 t.Errorf("Span name: %q; want %q", got, want)
306 }
307 default:
308 t.Fatalf("server or client span missing; kind = %v", sp.SpanKind)
309 }
310 }
311
312 if tt.wantSameTraceID {
313 if server.TraceID != client.TraceID {
314 t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID)
315 }
316 if !server.HasRemoteParent {
317 t.Errorf("server span should have remote parent")
318 }
319 if server.ParentSpanID != client.SpanID {
320 t.Errorf("server span should have client span as parent")
321 }
322 }
323 if !tt.wantSameTraceID {
324 if server.TraceID == client.TraceID {
325 t.Errorf("TraceID should not be trusted")
326 }
327 }
328 if tt.wantLinks {
329 if got, want := len(server.Links), 1; got != want {
330 t.Errorf("len(server.Links) = %d; want %d", got, want)
331 } else {
332 link := server.Links[0]
333 if got, want := link.Type, trace.LinkTypeParent; got != want {
334 t.Errorf("link.Type = %v; want %v", got, want)
335 }
336 }
337 }
338 if server.StartTime.Before(client.StartTime) {
339 t.Errorf("server span starts before client span")
340 }
341 if server.EndTime.After(client.EndTime) {
342 t.Errorf("client span ends before server span")
343 }
344 })
345 }
346 }
347
348 func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time, statusCode int) string {
349 handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350 w.WriteHeader(statusCode)
351 w.(http.Flusher).Flush()
352
353
354 sleepUntil := <-wait
355 for time.Now().Before(sleepUntil) {
356 time.Sleep(time.Until(sleepUntil))
357 }
358
359 io.WriteString(w, "expected-response")
360 close(done)
361 })
362 server := httptest.NewServer(handler)
363 go func() {
364 <-done
365 server.Close()
366 }()
367 return server.URL
368 }
369
370 func TestSpanNameFromURL(t *testing.T) {
371 tests := []struct {
372 u string
373 want string
374 }{
375 {
376 u: "http://localhost:80/hello?q=a",
377 want: "/hello",
378 },
379 {
380 u: "/a/b?q=c",
381 want: "/a/b",
382 },
383 }
384 for _, tt := range tests {
385 t.Run(tt.u, func(t *testing.T) {
386 req, err := http.NewRequest("GET", tt.u, nil)
387 if err != nil {
388 t.Errorf("url issue = %v", err)
389 }
390 if got := spanNameFromURL(req); got != tt.want {
391 t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
392 }
393 })
394 }
395 }
396
397 func TestFormatSpanName(t *testing.T) {
398 formatSpanName := func(r *http.Request) string {
399 return r.Method + " " + r.URL.Path
400 }
401
402 handler := &Handler{
403 Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
404 resp.Write([]byte("Hello, world!"))
405 }),
406 FormatSpanName: formatSpanName,
407 }
408
409 server := httptest.NewServer(handler)
410 defer server.Close()
411
412 client := &http.Client{
413 Transport: &Transport{
414 FormatSpanName: formatSpanName,
415 StartOptions: trace.StartOptions{
416 Sampler: trace.AlwaysSample(),
417 },
418 },
419 }
420
421 tests := []struct {
422 u string
423 want string
424 }{
425 {
426 u: "/hello?q=a",
427 want: "GET /hello",
428 },
429 {
430 u: "/a/b?q=c",
431 want: "GET /a/b",
432 },
433 }
434
435 for _, tt := range tests {
436 t.Run(tt.u, func(t *testing.T) {
437 var te testExporter
438 trace.RegisterExporter(&te)
439 res, err := client.Get(server.URL + tt.u)
440 if err != nil {
441 t.Fatalf("error creating request: %v", err)
442 }
443 res.Body.Close()
444 trace.UnregisterExporter(&te)
445 if want, got := 2, len(te.spans); want != got {
446 t.Fatalf("got exported spans %#v, wanted two spans", te.spans)
447 }
448 if got := te.spans[0].Name; got != tt.want {
449 t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
450 }
451 if got := te.spans[1].Name; got != tt.want {
452 t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want)
453 }
454 })
455 }
456 }
457
458 func TestRequestAttributes(t *testing.T) {
459 tests := []struct {
460 name string
461 makeReq func() *http.Request
462 wantAttrs []trace.Attribute
463 }{
464 {
465 name: "GET example.com/hello",
466 makeReq: func() *http.Request {
467 req, _ := http.NewRequest("GET", "http://example.com:779/hello", nil)
468 req.Header.Add("User-Agent", "ua")
469 return req
470 },
471 wantAttrs: []trace.Attribute{
472 trace.StringAttribute("http.path", "/hello"),
473 trace.StringAttribute("http.url", "http://example.com:779/hello"),
474 trace.StringAttribute("http.host", "example.com:779"),
475 trace.StringAttribute("http.method", "GET"),
476 trace.StringAttribute("http.user_agent", "ua"),
477 },
478 },
479 }
480
481 for _, tt := range tests {
482 t.Run(tt.name, func(t *testing.T) {
483 req := tt.makeReq()
484 attrs := requestAttrs(req)
485
486 if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) {
487 t.Errorf("Request attributes = %#v; want %#v", got, want)
488 }
489 })
490 }
491 }
492
493 func TestResponseAttributes(t *testing.T) {
494 tests := []struct {
495 name string
496 resp *http.Response
497 wantAttrs []trace.Attribute
498 }{
499 {
500 name: "non-zero HTTP 200 response",
501 resp: &http.Response{StatusCode: 200},
502 wantAttrs: []trace.Attribute{
503 trace.Int64Attribute("http.status_code", 200),
504 },
505 },
506 {
507 name: "zero HTTP 500 response",
508 resp: &http.Response{StatusCode: 500},
509 wantAttrs: []trace.Attribute{
510 trace.Int64Attribute("http.status_code", 500),
511 },
512 },
513 }
514 for _, tt := range tests {
515 t.Run(tt.name, func(t *testing.T) {
516 attrs := responseAttrs(tt.resp)
517 if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) {
518 t.Errorf("Response attributes = %#v; want %#v", got, want)
519 }
520 })
521 }
522 }
523
524 type TestCase struct {
525 Name string
526 Method string
527 URL string
528 Headers map[string]string
529 ResponseCode int
530 SpanName string
531 SpanStatus string
532 SpanKind string
533 SpanAttributes map[string]string
534 }
535
536 func TestAgainstSpecs(t *testing.T) {
537
538 fmt.Println("start")
539
540 dat, err := ioutil.ReadFile("testdata/http-out-test-cases.json")
541 if err != nil {
542 t.Fatalf("error reading file: %v", err)
543 }
544
545 tests := make([]TestCase, 0)
546 err = json.Unmarshal(dat, &tests)
547 if err != nil {
548 t.Fatalf("error parsing json: %v", err)
549 }
550
551 trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
552
553 for _, tt := range tests {
554 t.Run(tt.Name, func(t *testing.T) {
555 var spans collector
556 trace.RegisterExporter(&spans)
557 defer trace.UnregisterExporter(&spans)
558
559 handler := &Handler{}
560 transport := &Transport{}
561
562 serverDone := make(chan struct{})
563 serverReturn := make(chan time.Time)
564 host := ""
565 port := ""
566 serverRequired := strings.Contains(tt.URL, "{")
567 if serverRequired {
568
569 localServerURL := serveHTTP(handler, serverDone, serverReturn, tt.ResponseCode)
570 u, _ := url.Parse(localServerURL)
571 host, port, _ = net.SplitHostPort(u.Host)
572
573 tt.URL = strings.Replace(tt.URL, "{host}", host, 1)
574 tt.URL = strings.Replace(tt.URL, "{port}", port, 1)
575 }
576
577
578 ctx, _ := trace.StartSpan(
579 context.Background(),
580 "top-level")
581
582 req, err := http.NewRequest(
583 tt.Method,
584 tt.URL,
585 nil)
586 for headerName, headerValue := range tt.Headers {
587 req.Header.Add(headerName, headerValue)
588 }
589 if err != nil {
590 t.Fatal(err)
591 }
592 req = req.WithContext(ctx)
593 resp, err := transport.RoundTrip(req)
594 if err != nil {
595
596
597 }
598
599 if serverRequired {
600
601 serverReturn <- time.Now().Add(time.Millisecond)
602 }
603
604 if resp != nil {
605
606
607
608
609 ioutil.ReadAll(resp.Body)
610 resp.Body.Close()
611 if serverRequired {
612 <-serverDone
613 }
614 }
615 trace.UnregisterExporter(&spans)
616
617 var client *trace.SpanData
618 for _, sp := range spans {
619 if sp.SpanKind == trace.SpanKindClient {
620 client = sp
621 }
622 }
623
624 if client.Name != tt.SpanName {
625 t.Errorf("span names don't match: expected: %s, actual: %s", tt.SpanName, client.Name)
626 }
627
628 spanKindToStr := map[int]string{
629 trace.SpanKindClient: "Client",
630 trace.SpanKindServer: "Server",
631 }
632
633 if !strings.EqualFold(codeToStr[client.Status.Code], tt.SpanStatus) {
634 t.Errorf("span status don't match: expected: %s, actual: %d (%s)", tt.SpanStatus, client.Status.Code, codeToStr[client.Status.Code])
635 }
636
637 if !strings.EqualFold(spanKindToStr[client.SpanKind], tt.SpanKind) {
638 t.Errorf("span kind don't match: expected: %s, actual: %d (%s)", tt.SpanKind, client.SpanKind, spanKindToStr[client.SpanKind])
639 }
640
641 normalizedActualAttributes := map[string]string{}
642 for k, v := range client.Attributes {
643 normalizedActualAttributes[k] = fmt.Sprintf("%v", v)
644 }
645
646 normalizedExpectedAttributes := map[string]string{}
647 for k, v := range tt.SpanAttributes {
648 normalizedValue := v
649 normalizedValue = strings.Replace(normalizedValue, "{host}", host, 1)
650 normalizedValue = strings.Replace(normalizedValue, "{port}", port, 1)
651
652 normalizedExpectedAttributes[k] = normalizedValue
653 }
654
655 if got, want := normalizedActualAttributes, normalizedExpectedAttributes; !reflect.DeepEqual(got, want) {
656 t.Errorf("Request attributes = %#v; want %#v", got, want)
657 }
658 })
659 }
660 }
661
662 func TestStatusUnitTest(t *testing.T) {
663 tests := []struct {
664 in int
665 want trace.Status
666 }{
667 {200, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
668 {204, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
669 {100, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
670 {500, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
671 {400, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: `INVALID_ARGUMENT`}},
672 {422, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: `INVALID_ARGUMENT`}},
673 {499, trace.Status{Code: trace.StatusCodeCancelled, Message: `CANCELLED`}},
674 {404, trace.Status{Code: trace.StatusCodeNotFound, Message: `NOT_FOUND`}},
675 {600, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}},
676 {401, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}},
677 {403, trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}},
678 {301, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}},
679 {501, trace.Status{Code: trace.StatusCodeUnimplemented, Message: `UNIMPLEMENTED`}},
680 {409, trace.Status{Code: trace.StatusCodeAlreadyExists, Message: `ALREADY_EXISTS`}},
681 {429, trace.Status{Code: trace.StatusCodeResourceExhausted, Message: `RESOURCE_EXHAUSTED`}},
682 {503, trace.Status{Code: trace.StatusCodeUnavailable, Message: `UNAVAILABLE`}},
683 {504, trace.Status{Code: trace.StatusCodeDeadlineExceeded, Message: `DEADLINE_EXCEEDED`}},
684 }
685
686 for _, tt := range tests {
687 got, want := TraceStatus(tt.in, ""), tt.want
688 if got != want {
689 t.Errorf("status(%d) got = (%#v) want = (%#v)", tt.in, got, want)
690 }
691 }
692 }
693
View as plain text