1 package middleware
2
3 import (
4 "bytes"
5 "context"
6 "database/sql"
7 "encoding/hex"
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "regexp"
13 "slices"
14 "strings"
15 "time"
16
17 "edge-infra.dev/pkg/edge/api/bsl/types"
18 "edge-infra.dev/pkg/edge/api/graph/model"
19 sqlstatements "edge-infra.dev/pkg/edge/api/sql"
20 "edge-infra.dev/pkg/edge/api/totp"
21 "edge-infra.dev/pkg/edge/api/utils"
22 "edge-infra.dev/pkg/edge/client"
23 "edge-infra.dev/pkg/lib/crypto"
24
25 "github.com/gin-gonic/gin"
26 "github.com/golang-jwt/jwt"
27 )
28
29 var userCtxKey = &contextKey{"user"}
30 var TerminalIDCtxKey = &contextKey{"terminalId"}
31 var EdgeAPIEndpointCtxKey = &contextKey{"edgeAPIEndpoint"}
32 var ClusterEdgeIDCtxKey = &contextKey{"clusterEdgeId"}
33
34 type contextKey struct {
35 name string
36 }
37
38
39 func AuthMiddleware(jwtTokenSecret, totpSecret string, db *sql.DB) gin.HandlerFunc {
40 return func(c *gin.Context) {
41 handleEdgeAuth(c, jwtTokenSecret, totpSecret, db)
42 }
43 }
44
45 func validateActivationCode(activationCode string, db *sql.DB) (*types.AuthUser, string, error) {
46 reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second)
47 defer cancelReq()
48
49 hash := crypto.HashActivation(activationCode)
50 encodedHash := hex.EncodeToString(hash)
51
52 row := db.QueryRowContext(reqCtx, sqlstatements.TerminalActivationQuery, encodedHash)
53 var activationHash, terminalID string
54 if err := row.Scan(&activationHash, &terminalID); err != nil {
55 return nil, "", err
56 }
57
58 return &types.AuthUser{
59 Roles: []string{string(model.RoleEdgeTerminal)},
60 }, terminalID, nil
61 }
62
63 func validateEdgeBootstrapToken(bootstrapToken string, db *sql.DB) (*types.AuthUser, string, error) {
64 reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second)
65 defer cancelReq()
66
67 hashEncoded := hex.EncodeToString(crypto.HashEdgeBootstrapToken([]byte(bootstrapToken)))
68 row := db.QueryRowContext(reqCtx, sqlstatements.GetEdgeBootstrapToken, hashEncoded)
69 var bootstrapTokenHash, bootstrapExpiry, clusterEdgeID string
70 if err := row.Scan(&bootstrapTokenHash, &bootstrapExpiry, &clusterEdgeID); err != nil {
71 return nil, "", err
72 }
73
74 timeUTCNow := time.Now().UTC()
75 expiryTime, err := time.Parse(time.RFC3339, bootstrapExpiry)
76 if err != nil {
77 return nil, "", fmt.Errorf("could not parse bootstrap expiry time")
78 }
79
80 if expiryTime.Before(timeUTCNow) {
81 return nil, "", fmt.Errorf("bootstrap token has expired")
82 }
83
84 if clusterEdgeID == "" {
85 return nil, "", fmt.Errorf("invalid cluster edge id returned on edge bootstrap token auth")
86 }
87
88 return &types.AuthUser{
89 Roles: []string{string(model.RoleEdgeBootstrap)},
90 }, clusterEdgeID, nil
91 }
92
93 func handleEdgeAuth(c *gin.Context, jwtTokenSecret string, totpSecret string, db *sql.DB) {
94 auth := c.GetHeader("Authorization")
95 if auth == "" {
96 c.Next()
97 return
98 }
99
100 fullPath := strings.ToLower(c.FullPath())
101 if strings.Contains(fullPath, "/validate_token") {
102 c.Next()
103 return
104 }
105
106 defer c.Request.Body.Close()
107 body, err := io.ReadAll(c.Request.Body)
108 if err != nil {
109 c.String(http.StatusForbidden, fmt.Sprintf("error reading request: %v", err))
110 return
111 }
112 c.Request.Body = io.NopCloser((bytes.NewBuffer(body)))
113
114 edgeAPIEndpoint := getEdgeEndpoint(c)
115
116
117 if strings.Contains(string(body), "terminalBootstrap") {
118 if err := crypto.ValidActivationCode(auth); err != nil {
119 c.String(http.StatusForbidden, fmt.Sprintf("%v - please check the activation code has been copied correctly", err))
120 return
121 }
122
123 user, terminalID, err := validateActivationCode(auth, db)
124 if user != nil && err == nil {
125 ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
126 ctx = context.WithValue(ctx, TerminalIDCtxKey, terminalID)
127 ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
128 c.Request = c.Request.WithContext(ctx)
129 c.Next()
130 return
131 }
132 if errors.Is(err, sql.ErrNoRows) {
133 c.String(http.StatusForbidden, "activation code is incorrect or has been consumed, please check the activation code has been copied correctly. Otherwise request a new one.")
134 } else {
135 c.String(http.StatusForbidden, fmt.Sprintf("Please try again. There was an error attempting to validate the activation code: %v", err))
136 }
137 return
138 }
139
140
141
142 r, _ := regexp.Compile("^(bearer )[A-Fa-f0-9]{64,64}$")
143 if strings.Contains(string(body), "bootstrapCluster") && r.MatchString(auth) {
144 auth := strings.Replace(auth, "bearer ", "", 1)
145 user, clusterEdgeID, err := validateEdgeBootstrapToken(auth, db)
146 if user != nil && err == nil {
147 ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
148 ctx = context.WithValue(ctx, ClusterEdgeIDCtxKey, clusterEdgeID)
149 ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
150 c.Request = c.Request.WithContext(ctx)
151 c.Next()
152 return
153 }
154 c.String(http.StatusForbidden, fmt.Sprintf("invalid edge bootstrap token: %v", err))
155 return
156 }
157
158 tokenParts := strings.Split(auth, " ")
159 if len(tokenParts) != 2 {
160 c.String(http.StatusBadRequest, "invalid auth token")
161 return
162 }
163
164 tokenType := strings.ToLower(tokenParts[0])
165 if !utils.Contains([]string{client.BearerToken, client.TotpToken}, tokenType) {
166 c.String(http.StatusForbidden, "invalid auth token")
167 return
168 }
169
170 user, err := getUserAndRolesFromAuth(tokenType, tokenParts[1], jwtTokenSecret, totpSecret)
171 if err != nil {
172 c.String(http.StatusForbidden, err.Error())
173 return
174 }
175
176
177 ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
178 ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
179
180
181 c.Request = c.Request.WithContext(ctx)
182 c.Next()
183 }
184
185 func getEdgeEndpoint(c *gin.Context) string {
186 return fmt.Sprintf("https://%s%s", c.Request.Host, c.Request.RequestURI)
187 }
188
189 func getUserAndRolesFromAuth(tokenType, token, secret, totpSecret string) (*types.AuthUser, error) {
190 if tokenType == client.BearerToken {
191 return ValidateAndGetUser(token, secret)
192 }
193 if err := totp.ValidateTotpToken(token, totpSecret); err != nil {
194 return nil, err
195 }
196 return &types.AuthUser{
197 Roles: []string{string(model.RoleTotpRole)},
198 }, nil
199 }
200
201
202 func ForContext(ctx context.Context) *types.AuthUser {
203 user, _ := ctx.Value(userCtxKey).(*types.AuthUser)
204 return user
205 }
206
207
208 func NewContext(ctx context.Context, value interface{}) context.Context {
209 return context.WithValue(ctx, userCtxKey, value)
210 }
211
212
213 func ValidateAndGetUser(tokenString, jwtTokenSecret string) (*types.AuthUser, error) {
214 claims := jwt.MapClaims{}
215 _, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) {
216
217 return []byte(jwtTokenSecret), nil
218 })
219 if err != nil {
220 return nil, err
221 }
222
223 var roles []string
224 if claims[types.Roles] != nil {
225 for _, role := range claims[types.Roles].([]interface{}) {
226 roles = append(roles, fmt.Sprint(role))
227 }
228 }
229 user := &types.AuthUser{
230 Username: claims[types.Username].(string),
231 Email: claims[types.Email].(string),
232 Roles: roles,
233 Organization: claims[types.Organization].(string),
234 AuthProvider: claims[types.AuthProvider].(string),
235 }
236 if val, ok := claims[types.Token]; ok {
237 user.Token = val.(string)
238 }
239 if val, ok := claims[types.RefreshToken]; ok {
240 user.RefreshToken = val.(string)
241 }
242 return user, nil
243 }
244
245
246 func CreateToken(username, email, organization, secret string, roles []string, token, authProvider, refreshToken string) (string, error) {
247 mc := jwt.MapClaims{}
248 mc[types.Username] = username
249 mc[types.Email] = email
250 mc[types.Roles] = roles
251 mc[types.Organization] = organization
252 mc[types.Token] = token
253 mc[types.RefreshToken] = refreshToken
254 mc[types.AuthProvider] = authProvider
255
256 claims := jwt.NewWithClaims(jwt.SigningMethodHS512, mc)
257
258 token, err := claims.SignedString([]byte(secret))
259 if err != nil {
260 return "", err
261 }
262 return token, nil
263 }
264
265 func ValidateToken(c *gin.Context, auth string, banner string, jwtSecret string, getUserBanners func(context.Context, string) ([]*model.BannerInfo, error)) error {
266 if auth == "" {
267 return fmt.Errorf("no authorization header found")
268 }
269
270 if banner == "" {
271 return fmt.Errorf("no banner header found")
272 }
273
274 tokenParts := strings.Split(auth, " ")
275
276 if len(tokenParts) != 2 {
277 return fmt.Errorf("invalid auth token")
278 }
279
280
281 user, err := ValidateAndGetUser(tokenParts[1], jwtSecret)
282 if err != nil {
283 return err
284 }
285
286
287 ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
288 c.Request = c.Request.WithContext(ctx)
289
290
291 name, err := ValidateBSLTokenAndGetName(user.Token)
292 if err != nil {
293 return err
294 }
295
296
297 for _, role := range user.Roles {
298 if role == string(model.RoleEdgeOrgAdmin) || role == string(model.RoleEdgeSuperAdmin) {
299 return nil
300 }
301 }
302
303
304 banners, err := getUserBanners(ctx, name)
305 if err != nil {
306 return err
307 }
308
309 for _, b := range banners {
310 if banner == b.BannerEdgeID {
311 return nil
312 }
313 }
314
315 return fmt.Errorf("user does not have access to banner")
316 }
317
318 func ValidateBSLTokenAndGetName(token string) (string, error) {
319 bslClaims := jwt.MapClaims{}
320 if _, _, err := new(jwt.Parser).ParseUnverified(token, bslClaims); err != nil {
321 return "", err
322 }
323
324 return bslClaims["sub"].(string), bslClaims.Valid()
325 }
326
327 func GetEdgeRoles(ctx context.Context) ([]string, error) {
328 u := ForContext(ctx)
329 userRoles := []string{}
330 edgeRoles := []string{
331 "EDGE_SUPER_ADMIN",
332 "EDGE_ORG_ADMIN",
333 "EDGE_BANNER_ADMIN",
334 "EDGE_BANNER_OPERATOR",
335 "EDGE_BANNER_VIEWER",
336 "EDGE_OI_ADMIN",
337 }
338 for _, edgeRole := range edgeRoles {
339 if slices.Contains(u.Roles, edgeRole) {
340 userRoles = append(userRoles, edgeRole)
341 }
342 }
343 if len(userRoles) == 0 {
344 return userRoles, fmt.Errorf("edge role not found for context user")
345 }
346 return userRoles, nil
347 }
348
View as plain text