...

Source file src/edge-infra.dev/pkg/sds/interlock/internal/middleware/middleware.go

Documentation: edge-infra.dev/pkg/sds/interlock/internal/middleware

     1  package middleware
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/gin-contrib/requestid"
    11  	"github.com/gin-gonic/gin"
    12  	"github.com/go-logr/logr"
    13  
    14  	"edge-infra.dev/pkg/lib/fog"
    15  	"edge-infra.dev/pkg/sds/interlock/internal/constants"
    16  	"edge-infra.dev/pkg/sds/interlock/internal/observability"
    17  )
    18  
    19  const (
    20  	clientIPLabel  = "clientIP"
    21  	requestIDLabel = "X-Request-ID"
    22  )
    23  
    24  // responseWriter wraps the default ResponseWriter to add functionality to
    25  // capture the body of the response.
    26  type responseWriter struct {
    27  	gin.ResponseWriter
    28  	body *bytes.Buffer
    29  }
    30  
    31  func SetAccessControlHeaders() gin.HandlerFunc {
    32  	return func(c *gin.Context) {
    33  		c.Header("Access-Control-Allow-Origin", "*")
    34  		c.Header("Access-Control-Allow-Methods", "*")
    35  		c.Header("Access-Control-Allow-Headers", "*")
    36  	}
    37  }
    38  
    39  // Write the response bytes to the body buffer and then write the body with the
    40  // default ResponseWriter.
    41  func (r responseWriter) Write(b []byte) (int, error) {
    42  	r.body.Write(b)
    43  	return r.ResponseWriter.Write(b)
    44  }
    45  
    46  // SetLoggerInContext appends the request ID, remote/client IP, request method
    47  // and request URL to the provided logger and loads the logger into the context.
    48  func SetLoggerInContext(log logr.Logger) gin.HandlerFunc {
    49  	return func(c *gin.Context) {
    50  		requestID := requestid.Get(c)
    51  		requestInfo := map[string]string{
    52  			"method": c.Request.Method,
    53  			"path":   c.Request.URL.Path,
    54  			"query":  c.Request.URL.RawQuery,
    55  		}
    56  		ctcLog := log.WithValues(requestIDLabel, requestID).
    57  			WithValues(clientIPLabel, c.ClientIP()).
    58  			WithValues("request", requestInfo)
    59  
    60  		c.Request = c.Request.Clone(fog.IntoContext(c.Request.Context(), ctcLog))
    61  		c.Next()
    62  	}
    63  }
    64  
    65  // RequestLogger logs request results, adding structured logging to the default
    66  // gin implementation. The response status, response body (if error status) and
    67  // the time taken to respond is appended to the log values.
    68  func RequestLogger() gin.HandlerFunc {
    69  	return func(c *gin.Context) {
    70  		if c.FullPath() == "/health" || c.Request.Method == http.MethodOptions {
    71  			c.Next()
    72  			return
    73  		}
    74  
    75  		rw := &responseWriter{
    76  			ResponseWriter: c.Writer,
    77  			body:           bytes.NewBufferString(""),
    78  		}
    79  		c.Writer = rw
    80  
    81  		t := time.Now()
    82  		c.Next()
    83  
    84  		logRequest(c, rw.body.String(), time.Since(t))
    85  	}
    86  }
    87  
    88  // TODO: Limit this to only log errors unless in debug mode at release time
    89  // logRequest logs the request with the response results.
    90  func logRequest(c *gin.Context, responseBody string, timeElapsed time.Duration) {
    91  	status := c.Writer.Status()
    92  	latency := float64(timeElapsed.Microseconds()) / 1000
    93  	logmsg := fmt.Sprintf("%s %s %d %s", c.Request.Method, c.Request.URL.Path, status, http.StatusText(status))
    94  	responseInfo := map[string]interface{}{
    95  		"bodySize": c.Writer.Size(),
    96  		"status":   status,
    97  		"text":     http.StatusText(status),
    98  		"latency":  fmt.Sprintf("%.2f ms", latency),
    99  	}
   100  
   101  	if status >= 400 { // only append the response body to errors
   102  		responseInfo["body"] = responseBody
   103  	}
   104  	log := fog.FromContext(c.Request.Context()).
   105  		WithValues("response", responseInfo)
   106  
   107  	switch {
   108  	case status < 400: // for 100s,200s, 300s responses log info
   109  		log.Info(logmsg)
   110  	case status >= 400 && status < 500: // for 400s responses log info with audit key
   111  		log.WithValues(constants.AuditKey, "").Info(logmsg)
   112  	case status >= 500: // for 500 responses log error
   113  		log.Error(errors.New(http.StatusText(status)), logmsg)
   114  	}
   115  }
   116  
   117  // RequestRecorder records the request result as prometheus metrics. This is
   118  // included in the count of different response status' and the latency
   119  // histogram.
   120  func RequestRecorder() gin.HandlerFunc {
   121  	return func(c *gin.Context) {
   122  		t := time.Now()
   123  		c.Next()
   124  
   125  		path := c.Request.URL.Path
   126  		method := c.Request.Method
   127  		status := c.Writer.Status()
   128  		latency := float64(time.Since(t).Microseconds()) / 1000
   129  		observability.RecordRequest(path, method, status, latency)
   130  	}
   131  }
   132  

View as plain text