1
2
3
4 package otelhttp
5
6 import (
7 "bytes"
8 "context"
9 "errors"
10 "io"
11 "net/http"
12 "net/http/httptest"
13 "strings"
14 "testing"
15
16 "github.com/stretchr/testify/assert"
17 "github.com/stretchr/testify/require"
18
19 "go.opentelemetry.io/otel/codes"
20 "go.opentelemetry.io/otel/propagation"
21 "go.opentelemetry.io/otel/trace"
22 )
23
24 func TestTransportFormatter(t *testing.T) {
25 httpMethods := []struct {
26 name string
27 method string
28 expected string
29 }{
30 {
31 "GET method",
32 http.MethodGet,
33 "HTTP GET",
34 },
35 {
36 "HEAD method",
37 http.MethodHead,
38 "HTTP HEAD",
39 },
40 {
41 "POST method",
42 http.MethodPost,
43 "HTTP POST",
44 },
45 {
46 "PUT method",
47 http.MethodPut,
48 "HTTP PUT",
49 },
50 {
51 "PATCH method",
52 http.MethodPatch,
53 "HTTP PATCH",
54 },
55 {
56 "DELETE method",
57 http.MethodDelete,
58 "HTTP DELETE",
59 },
60 {
61 "CONNECT method",
62 http.MethodConnect,
63 "HTTP CONNECT",
64 },
65 {
66 "OPTIONS method",
67 http.MethodOptions,
68 "HTTP OPTIONS",
69 },
70 {
71 "TRACE method",
72 http.MethodTrace,
73 "HTTP TRACE",
74 },
75 }
76
77 for _, tc := range httpMethods {
78 t.Run(tc.name, func(t *testing.T) {
79 r, err := http.NewRequest(tc.method, "http://localhost/", nil)
80 if err != nil {
81 t.Fatal(err)
82 }
83 formattedName := "HTTP " + r.Method
84
85 if formattedName != tc.expected {
86 t.Fatalf("unexpected name: got %s, expected %s", formattedName, tc.expected)
87 }
88 })
89 }
90 }
91
92 func TestTransportBasics(t *testing.T) {
93 prop := propagation.TraceContext{}
94 content := []byte("Hello, world!")
95
96 ctx := context.Background()
97 sc := trace.NewSpanContext(trace.SpanContextConfig{
98 TraceID: trace.TraceID{0x01},
99 SpanID: trace.SpanID{0x01},
100 })
101 ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
102
103 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104 ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
105 span := trace.SpanContextFromContext(ctx)
106 if span.SpanID() != sc.SpanID() {
107 t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), sc.SpanID())
108 }
109 if _, err := w.Write(content); err != nil {
110 t.Fatal(err)
111 }
112 }))
113 defer ts.Close()
114
115 r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
116 if err != nil {
117 t.Fatal(err)
118 }
119
120 tr := NewTransport(http.DefaultTransport, WithPropagators(prop))
121
122 c := http.Client{Transport: tr}
123 res, err := c.Do(r)
124 if err != nil {
125 t.Fatal(err)
126 }
127
128 body, err := io.ReadAll(res.Body)
129 if err != nil {
130 t.Fatal(err)
131 }
132
133 if !bytes.Equal(body, content) {
134 t.Fatalf("unexpected content: got %s, expected %s", body, content)
135 }
136 }
137
138 func TestNilTransport(t *testing.T) {
139 prop := propagation.TraceContext{}
140 content := []byte("Hello, world!")
141
142 ctx := context.Background()
143 sc := trace.NewSpanContext(trace.SpanContextConfig{
144 TraceID: trace.TraceID{0x01},
145 SpanID: trace.SpanID{0x01},
146 })
147 ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
148
149 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
150 ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
151 span := trace.SpanContextFromContext(ctx)
152 if span.SpanID() != sc.SpanID() {
153 t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), sc.SpanID())
154 }
155 if _, err := w.Write(content); err != nil {
156 t.Fatal(err)
157 }
158 }))
159 defer ts.Close()
160
161 r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
162 if err != nil {
163 t.Fatal(err)
164 }
165
166 tr := NewTransport(nil, WithPropagators(prop))
167
168 c := http.Client{Transport: tr}
169 res, err := c.Do(r)
170 if err != nil {
171 t.Fatal(err)
172 }
173
174 body, err := io.ReadAll(res.Body)
175 if err != nil {
176 t.Fatal(err)
177 }
178
179 if !bytes.Equal(body, content) {
180 t.Fatalf("unexpected content: got %s, expected %s", body, content)
181 }
182 }
183
184 const readSize = 42
185
186 type readCloser struct {
187 readErr, closeErr error
188 }
189
190 func (rc readCloser) Read(p []byte) (n int, err error) {
191 return readSize, rc.readErr
192 }
193
194 func (rc readCloser) Close() error {
195 return rc.closeErr
196 }
197
198 type span struct {
199 trace.Span
200
201 ended bool
202 recordedErr error
203
204 statusCode codes.Code
205 statusDesc string
206 }
207
208 func (s *span) End(...trace.SpanEndOption) {
209 s.ended = true
210 }
211
212 func (s *span) RecordError(err error, _ ...trace.EventOption) {
213 s.recordedErr = err
214 }
215
216 func (s *span) SetStatus(c codes.Code, d string) {
217 s.statusCode, s.statusDesc = c, d
218 }
219
220 func (s *span) assert(t *testing.T, ended bool, err error, c codes.Code, d string) {
221 if ended {
222 assert.True(t, s.ended, "not ended")
223 } else {
224 assert.False(t, s.ended, "ended")
225 }
226
227 if err == nil {
228 assert.NoError(t, s.recordedErr, "recorded an error")
229 } else {
230 assert.Equal(t, err, s.recordedErr)
231 }
232
233 assert.Equal(t, c, s.statusCode, "status codes not equal")
234 assert.Equal(t, d, s.statusDesc, "status description not equal")
235 }
236
237 func TestWrappedBodyRead(t *testing.T) {
238 s := new(span)
239 called := false
240 record := func(numBytes int64) { called = true }
241 wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{}}
242 n, err := wb.Read([]byte{})
243 assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
244 assert.NoError(t, err)
245 s.assert(t, false, nil, codes.Unset, "")
246 assert.False(t, called, "record should not have been called")
247 }
248
249 func TestWrappedBodyReadEOFError(t *testing.T) {
250 s := new(span)
251 called := false
252 numRecorded := int64(0)
253 record := func(numBytes int64) {
254 called = true
255 numRecorded = numBytes
256 }
257 wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{readErr: io.EOF}}
258 n, err := wb.Read([]byte{})
259 assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
260 assert.Equal(t, io.EOF, err)
261 s.assert(t, true, nil, codes.Unset, "")
262 assert.True(t, called, "record should have been called")
263 assert.Equal(t, int64(readSize), numRecorded, "record recorded wrong number of bytes")
264 }
265
266 func TestWrappedBodyReadError(t *testing.T) {
267 s := new(span)
268 called := false
269 record := func(int64) { called = true }
270 expectedErr := errors.New("test")
271 wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{readErr: expectedErr}}
272 n, err := wb.Read([]byte{})
273 assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
274 assert.Equal(t, expectedErr, err)
275 s.assert(t, false, expectedErr, codes.Error, expectedErr.Error())
276 assert.False(t, called, "record should not have been called")
277 }
278
279 func TestWrappedBodyClose(t *testing.T) {
280 s := new(span)
281 called := false
282 record := func(int64) { called = true }
283 wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{}}
284 assert.NoError(t, wb.Close())
285 s.assert(t, true, nil, codes.Unset, "")
286 assert.True(t, called, "record should have been called")
287 }
288
289 func TestWrappedBodyClosePanic(t *testing.T) {
290 s := new(span)
291 var body io.ReadCloser
292 wb := newWrappedBody(s, func(n int64) {}, body)
293 assert.NotPanics(t, func() { wb.Close() }, "nil body should not panic on close")
294 }
295
296 func TestWrappedBodyCloseError(t *testing.T) {
297 s := new(span)
298 called := false
299 record := func(int64) { called = true }
300 expectedErr := errors.New("test")
301 wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{closeErr: expectedErr}}
302 assert.Equal(t, expectedErr, wb.Close())
303 s.assert(t, true, nil, codes.Unset, "")
304 assert.True(t, called, "record should have been called")
305 }
306
307 type readWriteCloser struct {
308 readCloser
309
310 writeErr error
311 }
312
313 const writeSize = 1
314
315 func (rwc readWriteCloser) Write([]byte) (int, error) {
316 return writeSize, rwc.writeErr
317 }
318
319 func TestNewWrappedBodyReadWriteCloserImplementation(t *testing.T) {
320 wb := newWrappedBody(nil, func(n int64) {}, readWriteCloser{})
321 assert.Implements(t, (*io.ReadWriteCloser)(nil), wb)
322 }
323
324 func TestNewWrappedBodyReadCloserImplementation(t *testing.T) {
325 wb := newWrappedBody(nil, func(n int64) {}, readCloser{})
326 assert.Implements(t, (*io.ReadCloser)(nil), wb)
327
328 _, ok := wb.(io.ReadWriteCloser)
329 assert.False(t, ok, "wrappedBody should not implement io.ReadWriteCloser")
330 }
331
332 func TestWrappedBodyWrite(t *testing.T) {
333 s := new(span)
334 var rwc io.ReadWriteCloser
335 assert.NotPanics(t, func() {
336 rwc = newWrappedBody(s, func(n int64) {}, readWriteCloser{}).(io.ReadWriteCloser)
337 })
338
339 n, err := rwc.Write([]byte{})
340 assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes")
341 assert.NoError(t, err)
342 s.assert(t, false, nil, codes.Unset, "")
343 }
344
345 func TestWrappedBodyWriteError(t *testing.T) {
346 s := new(span)
347 expectedErr := errors.New("test")
348 var rwc io.ReadWriteCloser
349 assert.NotPanics(t, func() {
350 rwc = newWrappedBody(s,
351 func(n int64) {},
352 readWriteCloser{
353 writeErr: expectedErr,
354 }).(io.ReadWriteCloser)
355 })
356 n, err := rwc.Write([]byte{})
357 assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes")
358 assert.ErrorIs(t, err, expectedErr)
359 s.assert(t, false, expectedErr, codes.Error, expectedErr.Error())
360 }
361
362 func TestTransportProtocolSwitch(t *testing.T) {
363
364
365
366 response := []byte(strings.Join([]string{
367 "HTTP/1.1 101 Switching Protocols",
368 "Upgrade: WebSocket",
369 "Connection: Upgrade",
370 "", "",
371 }, "\r\n"))
372
373 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
374 conn, buf, err := w.(http.Hijacker).Hijack()
375 require.NoError(t, err)
376
377 _, err = buf.Write(response)
378 require.NoError(t, err)
379 require.NoError(t, buf.Flush())
380 require.NoError(t, conn.Close())
381 }))
382 defer ts.Close()
383
384 ctx := context.Background()
385 r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, http.NoBody)
386 require.NoError(t, err)
387
388 c := http.Client{Transport: NewTransport(http.DefaultTransport)}
389 res, err := c.Do(r)
390 require.NoError(t, err)
391 t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
392
393 assert.Implements(t, (*io.ReadWriteCloser)(nil), res.Body, "invalid body returned for protocol switch")
394 }
395
396 func TestTransportOriginRequestNotModify(t *testing.T) {
397 prop := propagation.TraceContext{}
398
399 ctx := context.Background()
400 sc := trace.NewSpanContext(trace.SpanContextConfig{
401 TraceID: trace.TraceID{0x01},
402 SpanID: trace.SpanID{0x01},
403 })
404 ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
405
406 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
407 w.WriteHeader(http.StatusOK)
408 }))
409 defer ts.Close()
410
411 r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, http.NoBody)
412 require.NoError(t, err)
413
414 expectedRequest := r.Clone(r.Context())
415
416 c := http.Client{Transport: NewTransport(http.DefaultTransport, WithPropagators(prop))}
417 res, err := c.Do(r)
418 require.NoError(t, err)
419
420 t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
421
422 assert.Equal(t, expectedRequest, r)
423 }
424
View as plain text