package middleware import ( "bytes" "context" "database/sql" "encoding/hex" "errors" "fmt" "io" "net/http" "regexp" "slices" "strings" "time" "edge-infra.dev/pkg/edge/api/bsl/types" "edge-infra.dev/pkg/edge/api/graph/model" sqlstatements "edge-infra.dev/pkg/edge/api/sql" "edge-infra.dev/pkg/edge/api/totp" "edge-infra.dev/pkg/edge/api/utils" "edge-infra.dev/pkg/edge/client" "edge-infra.dev/pkg/lib/crypto" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" ) var userCtxKey = &contextKey{"user"} var TerminalIDCtxKey = &contextKey{"terminalId"} var EdgeAPIEndpointCtxKey = &contextKey{"edgeAPIEndpoint"} var ClusterEdgeIDCtxKey = &contextKey{"clusterEdgeId"} type contextKey struct { name string } // AuthMiddleware extract auth token from header, validate it and it to the http context. func AuthMiddleware(jwtTokenSecret, totpSecret string, db *sql.DB) gin.HandlerFunc { return func(c *gin.Context) { handleEdgeAuth(c, jwtTokenSecret, totpSecret, db) } } func validateActivationCode(activationCode string, db *sql.DB) (*types.AuthUser, string, error) { reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second) defer cancelReq() hash := crypto.HashActivation(activationCode) encodedHash := hex.EncodeToString(hash) row := db.QueryRowContext(reqCtx, sqlstatements.TerminalActivationQuery, encodedHash) var activationHash, terminalID string if err := row.Scan(&activationHash, &terminalID); err != nil { return nil, "", err } return &types.AuthUser{ Roles: []string{string(model.RoleEdgeTerminal)}, }, terminalID, nil } func validateEdgeBootstrapToken(bootstrapToken string, db *sql.DB) (*types.AuthUser, string, error) { reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second) defer cancelReq() hashEncoded := hex.EncodeToString(crypto.HashEdgeBootstrapToken([]byte(bootstrapToken))) row := db.QueryRowContext(reqCtx, sqlstatements.GetEdgeBootstrapToken, hashEncoded) var bootstrapTokenHash, bootstrapExpiry, clusterEdgeID string if err := row.Scan(&bootstrapTokenHash, &bootstrapExpiry, &clusterEdgeID); err != nil { return nil, "", err } timeUTCNow := time.Now().UTC() expiryTime, err := time.Parse(time.RFC3339, bootstrapExpiry) if err != nil { return nil, "", fmt.Errorf("could not parse bootstrap expiry time") } if expiryTime.Before(timeUTCNow) { return nil, "", fmt.Errorf("bootstrap token has expired") } if clusterEdgeID == "" { return nil, "", fmt.Errorf("invalid cluster edge id returned on edge bootstrap token auth") } return &types.AuthUser{ Roles: []string{string(model.RoleEdgeBootstrap)}, }, clusterEdgeID, nil } func handleEdgeAuth(c *gin.Context, jwtTokenSecret string, totpSecret string, db *sql.DB) { auth := c.GetHeader("Authorization") if auth == "" { c.Next() return } fullPath := strings.ToLower(c.FullPath()) if strings.Contains(fullPath, "/validate_token") { c.Next() return } defer c.Request.Body.Close() body, err := io.ReadAll(c.Request.Body) if err != nil { c.String(http.StatusForbidden, fmt.Sprintf("error reading request: %v", err)) return } c.Request.Body = io.NopCloser((bytes.NewBuffer(body))) edgeAPIEndpoint := getEdgeEndpoint(c) // Only use auth as activation code if the query is terminalBootstrap if strings.Contains(string(body), "terminalBootstrap") { if err := crypto.ValidActivationCode(auth); err != nil { c.String(http.StatusForbidden, fmt.Sprintf("%v - please check the activation code has been copied correctly", err)) return } user, terminalID, err := validateActivationCode(auth, db) if user != nil && err == nil { ctx := context.WithValue(c.Request.Context(), userCtxKey, user) ctx = context.WithValue(ctx, TerminalIDCtxKey, terminalID) ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint) c.Request = c.Request.WithContext(ctx) c.Next() return } if errors.Is(err, sql.ErrNoRows) { c.String(http.StatusForbidden, "activation code is incorrect or has been consumed, please check the activation code has been copied correctly. Otherwise request a new one.") } else { c.String(http.StatusForbidden, fmt.Sprintf("Please try again. There was an error attempting to validate the activation code: %v", err)) } return } // Only use auth as edge bootstrap token on mutation bootstrapCluster // token is a hex encoded string of 32 bytes (encoded length is 64 chars) r, _ := regexp.Compile("^(bearer )[A-Fa-f0-9]{64,64}$") if strings.Contains(string(body), "bootstrapCluster") && r.MatchString(auth) { auth := strings.Replace(auth, "bearer ", "", 1) user, clusterEdgeID, err := validateEdgeBootstrapToken(auth, db) if user != nil && err == nil { ctx := context.WithValue(c.Request.Context(), userCtxKey, user) ctx = context.WithValue(ctx, ClusterEdgeIDCtxKey, clusterEdgeID) ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint) c.Request = c.Request.WithContext(ctx) c.Next() return } c.String(http.StatusForbidden, fmt.Sprintf("invalid edge bootstrap token: %v", err)) return } tokenParts := strings.Split(auth, " ") if len(tokenParts) != 2 { c.String(http.StatusBadRequest, "invalid auth token") return } tokenType := strings.ToLower(tokenParts[0]) if !utils.Contains([]string{client.BearerToken, client.TotpToken}, tokenType) { c.String(http.StatusForbidden, "invalid auth token") return } user, err := getUserAndRolesFromAuth(tokenType, tokenParts[1], jwtTokenSecret, totpSecret) if err != nil { c.String(http.StatusForbidden, err.Error()) return } // put it in context ctx := context.WithValue(c.Request.Context(), userCtxKey, user) ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint) // and call the next with our new context c.Request = c.Request.WithContext(ctx) c.Next() } func getEdgeEndpoint(c *gin.Context) string { return fmt.Sprintf("https://%s%s", c.Request.Host, c.Request.RequestURI) } func getUserAndRolesFromAuth(tokenType, token, secret, totpSecret string) (*types.AuthUser, error) { if tokenType == client.BearerToken { return ValidateAndGetUser(token, secret) } if err := totp.ValidateTotpToken(token, totpSecret); err != nil { return nil, err } return &types.AuthUser{ Roles: []string{string(model.RoleTotpRole)}, }, nil } // ForContext finds the user from the context. REQUIRES Middleware to have run. func ForContext(ctx context.Context) *types.AuthUser { user, _ := ctx.Value(userCtxKey).(*types.AuthUser) return user } // NewContext creates a new context from the provided context with the supplied values and user ctx key. func NewContext(ctx context.Context, value interface{}) context.Context { return context.WithValue(ctx, userCtxKey, value) } // ValidateAndGetUser decode the jwt token into an auth user and using the secret to validate that the token is valid. func ValidateAndGetUser(tokenString, jwtTokenSecret string) (*types.AuthUser, error) { claims := jwt.MapClaims{} _, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) { // the jwtTokenSecret is used by the jwt library to validate that the token is valid after decoding it. return []byte(jwtTokenSecret), nil }) if err != nil { return nil, err } var roles []string if claims[types.Roles] != nil { for _, role := range claims[types.Roles].([]interface{}) { roles = append(roles, fmt.Sprint(role)) } } user := &types.AuthUser{ Username: claims[types.Username].(string), Email: claims[types.Email].(string), Roles: roles, Organization: claims[types.Organization].(string), AuthProvider: claims[types.AuthProvider].(string), } if val, ok := claims[types.Token]; ok { user.Token = val.(string) } if val, ok := claims[types.RefreshToken]; ok { user.RefreshToken = val.(string) } return user, nil } // CreateToken creates a jwt token signed with a secret for validation. func CreateToken(username, email, organization, secret string, roles []string, token, authProvider, refreshToken string) (string, error) { mc := jwt.MapClaims{} mc[types.Username] = username mc[types.Email] = email mc[types.Roles] = roles mc[types.Organization] = organization mc[types.Token] = token mc[types.RefreshToken] = refreshToken mc[types.AuthProvider] = authProvider claims := jwt.NewWithClaims(jwt.SigningMethodHS512, mc) token, err := claims.SignedString([]byte(secret)) if err != nil { return "", err } return token, nil } func ValidateToken(c *gin.Context, auth string, banner string, jwtSecret string, getUserBanners func(context.Context, string) ([]*model.BannerInfo, error)) error { if auth == "" { return fmt.Errorf("no authorization header found") } if banner == "" { return fmt.Errorf("no banner header found") } tokenParts := strings.Split(auth, " ") if len(tokenParts) != 2 { return fmt.Errorf("invalid auth token") } // Validate outer JWT token and grab values user, err := ValidateAndGetUser(tokenParts[1], jwtSecret) if err != nil { return err } // Put user into context ctx := context.WithValue(c.Request.Context(), userCtxKey, user) c.Request = c.Request.WithContext(ctx) // Validate inner token and get user fq name name, err := ValidateBSLTokenAndGetName(user.Token) if err != nil { return err } // We do not need to check banner for org-admins or super admins - they have access to all for _, role := range user.Roles { if role == string(model.RoleEdgeOrgAdmin) || role == string(model.RoleEdgeSuperAdmin) { return nil } } // Grab user banners via BannerService function and validate if user has access banners, err := getUserBanners(ctx, name) if err != nil { return err } for _, b := range banners { if banner == b.BannerEdgeID { return nil } } return fmt.Errorf("user does not have access to banner") } func ValidateBSLTokenAndGetName(token string) (string, error) { bslClaims := jwt.MapClaims{} if _, _, err := new(jwt.Parser).ParseUnverified(token, bslClaims); err != nil { return "", err } return bslClaims["sub"].(string), bslClaims.Valid() } func GetEdgeRoles(ctx context.Context) ([]string, error) { u := ForContext(ctx) userRoles := []string{} edgeRoles := []string{ "EDGE_SUPER_ADMIN", "EDGE_ORG_ADMIN", "EDGE_BANNER_ADMIN", "EDGE_BANNER_OPERATOR", "EDGE_BANNER_VIEWER", "EDGE_OI_ADMIN", } for _, edgeRole := range edgeRoles { if slices.Contains(u.Roles, edgeRole) { userRoles = append(userRoles, edgeRole) } } if len(userRoles) == 0 { return userRoles, fmt.Errorf("edge role not found for context user") } return userRoles, nil }