1 package edgeencrypt
2
3 import (
4 "context"
5 "crypto/rsa"
6 "errors"
7 "fmt"
8 "net/http"
9 "strings"
10 "time"
11
12 "github.com/gin-contrib/requestid"
13 "github.com/gin-gonic/gin"
14 "github.com/go-logr/logr"
15 "github.com/penglongli/gin-metrics/ginmetrics"
16
17 "edge-infra.dev/pkg/lib/fog"
18 )
19
20 const (
21 CorrelationIDKey = "X-Correlation-ID"
22 EncryptionDefaultBodySize = 10 << 20
23 )
24
25 type claimsContextKey struct{}
26
27
28 func ClaimsIntoContext(ctx context.Context, claims *EncryptionClaims) context.Context {
29 return context.WithValue(ctx, claimsContextKey{}, claims)
30 }
31
32
33 func ClaimsFromContext(ctx context.Context) *EncryptionClaims {
34 u, ok := ctx.Value(claimsContextKey{}).(*EncryptionClaims)
35 if ok {
36 return u
37 }
38 return nil
39 }
40
41 type PublicKeyGetter func(ctx context.Context) (*rsa.PublicKey, error)
42
43 func BearerToken(getPublicKey PublicKeyGetter, role Role) gin.HandlerFunc {
44 return func(c *gin.Context) {
45 ctx := c.Request.Context()
46 auth := c.GetHeader("Authorization")
47 if !strings.HasPrefix(auth, "Bearer ") {
48 _ = c.Error(fmt.Errorf("invalid Authorization header"))
49 c.AbortWithStatus(http.StatusUnauthorized)
50 return
51 }
52
53 token := strings.Split(auth, " ")
54 if len(token) != 2 {
55 _ = c.Error(fmt.Errorf("invalid Authorization header, token length: %d ", len(token)))
56 c.AbortWithStatus(http.StatusUnauthorized)
57 return
58 }
59
60 pk, err := getPublicKey(ctx)
61 if err != nil {
62 _ = c.Error(fmt.Errorf("failed to get BearerToken %s public key: %w", role, err))
63 c.AbortWithStatus(http.StatusServiceUnavailable)
64 return
65 }
66
67 claims, err := FromToken(pk, token[1])
68 if err != nil {
69 _ = c.Error(fmt.Errorf("invalid bearer token: %w", err))
70 c.AbortWithStatus(http.StatusUnauthorized)
71 return
72 }
73
74 if !claims.HasRole(role) {
75 _ = c.Error(fmt.Errorf("invalid role: %s, required: %s", claims.Role, role))
76 c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
77 "error": fmt.Sprintf("invalid role: %s, required: %s", claims.Role, role),
78 })
79 return
80 }
81
82 ctx = ClaimsIntoContext(ctx, claims)
83
84 c.Request = c.Request.Clone(ctx)
85
86 c.Next()
87 }
88 }
89
90 func MetricServer(appRouter *gin.Engine) *gin.Engine {
91 metricRouter := gin.New()
92 metricRouter.Use(gin.Recovery())
93
94
95 m := ginmetrics.GetMonitor()
96
97
98 m.SetMetricPath("/metrics")
99
100
101 m.SetSlowTime(3)
102
103 m.UseWithoutExposingEndpoint(appRouter)
104
105 m.Expose(metricRouter)
106
107 return metricRouter
108 }
109
110 func RequestLogger(log logr.Logger) gin.HandlerFunc {
111 return func(c *gin.Context) {
112 if c.FullPath() == "/health" || c.Request.Method == http.MethodOptions {
113 c.Next()
114 return
115 }
116
117 log := log.WithValues("correlation-id", requestid.Get(c),
118 "method", c.Request.Method,
119 "path", c.Request.URL.Path,
120 "query", c.Request.URL.RawQuery,
121 "ip", c.ClientIP())
122
123 ctx := c.Request.Context()
124 ctx = fog.IntoContext(ctx, log)
125 c.Request = c.Request.Clone(ctx)
126
127 start := time.Now()
128 c.Next()
129 if c.Writer.Status()/100 != 2 {
130 var err error
131 if len(c.Errors) == 0 {
132 err = fmt.Errorf("invalid response status: %d", c.Writer.Status())
133 } else {
134 err = errors.New(c.Errors.String())
135 }
136 log.Error(err, "invalid response status", "status", c.Writer.Status(), "duration", time.Since(start))
137 }
138 }
139 }
140
141 func MaxRequestBodySize(max int64) gin.HandlerFunc {
142 return func(c *gin.Context) {
143 c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, max)
144 c.Next()
145 if len(c.Errors) > 0 && errors.Is(c.Errors.Last(), http.ErrMissingFile) {
146 c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
147 "error": "body too large",
148 })
149 }
150 }
151 }
152
View as plain text