...

Source file src/github.com/linkerd/linkerd2/pkg/protohttp/protohttp_test.go

Documentation: github.com/linkerd/linkerd2/pkg/protohttp

     1  package protohttp
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net/http"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/go-test/deep"
    16  	metricsPb "github.com/linkerd/linkerd2/viz/metrics-api/gen/viz"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/status"
    19  	"google.golang.org/protobuf/proto"
    20  	kerrors "k8s.io/apimachinery/pkg/api/errors"
    21  	"k8s.io/apimachinery/pkg/runtime/schema"
    22  )
    23  
    24  type stubResponseWriter struct {
    25  	body    *bytes.Buffer
    26  	headers http.Header
    27  }
    28  
    29  func (w *stubResponseWriter) Header() http.Header {
    30  	return w.headers
    31  }
    32  
    33  func (w *stubResponseWriter) Write(p []byte) (int, error) {
    34  	n, err := w.body.Write(p)
    35  	return n, err
    36  }
    37  
    38  func (w *stubResponseWriter) WriteHeader(int) {}
    39  
    40  func (w *stubResponseWriter) Flush() {}
    41  
    42  type nonStreamingResponseWriter struct {
    43  }
    44  
    45  func (w *nonStreamingResponseWriter) Header() http.Header { return nil }
    46  
    47  func (w *nonStreamingResponseWriter) Write(p []byte) (int, error) { return -1, nil }
    48  
    49  func (w *nonStreamingResponseWriter) WriteHeader(int) {}
    50  
    51  func newStubResponseWriter() *stubResponseWriter {
    52  	return &stubResponseWriter{
    53  		headers: make(http.Header),
    54  		body:    bytes.NewBufferString(""),
    55  	}
    56  }
    57  
    58  func TestHttpRequestToProto(t *testing.T) {
    59  	someURL := "https://www.example.org/something"
    60  	someMethod := http.MethodPost
    61  
    62  	t.Run("Given a valid request, serializes its contents into protobuf object", func(t *testing.T) {
    63  		expectedProtoMessage := metricsPb.Pod{
    64  			Name:                "some-name",
    65  			PodIP:               "some-name",
    66  			Owner:               &metricsPb.Pod_Deployment{Deployment: "some-name"},
    67  			Status:              "some-name",
    68  			Added:               false,
    69  			ControllerNamespace: "some-name",
    70  			ControlPlane:        false,
    71  		}
    72  		payload, err := proto.Marshal(&expectedProtoMessage)
    73  		if err != nil {
    74  			t.Fatalf("Unexpected error: %v", err)
    75  		}
    76  
    77  		req, err := http.NewRequest(someMethod, someURL, bytes.NewReader(payload))
    78  		if err != nil {
    79  			t.Fatalf("Unexpected error: %v", err)
    80  		}
    81  
    82  		var actualProtoMessage metricsPb.Pod
    83  		err = HTTPRequestToProto(req, &actualProtoMessage)
    84  		if err != nil {
    85  			t.Fatalf("Unexpected error: %v", err)
    86  		}
    87  
    88  		if !proto.Equal(&actualProtoMessage, &expectedProtoMessage) {
    89  			t.Fatalf("Expected request to be [%s], but got [%s]", expectedProtoMessage.String(), actualProtoMessage.String())
    90  		}
    91  	})
    92  
    93  	t.Run("Given a broken request, returns http error", func(t *testing.T) {
    94  		var actualProtoMessage metricsPb.Pod
    95  
    96  		req, err := http.NewRequest(someMethod, someURL, strings.NewReader("not really protobuf"))
    97  		if err != nil {
    98  			t.Fatalf("Unexpected error: %v", err)
    99  		}
   100  
   101  		err = HTTPRequestToProto(req, &actualProtoMessage)
   102  		if err == nil {
   103  			t.Fatalf("Expecting error, got nothing")
   104  		}
   105  
   106  		var he HTTPError
   107  		if errors.As(err, &he) {
   108  			expectedStatusCode := http.StatusBadRequest
   109  			if he.Code != expectedStatusCode || he.WrappedError == nil {
   110  				t.Fatalf("Expected error status to be [%d] and contain wrapper error, got status [%d] and error [%s]", expectedStatusCode, he.Code, he.WrappedError)
   111  			}
   112  		} else {
   113  			t.Fatalf("Expected error to be httpError, got: %v", err)
   114  		}
   115  	})
   116  }
   117  
   118  func TestWriteErrorToHttpResponse(t *testing.T) {
   119  	t.Run("Writes generic error correctly to response", func(t *testing.T) {
   120  		expectedErrorStatusCode := defaultHTTPErrorStatusCode
   121  
   122  		responseWriter := newStubResponseWriter()
   123  		genericError := errors.New("expected generic error")
   124  
   125  		WriteErrorToHTTPResponse(responseWriter, genericError)
   126  
   127  		assertResponseHasProtobufContentType(t, responseWriter)
   128  
   129  		actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
   130  		if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
   131  			t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
   132  		}
   133  
   134  		payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
   135  		if err != nil {
   136  			t.Fatalf("Unexpected error: %v", err)
   137  		}
   138  
   139  		expectedErrorPayload := metricsPb.ApiError{Error: genericError.Error()}
   140  		var actualErrorPayload metricsPb.ApiError
   141  		err = proto.Unmarshal(payloadRead, &actualErrorPayload)
   142  		if err != nil {
   143  			t.Fatalf("Unexpected error: %v", err)
   144  		}
   145  
   146  		if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
   147  			t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
   148  		}
   149  	})
   150  
   151  	t.Run("Writes http specific error correctly to response", func(t *testing.T) {
   152  		expectedErrorStatusCode := http.StatusBadGateway
   153  		responseWriter := newStubResponseWriter()
   154  		httpError := HTTPError{
   155  			WrappedError: errors.New("expected to be wrapped"),
   156  			Code:         http.StatusBadGateway,
   157  		}
   158  
   159  		WriteErrorToHTTPResponse(responseWriter, httpError)
   160  
   161  		assertResponseHasProtobufContentType(t, responseWriter)
   162  
   163  		actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
   164  		if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
   165  			t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
   166  		}
   167  
   168  		payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
   169  		if err != nil {
   170  			t.Fatalf("Unexpected error: %v", err)
   171  		}
   172  
   173  		expectedErrorPayload := metricsPb.ApiError{Error: httpError.WrappedError.Error()}
   174  		var actualErrorPayload metricsPb.ApiError
   175  		err = proto.Unmarshal(payloadRead, &actualErrorPayload)
   176  		if err != nil {
   177  			t.Fatalf("Unexpected error: %v", err)
   178  		}
   179  
   180  		if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
   181  			t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
   182  		}
   183  	})
   184  
   185  	t.Run("Writes gRPC specific error correctly to response", func(t *testing.T) {
   186  		expectedErrorStatusCode := defaultHTTPErrorStatusCode
   187  
   188  		responseWriter := newStubResponseWriter()
   189  		expectedErrorMessage := "error message"
   190  		grpcError := status.Errorf(codes.AlreadyExists, expectedErrorMessage)
   191  
   192  		WriteErrorToHTTPResponse(responseWriter, grpcError)
   193  
   194  		assertResponseHasProtobufContentType(t, responseWriter)
   195  
   196  		actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
   197  		if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
   198  			t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
   199  		}
   200  
   201  		payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
   202  		if err != nil {
   203  			t.Fatalf("Unexpected error: %v", err)
   204  		}
   205  
   206  		expectedErrorPayload := metricsPb.ApiError{Error: expectedErrorMessage}
   207  		var actualErrorPayload metricsPb.ApiError
   208  		err = proto.Unmarshal(payloadRead, &actualErrorPayload)
   209  		if err != nil {
   210  			t.Fatalf("Unexpected error: %v", err)
   211  		}
   212  
   213  		if actualErrorPayload.String() != expectedErrorPayload.String() {
   214  			t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
   215  		}
   216  	})
   217  }
   218  
   219  func TestDeserializePayloadFromReader(t *testing.T) {
   220  	t.Run("Can read message correctly based on payload size correct payload size to message", func(t *testing.T) {
   221  		expectedMessage := "this is the message"
   222  
   223  		messageWithSize := SerializeAsPayload([]byte(expectedMessage))
   224  		messageWithSomeNoise := append(messageWithSize, []byte("this is noise and should not be read")...)
   225  
   226  		actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageWithSomeNoise)))
   227  		if err != nil {
   228  			t.Fatalf("Unexpected error: %v", err)
   229  		}
   230  
   231  		if string(actualMessage) != expectedMessage {
   232  			t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage)
   233  		}
   234  	})
   235  
   236  	t.Run("Can multiple messages in the same stream", func(t *testing.T) {
   237  		expectedMessage1 := "Hit the road, Jack and don't you come back\n"
   238  		for i := 0; i < 450; i++ {
   239  			expectedMessage1 += fmt.Sprintf("no more (%d), ", i)
   240  		}
   241  
   242  		expectedMessage2 := "back street back, alright\n"
   243  		for i := 0; i < 450; i++ {
   244  			expectedMessage2 += fmt.Sprintf("tum (%d), ", i)
   245  		}
   246  
   247  		messageWithSize1 := SerializeAsPayload([]byte(expectedMessage1))
   248  		messageWithSize2 := SerializeAsPayload([]byte(expectedMessage2))
   249  
   250  		streamWithManyMessages := append(messageWithSize1, messageWithSize2...)
   251  		reader := bufio.NewReader(bytes.NewReader(streamWithManyMessages))
   252  
   253  		actualMessage1, err := deserializePayloadFromReader(reader)
   254  		if err != nil {
   255  			t.Fatalf("Unexpected error: %v", err)
   256  		}
   257  
   258  		actualMessage2, err := deserializePayloadFromReader(reader)
   259  		if err != nil {
   260  			t.Fatalf("Unexpected error: %v", err)
   261  		}
   262  
   263  		if string(actualMessage1) != expectedMessage1 {
   264  			t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage1, actualMessage1)
   265  		}
   266  
   267  		if string(actualMessage2) != expectedMessage2 {
   268  			t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage2, actualMessage2)
   269  		}
   270  	})
   271  
   272  	t.Run("Can read byte streams larger than Go's default buffer chunk size", func(t *testing.T) {
   273  		goDefaultChunkSize := 4000
   274  		expectedMessage := "Hit the road, Jack and don't you come back\n"
   275  		for i := 0; i < 450; i++ {
   276  			expectedMessage += fmt.Sprintf("no more (%d), ", i)
   277  		}
   278  
   279  		expectedMessageAsBytes := []byte(expectedMessage)
   280  		lengthOfInputData := len(expectedMessageAsBytes)
   281  
   282  		if lengthOfInputData < goDefaultChunkSize {
   283  			t.Fatalf("Test needs data larger than [%d] bytes, currently only [%d] bytes", goDefaultChunkSize, lengthOfInputData)
   284  		}
   285  
   286  		payload := SerializeAsPayload(expectedMessageAsBytes)
   287  		actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(payload)))
   288  		if err != nil {
   289  			t.Fatalf("Unexpected error: %v", err)
   290  		}
   291  
   292  		if string(actualMessage) != expectedMessage {
   293  			t.Fatalf("Expecting payload to contain message:\n%s\n, but it had\n%s", expectedMessageAsBytes, actualMessage)
   294  		}
   295  	})
   296  
   297  	t.Run("Returns error when message has fewer bytes than declared message size", func(t *testing.T) {
   298  		expectedMessage := "this is the message"
   299  
   300  		messageWithSize := SerializeAsPayload([]byte(expectedMessage))
   301  		messageMissingOneCharacter := messageWithSize[:len(expectedMessage)-1]
   302  		_, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageMissingOneCharacter)))
   303  		if err == nil {
   304  			t.Fatalf("Expecting error, got nothing")
   305  		}
   306  	})
   307  }
   308  
   309  func TestNewStreamingWriter(t *testing.T) {
   310  	t.Run("Returns a streaming writer if the ResponseWriter is compatible with streaming", func(t *testing.T) {
   311  		rawWriter := newStubResponseWriter()
   312  		flushableWriter, err := NewStreamingWriter(rawWriter)
   313  		if err != nil {
   314  			t.Fatalf("Unexpected error: %v", err)
   315  		}
   316  
   317  		if flushableWriter != rawWriter {
   318  			t.Fatalf("Expected to return same instance of writer")
   319  		}
   320  
   321  		header := "Connection"
   322  		expectedValue := "keep-alive"
   323  		actualValue := rawWriter.Header().Get(header)
   324  		if actualValue != expectedValue {
   325  			t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
   326  		}
   327  
   328  		header = "Transfer-Encoding"
   329  		expectedValue = "chunked"
   330  		actualValue = rawWriter.Header().Get(header)
   331  		if actualValue != expectedValue {
   332  			t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
   333  		}
   334  	})
   335  
   336  	t.Run("Returns an error if writer does not support streaming", func(t *testing.T) {
   337  		_, err := NewStreamingWriter(&nonStreamingResponseWriter{})
   338  		if err == nil {
   339  			t.Fatalf("Expecting error, got nothing")
   340  		}
   341  	})
   342  }
   343  
   344  func TestCheckIfResponseHasError(t *testing.T) {
   345  	t.Run("returns nil if response doesn't contain linkerd-error header and is 200", func(t *testing.T) {
   346  		response := &http.Response{
   347  			Header:     make(http.Header),
   348  			StatusCode: http.StatusOK,
   349  		}
   350  		err := CheckIfResponseHasError(response)
   351  		if err != nil {
   352  			t.Fatalf("Unexpected error: %v", err)
   353  		}
   354  	})
   355  
   356  	t.Run("returns error in body if response contains linkerd-error header", func(t *testing.T) {
   357  		expectedErrorMessage := "expected error message"
   358  		protoInBytes, err := proto.Marshal(&metricsPb.ApiError{Error: expectedErrorMessage})
   359  		if err != nil {
   360  			t.Fatalf("Unexpected error: %v", err)
   361  		}
   362  
   363  		message := SerializeAsPayload(protoInBytes)
   364  		response := &http.Response{
   365  			Header:     make(http.Header),
   366  			Body:       io.NopCloser(bytes.NewReader(message)),
   367  			StatusCode: http.StatusInternalServerError,
   368  		}
   369  		response.Header.Set(errorHeader, "error")
   370  
   371  		err = CheckIfResponseHasError(response)
   372  		if err == nil {
   373  			t.Fatalf("Expecting error, got nothing")
   374  		}
   375  
   376  		actualErrorMessage := err.Error()
   377  		if actualErrorMessage != expectedErrorMessage {
   378  			t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
   379  		}
   380  	})
   381  
   382  	t.Run("returns Kubernetes StatusError if present", func(t *testing.T) {
   383  		statusError := kerrors.NewForbidden(
   384  			schema.GroupResource{Group: "group", Resource: "res"},
   385  			"name", errors.New("test-err"),
   386  		)
   387  		statusError.ErrStatus.Kind = "Status"
   388  		statusError.ErrStatus.APIVersion = "v1"
   389  		j, err := json.Marshal(statusError.ErrStatus)
   390  		if err != nil {
   391  			log.Fatalf("Failed to marshal JSON: %+v", statusError)
   392  		}
   393  		fmt.Printf("J: %+v\n", string(j))
   394  
   395  		response := &http.Response{
   396  			Header:     make(http.Header),
   397  			Body:       io.NopCloser(bytes.NewReader(j)),
   398  			StatusCode: http.StatusForbidden,
   399  			Status:     "403 Forbidden",
   400  		}
   401  
   402  		err = CheckIfResponseHasError(response)
   403  		expectedErr := HTTPError{Code: http.StatusForbidden, WrappedError: statusError}
   404  
   405  		if diff := deep.Equal(err, expectedErr); diff != nil {
   406  			t.Fatalf("%v", diff)
   407  		}
   408  	})
   409  
   410  	t.Run("returns error if response is not a 200", func(t *testing.T) {
   411  		response := &http.Response{
   412  			StatusCode: http.StatusServiceUnavailable,
   413  			Status:     "503 Service Unavailable",
   414  		}
   415  
   416  		err := CheckIfResponseHasError(response)
   417  		if err == nil {
   418  			t.Fatalf("Expecting error, got nothing")
   419  		}
   420  
   421  		expectedErrorMessage := "HTTP error, status Code [503] (unexpected API response)"
   422  		actualErrorMessage := err.Error()
   423  		if actualErrorMessage != expectedErrorMessage {
   424  			t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
   425  		}
   426  	})
   427  }
   428  
   429  func assertResponseHasProtobufContentType(t *testing.T, responseWriter *stubResponseWriter) {
   430  	actualContentType := responseWriter.headers.Get(contentTypeHeader)
   431  	expectedContentType := protobufContentType
   432  	if actualContentType != expectedContentType {
   433  		t.Fatalf("Expected content-type to be [%s], but got [%s]", expectedContentType, actualContentType)
   434  	}
   435  }
   436  

View as plain text