package middleware import ( "bytes" "fmt" "net/http" "net/http/httptest" "testing" "edge-infra.dev/pkg/edge/api/middleware" "edge-infra.dev/pkg/lib/fog" "github.com/gin-contrib/requestid" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestSetLoggerInContext(t *testing.T) { r := httptest.NewRecorder() buf := new(bytes.Buffer) router := createTestRouter(r, buf) req := httptest.NewRequest(http.MethodGet, "/", bytes.NewBuffer([]byte{})) correlationID := "correlationID" req.Header.Set(CorrelationIDKey, correlationID) var operationID string router.Any("/", func(ctx *gin.Context) { operationID = fog.OperationID(ctx.Request.Context()) l := fog.FromContext(ctx) l.Info("test") }) router.ServeHTTP(r, req) output := buf.String() assert.Contains(t, output, fmt.Sprintf(`"%s":"%s"`, correlationIDLabel, correlationID)) assert.Contains(t, output, fmt.Sprintf(`"%s":{"%s":"%s"}`, fog.OperationKey, "id", operationID)) } func TestRequestFinishedLog(t *testing.T) { tests := map[string]struct { status int exp string }{ "200": { status: http.StatusOK, exp: "true", }, "300": { status: http.StatusMultipleChoices, exp: "true", }, "400": { status: http.StatusBadRequest, exp: "false", }, "500": { status: http.StatusInternalServerError, exp: "false", }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { r := httptest.NewRecorder() buf := new(bytes.Buffer) router := createTestRouter(r, buf) req := httptest.NewRequest(http.MethodGet, "/", bytes.NewBuffer([]byte{})) router.Any("/", func(ctx *gin.Context) { ctx.Status(tc.status) }) router.ServeHTTP(r, req) output := buf.String() assert.Contains(t, output, "Request received") assert.Contains(t, output, "Request completed") assert.Contains(t, output, fmt.Sprintf(`"isSuccessful":%s`, tc.exp)) }) } } func TestGetCorrelationID(t *testing.T) { r := httptest.NewRecorder() buf := new(bytes.Buffer) router := createTestRouter(r, buf) req := httptest.NewRequest(http.MethodGet, "/", bytes.NewBuffer([]byte{})) expected := "correlationID" req.Header.Set(CorrelationIDKey, expected) var actual string router.Any("/", func(ctx *gin.Context) { actual = GetCorrelationID(ctx) }) router.ServeHTTP(r, req) assert.Equal(t, expected, actual) } func TestNoDuplicatedLogValues(t *testing.T) { r := httptest.NewRecorder() buf := new(bytes.Buffer) router := createTestRouter(r, buf) router.Any("/", func(ctx *gin.Context) { log := fog.FromContext(ctx) log.Info("test") }) correlationIDs := []string{"123", "456", "789"} for i, correlationID := range correlationIDs { req := httptest.NewRequest(http.MethodGet, "/", bytes.NewBuffer([]byte{})) req.Header.Set(CorrelationIDKey, correlationID) router.ServeHTTP(r, req) assert.Contains(t, buf.String(), fmt.Sprintf(`"%s":"%s"`, correlationIDLabel, correlationID)) for x := i - 1; x >= 0; x-- { assert.NotContains(t, buf.String(), fmt.Sprintf(`"%s":"%s"`, correlationIDLabel, correlationIDs[x])) } buf.Reset() } } func createTestRouter(r *httptest.ResponseRecorder, buf *bytes.Buffer) (router *gin.Engine) { _, router = gin.CreateTestContext(r) router.ContextWithFallback = true router.Use(middleware.SetRequestContext()) router.Use(requestid.New(requestid.WithCustomHeaderStrKey(CorrelationIDKey))) router.Use(SetLoggerInContext(fog.New(fog.To(buf)))) router.Use(RequestBookendLogs()) return router }