...

Source file src/edge-infra.dev/pkg/edge/edgeencrypt/middleware.go

Documentation: edge-infra.dev/pkg/edge/edgeencrypt

     1  package edgeencrypt
     2  
     3  import (
     4  	"context"
     5  	"crypto/rsa"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/gin-contrib/requestid"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/go-logr/logr"
    15  	"github.com/penglongli/gin-metrics/ginmetrics"
    16  
    17  	"edge-infra.dev/pkg/lib/fog"
    18  )
    19  
    20  const (
    21  	CorrelationIDKey          = "X-Correlation-ID"
    22  	EncryptionDefaultBodySize = 10 << 20 // 10 MB
    23  )
    24  
    25  type claimsContextKey struct{}
    26  
    27  // ClaimsIntoContext save encryption claims into context
    28  func ClaimsIntoContext(ctx context.Context, claims *EncryptionClaims) context.Context {
    29  	return context.WithValue(ctx, claimsContextKey{}, claims)
    30  }
    31  
    32  // ClaimsFromContext get encryption claims from context
    33  func ClaimsFromContext(ctx context.Context) *EncryptionClaims {
    34  	u, ok := ctx.Value(claimsContextKey{}).(*EncryptionClaims)
    35  	if ok {
    36  		return u
    37  	}
    38  	return nil
    39  }
    40  
    41  type PublicKeyGetter func(ctx context.Context) (*rsa.PublicKey, error)
    42  
    43  func BearerToken(getPublicKey PublicKeyGetter, role Role) gin.HandlerFunc {
    44  	return func(c *gin.Context) {
    45  		ctx := c.Request.Context()
    46  		auth := c.GetHeader("Authorization")
    47  		if !strings.HasPrefix(auth, "Bearer ") {
    48  			_ = c.Error(fmt.Errorf("invalid Authorization header"))
    49  			c.AbortWithStatus(http.StatusUnauthorized)
    50  			return
    51  		}
    52  
    53  		token := strings.Split(auth, " ")
    54  		if len(token) != 2 {
    55  			_ = c.Error(fmt.Errorf("invalid Authorization header, token length: %d ", len(token)))
    56  			c.AbortWithStatus(http.StatusUnauthorized)
    57  			return
    58  		}
    59  
    60  		pk, err := getPublicKey(ctx)
    61  		if err != nil {
    62  			_ = c.Error(fmt.Errorf("failed to get BearerToken %s public key: %w", role, err))
    63  			c.AbortWithStatus(http.StatusServiceUnavailable)
    64  			return
    65  		}
    66  
    67  		claims, err := FromToken(pk, token[1])
    68  		if err != nil {
    69  			_ = c.Error(fmt.Errorf("invalid bearer token: %w", err))
    70  			c.AbortWithStatus(http.StatusUnauthorized)
    71  			return
    72  		}
    73  
    74  		if !claims.HasRole(role) {
    75  			_ = c.Error(fmt.Errorf("invalid role: %s, required: %s", claims.Role, role))
    76  			c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
    77  				"error": fmt.Sprintf("invalid role: %s, required: %s", claims.Role, role),
    78  			})
    79  			return
    80  		}
    81  
    82  		ctx = ClaimsIntoContext(ctx, claims)
    83  
    84  		c.Request = c.Request.Clone(ctx)
    85  
    86  		c.Next()
    87  	}
    88  }
    89  
    90  func MetricServer(appRouter *gin.Engine) *gin.Engine {
    91  	metricRouter := gin.New()
    92  	metricRouter.Use(gin.Recovery())
    93  
    94  	// get global Monitor object
    95  	m := ginmetrics.GetMonitor()
    96  
    97  	// +optional set metric path, default /debug/metrics
    98  	m.SetMetricPath("/metrics")
    99  
   100  	// +optional set slow time, default 5s
   101  	m.SetSlowTime(3)
   102  
   103  	m.UseWithoutExposingEndpoint(appRouter)
   104  
   105  	m.Expose(metricRouter)
   106  
   107  	return metricRouter
   108  }
   109  
   110  func RequestLogger(log logr.Logger) gin.HandlerFunc {
   111  	return func(c *gin.Context) {
   112  		if c.FullPath() == "/health" || c.Request.Method == http.MethodOptions {
   113  			c.Next()
   114  			return
   115  		}
   116  
   117  		log := log.WithValues("correlation-id", requestid.Get(c),
   118  			"method", c.Request.Method,
   119  			"path", c.Request.URL.Path,
   120  			"query", c.Request.URL.RawQuery,
   121  			"ip", c.ClientIP())
   122  
   123  		ctx := c.Request.Context()
   124  		ctx = fog.IntoContext(ctx, log)
   125  		c.Request = c.Request.Clone(ctx)
   126  
   127  		start := time.Now()
   128  		c.Next()
   129  		if c.Writer.Status()/100 != 2 {
   130  			var err error
   131  			if len(c.Errors) == 0 {
   132  				err = fmt.Errorf("invalid response status: %d", c.Writer.Status())
   133  			} else {
   134  				err = errors.New(c.Errors.String())
   135  			}
   136  			log.Error(err, "invalid response status", "status", c.Writer.Status(), "duration", time.Since(start))
   137  		}
   138  	}
   139  }
   140  
   141  func MaxRequestBodySize(max int64) gin.HandlerFunc {
   142  	return func(c *gin.Context) {
   143  		c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, max)
   144  		c.Next()
   145  		if len(c.Errors) > 0 && errors.Is(c.Errors.Last(), http.ErrMissingFile) {
   146  			c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
   147  				"error": "body too large",
   148  			})
   149  		}
   150  	}
   151  }
   152  

View as plain text