package middleware import ( "bytes" "errors" "fmt" "net/http" "time" "github.com/gin-contrib/requestid" "github.com/gin-gonic/gin" "github.com/go-logr/logr" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/interlock/internal/constants" "edge-infra.dev/pkg/sds/interlock/internal/observability" ) const ( clientIPLabel = "clientIP" requestIDLabel = "X-Request-ID" ) // responseWriter wraps the default ResponseWriter to add functionality to // capture the body of the response. type responseWriter struct { gin.ResponseWriter body *bytes.Buffer } func SetAccessControlHeaders() gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "*") c.Header("Access-Control-Allow-Headers", "*") } } // Write the response bytes to the body buffer and then write the body with the // default ResponseWriter. func (r responseWriter) Write(b []byte) (int, error) { r.body.Write(b) return r.ResponseWriter.Write(b) } // SetLoggerInContext appends the request ID, remote/client IP, request method // and request URL to the provided logger and loads the logger into the context. func SetLoggerInContext(log logr.Logger) gin.HandlerFunc { return func(c *gin.Context) { requestID := requestid.Get(c) requestInfo := map[string]string{ "method": c.Request.Method, "path": c.Request.URL.Path, "query": c.Request.URL.RawQuery, } ctcLog := log.WithValues(requestIDLabel, requestID). WithValues(clientIPLabel, c.ClientIP()). WithValues("request", requestInfo) c.Request = c.Request.Clone(fog.IntoContext(c.Request.Context(), ctcLog)) c.Next() } } // RequestLogger logs request results, adding structured logging to the default // gin implementation. The response status, response body (if error status) and // the time taken to respond is appended to the log values. func RequestLogger() gin.HandlerFunc { return func(c *gin.Context) { if c.FullPath() == "/health" || c.Request.Method == http.MethodOptions { c.Next() return } rw := &responseWriter{ ResponseWriter: c.Writer, body: bytes.NewBufferString(""), } c.Writer = rw t := time.Now() c.Next() logRequest(c, rw.body.String(), time.Since(t)) } } // TODO: Limit this to only log errors unless in debug mode at release time // logRequest logs the request with the response results. func logRequest(c *gin.Context, responseBody string, timeElapsed time.Duration) { status := c.Writer.Status() latency := float64(timeElapsed.Microseconds()) / 1000 logmsg := fmt.Sprintf("%s %s %d %s", c.Request.Method, c.Request.URL.Path, status, http.StatusText(status)) responseInfo := map[string]interface{}{ "bodySize": c.Writer.Size(), "status": status, "text": http.StatusText(status), "latency": fmt.Sprintf("%.2f ms", latency), } if status >= 400 { // only append the response body to errors responseInfo["body"] = responseBody } log := fog.FromContext(c.Request.Context()). WithValues("response", responseInfo) switch { case status < 400: // for 100s,200s, 300s responses log info log.Info(logmsg) case status >= 400 && status < 500: // for 400s responses log info with audit key log.WithValues(constants.AuditKey, "").Info(logmsg) case status >= 500: // for 500 responses log error log.Error(errors.New(http.StatusText(status)), logmsg) } } // RequestRecorder records the request result as prometheus metrics. This is // included in the count of different response status' and the latency // histogram. func RequestRecorder() gin.HandlerFunc { return func(c *gin.Context) { t := time.Now() c.Next() path := c.Request.URL.Path method := c.Request.Method status := c.Writer.Status() latency := float64(time.Since(t).Microseconds()) / 1000 observability.RecordRequest(path, method, status, latency) } }