package edgeencrypt import ( "context" "crypto/rsa" "errors" "fmt" "net/http" "strings" "time" "github.com/gin-contrib/requestid" "github.com/gin-gonic/gin" "github.com/go-logr/logr" "github.com/penglongli/gin-metrics/ginmetrics" "edge-infra.dev/pkg/lib/fog" ) const ( CorrelationIDKey = "X-Correlation-ID" EncryptionDefaultBodySize = 10 << 20 // 10 MB ) type claimsContextKey struct{} // ClaimsIntoContext save encryption claims into context func ClaimsIntoContext(ctx context.Context, claims *EncryptionClaims) context.Context { return context.WithValue(ctx, claimsContextKey{}, claims) } // ClaimsFromContext get encryption claims from context func ClaimsFromContext(ctx context.Context) *EncryptionClaims { u, ok := ctx.Value(claimsContextKey{}).(*EncryptionClaims) if ok { return u } return nil } type PublicKeyGetter func(ctx context.Context) (*rsa.PublicKey, error) func BearerToken(getPublicKey PublicKeyGetter, role Role) gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context() auth := c.GetHeader("Authorization") if !strings.HasPrefix(auth, "Bearer ") { _ = c.Error(fmt.Errorf("invalid Authorization header")) c.AbortWithStatus(http.StatusUnauthorized) return } token := strings.Split(auth, " ") if len(token) != 2 { _ = c.Error(fmt.Errorf("invalid Authorization header, token length: %d ", len(token))) c.AbortWithStatus(http.StatusUnauthorized) return } pk, err := getPublicKey(ctx) if err != nil { _ = c.Error(fmt.Errorf("failed to get BearerToken %s public key: %w", role, err)) c.AbortWithStatus(http.StatusServiceUnavailable) return } claims, err := FromToken(pk, token[1]) if err != nil { _ = c.Error(fmt.Errorf("invalid bearer token: %w", err)) c.AbortWithStatus(http.StatusUnauthorized) return } if !claims.HasRole(role) { _ = c.Error(fmt.Errorf("invalid role: %s, required: %s", claims.Role, role)) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": fmt.Sprintf("invalid role: %s, required: %s", claims.Role, role), }) return } ctx = ClaimsIntoContext(ctx, claims) c.Request = c.Request.Clone(ctx) c.Next() } } func MetricServer(appRouter *gin.Engine) *gin.Engine { metricRouter := gin.New() metricRouter.Use(gin.Recovery()) // get global Monitor object m := ginmetrics.GetMonitor() // +optional set metric path, default /debug/metrics m.SetMetricPath("/metrics") // +optional set slow time, default 5s m.SetSlowTime(3) m.UseWithoutExposingEndpoint(appRouter) m.Expose(metricRouter) return metricRouter } func RequestLogger(log logr.Logger) gin.HandlerFunc { return func(c *gin.Context) { if c.FullPath() == "/health" || c.Request.Method == http.MethodOptions { c.Next() return } log := log.WithValues("correlation-id", requestid.Get(c), "method", c.Request.Method, "path", c.Request.URL.Path, "query", c.Request.URL.RawQuery, "ip", c.ClientIP()) ctx := c.Request.Context() ctx = fog.IntoContext(ctx, log) c.Request = c.Request.Clone(ctx) start := time.Now() c.Next() if c.Writer.Status()/100 != 2 { var err error if len(c.Errors) == 0 { err = fmt.Errorf("invalid response status: %d", c.Writer.Status()) } else { err = errors.New(c.Errors.String()) } log.Error(err, "invalid response status", "status", c.Writer.Status(), "duration", time.Since(start)) } } } func MaxRequestBodySize(max int64) gin.HandlerFunc { return func(c *gin.Context) { c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, max) c.Next() if len(c.Errors) > 0 && errors.Is(c.Errors.Last(), http.ErrMissingFile) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": "body too large", }) } } }