...

Source file src/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport_test.go

Documentation: go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp

     1  // Copyright The OpenTelemetry Authors
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  package otelhttp
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"go.opentelemetry.io/otel/codes"
    20  	"go.opentelemetry.io/otel/propagation"
    21  	"go.opentelemetry.io/otel/trace"
    22  )
    23  
    24  func TestTransportFormatter(t *testing.T) {
    25  	httpMethods := []struct {
    26  		name     string
    27  		method   string
    28  		expected string
    29  	}{
    30  		{
    31  			"GET method",
    32  			http.MethodGet,
    33  			"HTTP GET",
    34  		},
    35  		{
    36  			"HEAD method",
    37  			http.MethodHead,
    38  			"HTTP HEAD",
    39  		},
    40  		{
    41  			"POST method",
    42  			http.MethodPost,
    43  			"HTTP POST",
    44  		},
    45  		{
    46  			"PUT method",
    47  			http.MethodPut,
    48  			"HTTP PUT",
    49  		},
    50  		{
    51  			"PATCH method",
    52  			http.MethodPatch,
    53  			"HTTP PATCH",
    54  		},
    55  		{
    56  			"DELETE method",
    57  			http.MethodDelete,
    58  			"HTTP DELETE",
    59  		},
    60  		{
    61  			"CONNECT method",
    62  			http.MethodConnect,
    63  			"HTTP CONNECT",
    64  		},
    65  		{
    66  			"OPTIONS method",
    67  			http.MethodOptions,
    68  			"HTTP OPTIONS",
    69  		},
    70  		{
    71  			"TRACE method",
    72  			http.MethodTrace,
    73  			"HTTP TRACE",
    74  		},
    75  	}
    76  
    77  	for _, tc := range httpMethods {
    78  		t.Run(tc.name, func(t *testing.T) {
    79  			r, err := http.NewRequest(tc.method, "http://localhost/", nil)
    80  			if err != nil {
    81  				t.Fatal(err)
    82  			}
    83  			formattedName := "HTTP " + r.Method
    84  
    85  			if formattedName != tc.expected {
    86  				t.Fatalf("unexpected name: got %s, expected %s", formattedName, tc.expected)
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func TestTransportBasics(t *testing.T) {
    93  	prop := propagation.TraceContext{}
    94  	content := []byte("Hello, world!")
    95  
    96  	ctx := context.Background()
    97  	sc := trace.NewSpanContext(trace.SpanContextConfig{
    98  		TraceID: trace.TraceID{0x01},
    99  		SpanID:  trace.SpanID{0x01},
   100  	})
   101  	ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
   102  
   103  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   104  		ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
   105  		span := trace.SpanContextFromContext(ctx)
   106  		if span.SpanID() != sc.SpanID() {
   107  			t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), sc.SpanID())
   108  		}
   109  		if _, err := w.Write(content); err != nil {
   110  			t.Fatal(err)
   111  		}
   112  	}))
   113  	defer ts.Close()
   114  
   115  	r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
   116  	if err != nil {
   117  		t.Fatal(err)
   118  	}
   119  
   120  	tr := NewTransport(http.DefaultTransport, WithPropagators(prop))
   121  
   122  	c := http.Client{Transport: tr}
   123  	res, err := c.Do(r)
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  
   128  	body, err := io.ReadAll(res.Body)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	if !bytes.Equal(body, content) {
   134  		t.Fatalf("unexpected content: got %s, expected %s", body, content)
   135  	}
   136  }
   137  
   138  func TestNilTransport(t *testing.T) {
   139  	prop := propagation.TraceContext{}
   140  	content := []byte("Hello, world!")
   141  
   142  	ctx := context.Background()
   143  	sc := trace.NewSpanContext(trace.SpanContextConfig{
   144  		TraceID: trace.TraceID{0x01},
   145  		SpanID:  trace.SpanID{0x01},
   146  	})
   147  	ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
   148  
   149  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   150  		ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
   151  		span := trace.SpanContextFromContext(ctx)
   152  		if span.SpanID() != sc.SpanID() {
   153  			t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), sc.SpanID())
   154  		}
   155  		if _, err := w.Write(content); err != nil {
   156  			t.Fatal(err)
   157  		}
   158  	}))
   159  	defer ts.Close()
   160  
   161  	r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  
   166  	tr := NewTransport(nil, WithPropagators(prop))
   167  
   168  	c := http.Client{Transport: tr}
   169  	res, err := c.Do(r)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  
   174  	body, err := io.ReadAll(res.Body)
   175  	if err != nil {
   176  		t.Fatal(err)
   177  	}
   178  
   179  	if !bytes.Equal(body, content) {
   180  		t.Fatalf("unexpected content: got %s, expected %s", body, content)
   181  	}
   182  }
   183  
   184  const readSize = 42
   185  
   186  type readCloser struct {
   187  	readErr, closeErr error
   188  }
   189  
   190  func (rc readCloser) Read(p []byte) (n int, err error) {
   191  	return readSize, rc.readErr
   192  }
   193  
   194  func (rc readCloser) Close() error {
   195  	return rc.closeErr
   196  }
   197  
   198  type span struct {
   199  	trace.Span
   200  
   201  	ended       bool
   202  	recordedErr error
   203  
   204  	statusCode codes.Code
   205  	statusDesc string
   206  }
   207  
   208  func (s *span) End(...trace.SpanEndOption) {
   209  	s.ended = true
   210  }
   211  
   212  func (s *span) RecordError(err error, _ ...trace.EventOption) {
   213  	s.recordedErr = err
   214  }
   215  
   216  func (s *span) SetStatus(c codes.Code, d string) {
   217  	s.statusCode, s.statusDesc = c, d
   218  }
   219  
   220  func (s *span) assert(t *testing.T, ended bool, err error, c codes.Code, d string) { // nolint: revive  // ended is not a control flag.
   221  	if ended {
   222  		assert.True(t, s.ended, "not ended")
   223  	} else {
   224  		assert.False(t, s.ended, "ended")
   225  	}
   226  
   227  	if err == nil {
   228  		assert.NoError(t, s.recordedErr, "recorded an error")
   229  	} else {
   230  		assert.Equal(t, err, s.recordedErr)
   231  	}
   232  
   233  	assert.Equal(t, c, s.statusCode, "status codes not equal")
   234  	assert.Equal(t, d, s.statusDesc, "status description not equal")
   235  }
   236  
   237  func TestWrappedBodyRead(t *testing.T) {
   238  	s := new(span)
   239  	called := false
   240  	record := func(numBytes int64) { called = true }
   241  	wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{}}
   242  	n, err := wb.Read([]byte{})
   243  	assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
   244  	assert.NoError(t, err)
   245  	s.assert(t, false, nil, codes.Unset, "")
   246  	assert.False(t, called, "record should not have been called")
   247  }
   248  
   249  func TestWrappedBodyReadEOFError(t *testing.T) {
   250  	s := new(span)
   251  	called := false
   252  	numRecorded := int64(0)
   253  	record := func(numBytes int64) {
   254  		called = true
   255  		numRecorded = numBytes
   256  	}
   257  	wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{readErr: io.EOF}}
   258  	n, err := wb.Read([]byte{})
   259  	assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
   260  	assert.Equal(t, io.EOF, err)
   261  	s.assert(t, true, nil, codes.Unset, "")
   262  	assert.True(t, called, "record should have been called")
   263  	assert.Equal(t, int64(readSize), numRecorded, "record recorded wrong number of bytes")
   264  }
   265  
   266  func TestWrappedBodyReadError(t *testing.T) {
   267  	s := new(span)
   268  	called := false
   269  	record := func(int64) { called = true }
   270  	expectedErr := errors.New("test")
   271  	wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{readErr: expectedErr}}
   272  	n, err := wb.Read([]byte{})
   273  	assert.Equal(t, readSize, n, "wrappedBody returned wrong bytes")
   274  	assert.Equal(t, expectedErr, err)
   275  	s.assert(t, false, expectedErr, codes.Error, expectedErr.Error())
   276  	assert.False(t, called, "record should not have been called")
   277  }
   278  
   279  func TestWrappedBodyClose(t *testing.T) {
   280  	s := new(span)
   281  	called := false
   282  	record := func(int64) { called = true }
   283  	wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{}}
   284  	assert.NoError(t, wb.Close())
   285  	s.assert(t, true, nil, codes.Unset, "")
   286  	assert.True(t, called, "record should have been called")
   287  }
   288  
   289  func TestWrappedBodyClosePanic(t *testing.T) {
   290  	s := new(span)
   291  	var body io.ReadCloser
   292  	wb := newWrappedBody(s, func(n int64) {}, body)
   293  	assert.NotPanics(t, func() { wb.Close() }, "nil body should not panic on close")
   294  }
   295  
   296  func TestWrappedBodyCloseError(t *testing.T) {
   297  	s := new(span)
   298  	called := false
   299  	record := func(int64) { called = true }
   300  	expectedErr := errors.New("test")
   301  	wb := &wrappedBody{span: trace.Span(s), record: record, body: readCloser{closeErr: expectedErr}}
   302  	assert.Equal(t, expectedErr, wb.Close())
   303  	s.assert(t, true, nil, codes.Unset, "")
   304  	assert.True(t, called, "record should have been called")
   305  }
   306  
   307  type readWriteCloser struct {
   308  	readCloser
   309  
   310  	writeErr error
   311  }
   312  
   313  const writeSize = 1
   314  
   315  func (rwc readWriteCloser) Write([]byte) (int, error) {
   316  	return writeSize, rwc.writeErr
   317  }
   318  
   319  func TestNewWrappedBodyReadWriteCloserImplementation(t *testing.T) {
   320  	wb := newWrappedBody(nil, func(n int64) {}, readWriteCloser{})
   321  	assert.Implements(t, (*io.ReadWriteCloser)(nil), wb)
   322  }
   323  
   324  func TestNewWrappedBodyReadCloserImplementation(t *testing.T) {
   325  	wb := newWrappedBody(nil, func(n int64) {}, readCloser{})
   326  	assert.Implements(t, (*io.ReadCloser)(nil), wb)
   327  
   328  	_, ok := wb.(io.ReadWriteCloser)
   329  	assert.False(t, ok, "wrappedBody should not implement io.ReadWriteCloser")
   330  }
   331  
   332  func TestWrappedBodyWrite(t *testing.T) {
   333  	s := new(span)
   334  	var rwc io.ReadWriteCloser
   335  	assert.NotPanics(t, func() {
   336  		rwc = newWrappedBody(s, func(n int64) {}, readWriteCloser{}).(io.ReadWriteCloser)
   337  	})
   338  
   339  	n, err := rwc.Write([]byte{})
   340  	assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes")
   341  	assert.NoError(t, err)
   342  	s.assert(t, false, nil, codes.Unset, "")
   343  }
   344  
   345  func TestWrappedBodyWriteError(t *testing.T) {
   346  	s := new(span)
   347  	expectedErr := errors.New("test")
   348  	var rwc io.ReadWriteCloser
   349  	assert.NotPanics(t, func() {
   350  		rwc = newWrappedBody(s,
   351  			func(n int64) {},
   352  			readWriteCloser{
   353  				writeErr: expectedErr,
   354  			}).(io.ReadWriteCloser)
   355  	})
   356  	n, err := rwc.Write([]byte{})
   357  	assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes")
   358  	assert.ErrorIs(t, err, expectedErr)
   359  	s.assert(t, false, expectedErr, codes.Error, expectedErr.Error())
   360  }
   361  
   362  func TestTransportProtocolSwitch(t *testing.T) {
   363  	// This test validates the fix to #1329.
   364  
   365  	// Simulate a "101 Switching Protocols" response from the test server.
   366  	response := []byte(strings.Join([]string{
   367  		"HTTP/1.1 101 Switching Protocols",
   368  		"Upgrade: WebSocket",
   369  		"Connection: Upgrade",
   370  		"", "", // Needed for extra CRLF.
   371  	}, "\r\n"))
   372  
   373  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   374  		conn, buf, err := w.(http.Hijacker).Hijack()
   375  		require.NoError(t, err)
   376  
   377  		_, err = buf.Write(response)
   378  		require.NoError(t, err)
   379  		require.NoError(t, buf.Flush())
   380  		require.NoError(t, conn.Close())
   381  	}))
   382  	defer ts.Close()
   383  
   384  	ctx := context.Background()
   385  	r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, http.NoBody)
   386  	require.NoError(t, err)
   387  
   388  	c := http.Client{Transport: NewTransport(http.DefaultTransport)}
   389  	res, err := c.Do(r)
   390  	require.NoError(t, err)
   391  	t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
   392  
   393  	assert.Implements(t, (*io.ReadWriteCloser)(nil), res.Body, "invalid body returned for protocol switch")
   394  }
   395  
   396  func TestTransportOriginRequestNotModify(t *testing.T) {
   397  	prop := propagation.TraceContext{}
   398  
   399  	ctx := context.Background()
   400  	sc := trace.NewSpanContext(trace.SpanContextConfig{
   401  		TraceID: trace.TraceID{0x01},
   402  		SpanID:  trace.SpanID{0x01},
   403  	})
   404  	ctx = trace.ContextWithRemoteSpanContext(ctx, sc)
   405  
   406  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   407  		w.WriteHeader(http.StatusOK)
   408  	}))
   409  	defer ts.Close()
   410  
   411  	r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, http.NoBody)
   412  	require.NoError(t, err)
   413  
   414  	expectedRequest := r.Clone(r.Context())
   415  
   416  	c := http.Client{Transport: NewTransport(http.DefaultTransport, WithPropagators(prop))}
   417  	res, err := c.Do(r)
   418  	require.NoError(t, err)
   419  
   420  	t.Cleanup(func() { require.NoError(t, res.Body.Close()) })
   421  
   422  	assert.Equal(t, expectedRequest, r)
   423  }
   424  

View as plain text