...
1 package middleware
2
3 import (
4 "fmt"
5 "net/http"
6 "net/http/httptest"
7 "testing"
8
9 "github.com/go-chi/chi"
10 )
11
12 func maintainDefaultRequestId() func() {
13 original := RequestIDHeader
14
15 return func() {
16 RequestIDHeader = original
17 }
18 }
19
20 func TestRequestID(t *testing.T) {
21 tests := map[string]struct {
22 requestIDHeader string
23 request func() *http.Request
24 expectedResponse string
25 }{
26 "Retrieves Request Id from default header": {
27 "X-Request-Id",
28 func() *http.Request {
29 req, _ := http.NewRequest("GET", "/", nil)
30 req.Header.Add("X-Request-Id", "req-123456")
31
32 return req
33 },
34 "RequestID: req-123456",
35 },
36 "Retrieves Request Id from custom header": {
37 "X-Trace-Id",
38 func() *http.Request {
39 req, _ := http.NewRequest("GET", "/", nil)
40 req.Header.Add("X-Trace-Id", "trace:abc123")
41
42 return req
43 },
44 "RequestID: trace:abc123",
45 },
46 }
47
48 defer maintainDefaultRequestId()()
49
50 for _, test := range tests {
51 w := httptest.NewRecorder()
52
53 r := chi.NewRouter()
54
55 RequestIDHeader = test.requestIDHeader
56
57 r.Use(RequestID)
58
59 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
60 requestID := GetReqID(r.Context())
61 response := fmt.Sprintf("RequestID: %s", requestID)
62
63 w.Write([]byte(response))
64 })
65 r.ServeHTTP(w, test.request())
66
67 if w.Body.String() != test.expectedResponse {
68 t.Fatalf("RequestID was not the expected value")
69 }
70 }
71 }
72
View as plain text