1 package protohttp
2
3 import (
4 "bufio"
5 "bytes"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log"
11 "net/http"
12 "strings"
13 "testing"
14
15 "github.com/go-test/deep"
16 metricsPb "github.com/linkerd/linkerd2/viz/metrics-api/gen/viz"
17 "google.golang.org/grpc/codes"
18 "google.golang.org/grpc/status"
19 "google.golang.org/protobuf/proto"
20 kerrors "k8s.io/apimachinery/pkg/api/errors"
21 "k8s.io/apimachinery/pkg/runtime/schema"
22 )
23
24 type stubResponseWriter struct {
25 body *bytes.Buffer
26 headers http.Header
27 }
28
29 func (w *stubResponseWriter) Header() http.Header {
30 return w.headers
31 }
32
33 func (w *stubResponseWriter) Write(p []byte) (int, error) {
34 n, err := w.body.Write(p)
35 return n, err
36 }
37
38 func (w *stubResponseWriter) WriteHeader(int) {}
39
40 func (w *stubResponseWriter) Flush() {}
41
42 type nonStreamingResponseWriter struct {
43 }
44
45 func (w *nonStreamingResponseWriter) Header() http.Header { return nil }
46
47 func (w *nonStreamingResponseWriter) Write(p []byte) (int, error) { return -1, nil }
48
49 func (w *nonStreamingResponseWriter) WriteHeader(int) {}
50
51 func newStubResponseWriter() *stubResponseWriter {
52 return &stubResponseWriter{
53 headers: make(http.Header),
54 body: bytes.NewBufferString(""),
55 }
56 }
57
58 func TestHttpRequestToProto(t *testing.T) {
59 someURL := "https://www.example.org/something"
60 someMethod := http.MethodPost
61
62 t.Run("Given a valid request, serializes its contents into protobuf object", func(t *testing.T) {
63 expectedProtoMessage := metricsPb.Pod{
64 Name: "some-name",
65 PodIP: "some-name",
66 Owner: &metricsPb.Pod_Deployment{Deployment: "some-name"},
67 Status: "some-name",
68 Added: false,
69 ControllerNamespace: "some-name",
70 ControlPlane: false,
71 }
72 payload, err := proto.Marshal(&expectedProtoMessage)
73 if err != nil {
74 t.Fatalf("Unexpected error: %v", err)
75 }
76
77 req, err := http.NewRequest(someMethod, someURL, bytes.NewReader(payload))
78 if err != nil {
79 t.Fatalf("Unexpected error: %v", err)
80 }
81
82 var actualProtoMessage metricsPb.Pod
83 err = HTTPRequestToProto(req, &actualProtoMessage)
84 if err != nil {
85 t.Fatalf("Unexpected error: %v", err)
86 }
87
88 if !proto.Equal(&actualProtoMessage, &expectedProtoMessage) {
89 t.Fatalf("Expected request to be [%s], but got [%s]", expectedProtoMessage.String(), actualProtoMessage.String())
90 }
91 })
92
93 t.Run("Given a broken request, returns http error", func(t *testing.T) {
94 var actualProtoMessage metricsPb.Pod
95
96 req, err := http.NewRequest(someMethod, someURL, strings.NewReader("not really protobuf"))
97 if err != nil {
98 t.Fatalf("Unexpected error: %v", err)
99 }
100
101 err = HTTPRequestToProto(req, &actualProtoMessage)
102 if err == nil {
103 t.Fatalf("Expecting error, got nothing")
104 }
105
106 var he HTTPError
107 if errors.As(err, &he) {
108 expectedStatusCode := http.StatusBadRequest
109 if he.Code != expectedStatusCode || he.WrappedError == nil {
110 t.Fatalf("Expected error status to be [%d] and contain wrapper error, got status [%d] and error [%s]", expectedStatusCode, he.Code, he.WrappedError)
111 }
112 } else {
113 t.Fatalf("Expected error to be httpError, got: %v", err)
114 }
115 })
116 }
117
118 func TestWriteErrorToHttpResponse(t *testing.T) {
119 t.Run("Writes generic error correctly to response", func(t *testing.T) {
120 expectedErrorStatusCode := defaultHTTPErrorStatusCode
121
122 responseWriter := newStubResponseWriter()
123 genericError := errors.New("expected generic error")
124
125 WriteErrorToHTTPResponse(responseWriter, genericError)
126
127 assertResponseHasProtobufContentType(t, responseWriter)
128
129 actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
130 if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
131 t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
132 }
133
134 payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
135 if err != nil {
136 t.Fatalf("Unexpected error: %v", err)
137 }
138
139 expectedErrorPayload := metricsPb.ApiError{Error: genericError.Error()}
140 var actualErrorPayload metricsPb.ApiError
141 err = proto.Unmarshal(payloadRead, &actualErrorPayload)
142 if err != nil {
143 t.Fatalf("Unexpected error: %v", err)
144 }
145
146 if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
147 t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
148 }
149 })
150
151 t.Run("Writes http specific error correctly to response", func(t *testing.T) {
152 expectedErrorStatusCode := http.StatusBadGateway
153 responseWriter := newStubResponseWriter()
154 httpError := HTTPError{
155 WrappedError: errors.New("expected to be wrapped"),
156 Code: http.StatusBadGateway,
157 }
158
159 WriteErrorToHTTPResponse(responseWriter, httpError)
160
161 assertResponseHasProtobufContentType(t, responseWriter)
162
163 actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
164 if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
165 t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
166 }
167
168 payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
169 if err != nil {
170 t.Fatalf("Unexpected error: %v", err)
171 }
172
173 expectedErrorPayload := metricsPb.ApiError{Error: httpError.WrappedError.Error()}
174 var actualErrorPayload metricsPb.ApiError
175 err = proto.Unmarshal(payloadRead, &actualErrorPayload)
176 if err != nil {
177 t.Fatalf("Unexpected error: %v", err)
178 }
179
180 if !proto.Equal(&actualErrorPayload, &expectedErrorPayload) {
181 t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
182 }
183 })
184
185 t.Run("Writes gRPC specific error correctly to response", func(t *testing.T) {
186 expectedErrorStatusCode := defaultHTTPErrorStatusCode
187
188 responseWriter := newStubResponseWriter()
189 expectedErrorMessage := "error message"
190 grpcError := status.Errorf(codes.AlreadyExists, expectedErrorMessage)
191
192 WriteErrorToHTTPResponse(responseWriter, grpcError)
193
194 assertResponseHasProtobufContentType(t, responseWriter)
195
196 actualErrorStatusCode := responseWriter.headers.Get(errorHeader)
197 if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) {
198 t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode)
199 }
200
201 payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes())))
202 if err != nil {
203 t.Fatalf("Unexpected error: %v", err)
204 }
205
206 expectedErrorPayload := metricsPb.ApiError{Error: expectedErrorMessage}
207 var actualErrorPayload metricsPb.ApiError
208 err = proto.Unmarshal(payloadRead, &actualErrorPayload)
209 if err != nil {
210 t.Fatalf("Unexpected error: %v", err)
211 }
212
213 if actualErrorPayload.String() != expectedErrorPayload.String() {
214 t.Fatalf("Expecting error to be serialized as [%s], but got [%s]", expectedErrorPayload.String(), actualErrorPayload.String())
215 }
216 })
217 }
218
219 func TestDeserializePayloadFromReader(t *testing.T) {
220 t.Run("Can read message correctly based on payload size correct payload size to message", func(t *testing.T) {
221 expectedMessage := "this is the message"
222
223 messageWithSize := SerializeAsPayload([]byte(expectedMessage))
224 messageWithSomeNoise := append(messageWithSize, []byte("this is noise and should not be read")...)
225
226 actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageWithSomeNoise)))
227 if err != nil {
228 t.Fatalf("Unexpected error: %v", err)
229 }
230
231 if string(actualMessage) != expectedMessage {
232 t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage)
233 }
234 })
235
236 t.Run("Can multiple messages in the same stream", func(t *testing.T) {
237 expectedMessage1 := "Hit the road, Jack and don't you come back\n"
238 for i := 0; i < 450; i++ {
239 expectedMessage1 += fmt.Sprintf("no more (%d), ", i)
240 }
241
242 expectedMessage2 := "back street back, alright\n"
243 for i := 0; i < 450; i++ {
244 expectedMessage2 += fmt.Sprintf("tum (%d), ", i)
245 }
246
247 messageWithSize1 := SerializeAsPayload([]byte(expectedMessage1))
248 messageWithSize2 := SerializeAsPayload([]byte(expectedMessage2))
249
250 streamWithManyMessages := append(messageWithSize1, messageWithSize2...)
251 reader := bufio.NewReader(bytes.NewReader(streamWithManyMessages))
252
253 actualMessage1, err := deserializePayloadFromReader(reader)
254 if err != nil {
255 t.Fatalf("Unexpected error: %v", err)
256 }
257
258 actualMessage2, err := deserializePayloadFromReader(reader)
259 if err != nil {
260 t.Fatalf("Unexpected error: %v", err)
261 }
262
263 if string(actualMessage1) != expectedMessage1 {
264 t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage1, actualMessage1)
265 }
266
267 if string(actualMessage2) != expectedMessage2 {
268 t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage2, actualMessage2)
269 }
270 })
271
272 t.Run("Can read byte streams larger than Go's default buffer chunk size", func(t *testing.T) {
273 goDefaultChunkSize := 4000
274 expectedMessage := "Hit the road, Jack and don't you come back\n"
275 for i := 0; i < 450; i++ {
276 expectedMessage += fmt.Sprintf("no more (%d), ", i)
277 }
278
279 expectedMessageAsBytes := []byte(expectedMessage)
280 lengthOfInputData := len(expectedMessageAsBytes)
281
282 if lengthOfInputData < goDefaultChunkSize {
283 t.Fatalf("Test needs data larger than [%d] bytes, currently only [%d] bytes", goDefaultChunkSize, lengthOfInputData)
284 }
285
286 payload := SerializeAsPayload(expectedMessageAsBytes)
287 actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(payload)))
288 if err != nil {
289 t.Fatalf("Unexpected error: %v", err)
290 }
291
292 if string(actualMessage) != expectedMessage {
293 t.Fatalf("Expecting payload to contain message:\n%s\n, but it had\n%s", expectedMessageAsBytes, actualMessage)
294 }
295 })
296
297 t.Run("Returns error when message has fewer bytes than declared message size", func(t *testing.T) {
298 expectedMessage := "this is the message"
299
300 messageWithSize := SerializeAsPayload([]byte(expectedMessage))
301 messageMissingOneCharacter := messageWithSize[:len(expectedMessage)-1]
302 _, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageMissingOneCharacter)))
303 if err == nil {
304 t.Fatalf("Expecting error, got nothing")
305 }
306 })
307 }
308
309 func TestNewStreamingWriter(t *testing.T) {
310 t.Run("Returns a streaming writer if the ResponseWriter is compatible with streaming", func(t *testing.T) {
311 rawWriter := newStubResponseWriter()
312 flushableWriter, err := NewStreamingWriter(rawWriter)
313 if err != nil {
314 t.Fatalf("Unexpected error: %v", err)
315 }
316
317 if flushableWriter != rawWriter {
318 t.Fatalf("Expected to return same instance of writer")
319 }
320
321 header := "Connection"
322 expectedValue := "keep-alive"
323 actualValue := rawWriter.Header().Get(header)
324 if actualValue != expectedValue {
325 t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
326 }
327
328 header = "Transfer-Encoding"
329 expectedValue = "chunked"
330 actualValue = rawWriter.Header().Get(header)
331 if actualValue != expectedValue {
332 t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue)
333 }
334 })
335
336 t.Run("Returns an error if writer does not support streaming", func(t *testing.T) {
337 _, err := NewStreamingWriter(&nonStreamingResponseWriter{})
338 if err == nil {
339 t.Fatalf("Expecting error, got nothing")
340 }
341 })
342 }
343
344 func TestCheckIfResponseHasError(t *testing.T) {
345 t.Run("returns nil if response doesn't contain linkerd-error header and is 200", func(t *testing.T) {
346 response := &http.Response{
347 Header: make(http.Header),
348 StatusCode: http.StatusOK,
349 }
350 err := CheckIfResponseHasError(response)
351 if err != nil {
352 t.Fatalf("Unexpected error: %v", err)
353 }
354 })
355
356 t.Run("returns error in body if response contains linkerd-error header", func(t *testing.T) {
357 expectedErrorMessage := "expected error message"
358 protoInBytes, err := proto.Marshal(&metricsPb.ApiError{Error: expectedErrorMessage})
359 if err != nil {
360 t.Fatalf("Unexpected error: %v", err)
361 }
362
363 message := SerializeAsPayload(protoInBytes)
364 response := &http.Response{
365 Header: make(http.Header),
366 Body: io.NopCloser(bytes.NewReader(message)),
367 StatusCode: http.StatusInternalServerError,
368 }
369 response.Header.Set(errorHeader, "error")
370
371 err = CheckIfResponseHasError(response)
372 if err == nil {
373 t.Fatalf("Expecting error, got nothing")
374 }
375
376 actualErrorMessage := err.Error()
377 if actualErrorMessage != expectedErrorMessage {
378 t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
379 }
380 })
381
382 t.Run("returns Kubernetes StatusError if present", func(t *testing.T) {
383 statusError := kerrors.NewForbidden(
384 schema.GroupResource{Group: "group", Resource: "res"},
385 "name", errors.New("test-err"),
386 )
387 statusError.ErrStatus.Kind = "Status"
388 statusError.ErrStatus.APIVersion = "v1"
389 j, err := json.Marshal(statusError.ErrStatus)
390 if err != nil {
391 log.Fatalf("Failed to marshal JSON: %+v", statusError)
392 }
393 fmt.Printf("J: %+v\n", string(j))
394
395 response := &http.Response{
396 Header: make(http.Header),
397 Body: io.NopCloser(bytes.NewReader(j)),
398 StatusCode: http.StatusForbidden,
399 Status: "403 Forbidden",
400 }
401
402 err = CheckIfResponseHasError(response)
403 expectedErr := HTTPError{Code: http.StatusForbidden, WrappedError: statusError}
404
405 if diff := deep.Equal(err, expectedErr); diff != nil {
406 t.Fatalf("%v", diff)
407 }
408 })
409
410 t.Run("returns error if response is not a 200", func(t *testing.T) {
411 response := &http.Response{
412 StatusCode: http.StatusServiceUnavailable,
413 Status: "503 Service Unavailable",
414 }
415
416 err := CheckIfResponseHasError(response)
417 if err == nil {
418 t.Fatalf("Expecting error, got nothing")
419 }
420
421 expectedErrorMessage := "HTTP error, status Code [503] (unexpected API response)"
422 actualErrorMessage := err.Error()
423 if actualErrorMessage != expectedErrorMessage {
424 t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage)
425 }
426 })
427 }
428
429 func assertResponseHasProtobufContentType(t *testing.T, responseWriter *stubResponseWriter) {
430 actualContentType := responseWriter.headers.Get(contentTypeHeader)
431 expectedContentType := protobufContentType
432 if actualContentType != expectedContentType {
433 t.Fatalf("Expected content-type to be [%s], but got [%s]", expectedContentType, actualContentType)
434 }
435 }
436
View as plain text