...

Source file src/edge-infra.dev/pkg/edge/api/middleware/auth.go

Documentation: edge-infra.dev/pkg/edge/api/middleware

     1  package middleware
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"regexp"
    13  	"slices"
    14  	"strings"
    15  	"time"
    16  
    17  	"edge-infra.dev/pkg/edge/api/bsl/types"
    18  	"edge-infra.dev/pkg/edge/api/graph/model"
    19  	sqlstatements "edge-infra.dev/pkg/edge/api/sql"
    20  	"edge-infra.dev/pkg/edge/api/totp"
    21  	"edge-infra.dev/pkg/edge/api/utils"
    22  	"edge-infra.dev/pkg/edge/client"
    23  	"edge-infra.dev/pkg/lib/crypto"
    24  
    25  	"github.com/gin-gonic/gin"
    26  	"github.com/golang-jwt/jwt"
    27  )
    28  
    29  var userCtxKey = &contextKey{"user"}
    30  var TerminalIDCtxKey = &contextKey{"terminalId"}
    31  var EdgeAPIEndpointCtxKey = &contextKey{"edgeAPIEndpoint"}
    32  var ClusterEdgeIDCtxKey = &contextKey{"clusterEdgeId"}
    33  
    34  type contextKey struct {
    35  	name string
    36  }
    37  
    38  // AuthMiddleware extract auth token from header, validate it and it to the http context.
    39  func AuthMiddleware(jwtTokenSecret, totpSecret string, db *sql.DB) gin.HandlerFunc {
    40  	return func(c *gin.Context) {
    41  		handleEdgeAuth(c, jwtTokenSecret, totpSecret, db)
    42  	}
    43  }
    44  
    45  func validateActivationCode(activationCode string, db *sql.DB) (*types.AuthUser, string, error) {
    46  	reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second)
    47  	defer cancelReq()
    48  
    49  	hash := crypto.HashActivation(activationCode)
    50  	encodedHash := hex.EncodeToString(hash)
    51  
    52  	row := db.QueryRowContext(reqCtx, sqlstatements.TerminalActivationQuery, encodedHash)
    53  	var activationHash, terminalID string
    54  	if err := row.Scan(&activationHash, &terminalID); err != nil {
    55  		return nil, "", err
    56  	}
    57  
    58  	return &types.AuthUser{
    59  		Roles: []string{string(model.RoleEdgeTerminal)},
    60  	}, terminalID, nil
    61  }
    62  
    63  func validateEdgeBootstrapToken(bootstrapToken string, db *sql.DB) (*types.AuthUser, string, error) {
    64  	reqCtx, cancelReq := context.WithTimeout(context.Background(), time.Duration(30)*time.Second)
    65  	defer cancelReq()
    66  
    67  	hashEncoded := hex.EncodeToString(crypto.HashEdgeBootstrapToken([]byte(bootstrapToken)))
    68  	row := db.QueryRowContext(reqCtx, sqlstatements.GetEdgeBootstrapToken, hashEncoded)
    69  	var bootstrapTokenHash, bootstrapExpiry, clusterEdgeID string
    70  	if err := row.Scan(&bootstrapTokenHash, &bootstrapExpiry, &clusterEdgeID); err != nil {
    71  		return nil, "", err
    72  	}
    73  
    74  	timeUTCNow := time.Now().UTC()
    75  	expiryTime, err := time.Parse(time.RFC3339, bootstrapExpiry)
    76  	if err != nil {
    77  		return nil, "", fmt.Errorf("could not parse bootstrap expiry time")
    78  	}
    79  
    80  	if expiryTime.Before(timeUTCNow) {
    81  		return nil, "", fmt.Errorf("bootstrap token has expired")
    82  	}
    83  
    84  	if clusterEdgeID == "" {
    85  		return nil, "", fmt.Errorf("invalid cluster edge id returned on edge bootstrap token auth")
    86  	}
    87  
    88  	return &types.AuthUser{
    89  		Roles: []string{string(model.RoleEdgeBootstrap)},
    90  	}, clusterEdgeID, nil
    91  }
    92  
    93  func handleEdgeAuth(c *gin.Context, jwtTokenSecret string, totpSecret string, db *sql.DB) {
    94  	auth := c.GetHeader("Authorization")
    95  	if auth == "" {
    96  		c.Next()
    97  		return
    98  	}
    99  
   100  	fullPath := strings.ToLower(c.FullPath())
   101  	if strings.Contains(fullPath, "/validate_token") {
   102  		c.Next()
   103  		return
   104  	}
   105  
   106  	defer c.Request.Body.Close()
   107  	body, err := io.ReadAll(c.Request.Body)
   108  	if err != nil {
   109  		c.String(http.StatusForbidden, fmt.Sprintf("error reading request: %v", err))
   110  		return
   111  	}
   112  	c.Request.Body = io.NopCloser((bytes.NewBuffer(body)))
   113  
   114  	edgeAPIEndpoint := getEdgeEndpoint(c)
   115  
   116  	// Only use auth as activation code if the query is terminalBootstrap
   117  	if strings.Contains(string(body), "terminalBootstrap") {
   118  		if err := crypto.ValidActivationCode(auth); err != nil {
   119  			c.String(http.StatusForbidden, fmt.Sprintf("%v - please check the activation code has been copied correctly", err))
   120  			return
   121  		}
   122  
   123  		user, terminalID, err := validateActivationCode(auth, db)
   124  		if user != nil && err == nil {
   125  			ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
   126  			ctx = context.WithValue(ctx, TerminalIDCtxKey, terminalID)
   127  			ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
   128  			c.Request = c.Request.WithContext(ctx)
   129  			c.Next()
   130  			return
   131  		}
   132  		if errors.Is(err, sql.ErrNoRows) {
   133  			c.String(http.StatusForbidden, "activation code is incorrect or has been consumed, please check the activation code has been copied correctly. Otherwise request a new one.")
   134  		} else {
   135  			c.String(http.StatusForbidden, fmt.Sprintf("Please try again. There was an error attempting to validate the activation code: %v", err))
   136  		}
   137  		return
   138  	}
   139  
   140  	// Only use auth as edge bootstrap token on mutation bootstrapCluster
   141  	// token is a hex encoded string of 32 bytes (encoded length is 64 chars)
   142  	r, _ := regexp.Compile("^(bearer )[A-Fa-f0-9]{64,64}$")
   143  	if strings.Contains(string(body), "bootstrapCluster") && r.MatchString(auth) {
   144  		auth := strings.Replace(auth, "bearer ", "", 1)
   145  		user, clusterEdgeID, err := validateEdgeBootstrapToken(auth, db)
   146  		if user != nil && err == nil {
   147  			ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
   148  			ctx = context.WithValue(ctx, ClusterEdgeIDCtxKey, clusterEdgeID)
   149  			ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
   150  			c.Request = c.Request.WithContext(ctx)
   151  			c.Next()
   152  			return
   153  		}
   154  		c.String(http.StatusForbidden, fmt.Sprintf("invalid edge bootstrap token: %v", err))
   155  		return
   156  	}
   157  
   158  	tokenParts := strings.Split(auth, " ")
   159  	if len(tokenParts) != 2 {
   160  		c.String(http.StatusBadRequest, "invalid auth token")
   161  		return
   162  	}
   163  
   164  	tokenType := strings.ToLower(tokenParts[0])
   165  	if !utils.Contains([]string{client.BearerToken, client.TotpToken}, tokenType) {
   166  		c.String(http.StatusForbidden, "invalid auth token")
   167  		return
   168  	}
   169  
   170  	user, err := getUserAndRolesFromAuth(tokenType, tokenParts[1], jwtTokenSecret, totpSecret)
   171  	if err != nil {
   172  		c.String(http.StatusForbidden, err.Error())
   173  		return
   174  	}
   175  
   176  	// put it in context
   177  	ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
   178  	ctx = context.WithValue(ctx, EdgeAPIEndpointCtxKey, edgeAPIEndpoint)
   179  
   180  	// and call the next with our new context
   181  	c.Request = c.Request.WithContext(ctx)
   182  	c.Next()
   183  }
   184  
   185  func getEdgeEndpoint(c *gin.Context) string {
   186  	return fmt.Sprintf("https://%s%s", c.Request.Host, c.Request.RequestURI)
   187  }
   188  
   189  func getUserAndRolesFromAuth(tokenType, token, secret, totpSecret string) (*types.AuthUser, error) {
   190  	if tokenType == client.BearerToken {
   191  		return ValidateAndGetUser(token, secret)
   192  	}
   193  	if err := totp.ValidateTotpToken(token, totpSecret); err != nil {
   194  		return nil, err
   195  	}
   196  	return &types.AuthUser{
   197  		Roles: []string{string(model.RoleTotpRole)},
   198  	}, nil
   199  }
   200  
   201  // ForContext finds the user from the context. REQUIRES Middleware to have run.
   202  func ForContext(ctx context.Context) *types.AuthUser {
   203  	user, _ := ctx.Value(userCtxKey).(*types.AuthUser)
   204  	return user
   205  }
   206  
   207  // NewContext creates a new context from the provided context with the supplied values and user ctx key.
   208  func NewContext(ctx context.Context, value interface{}) context.Context {
   209  	return context.WithValue(ctx, userCtxKey, value)
   210  }
   211  
   212  // ValidateAndGetUser decode the jwt token into an auth user and using the secret to validate that the token is valid.
   213  func ValidateAndGetUser(tokenString, jwtTokenSecret string) (*types.AuthUser, error) {
   214  	claims := jwt.MapClaims{}
   215  	_, err := jwt.ParseWithClaims(tokenString, claims, func(_ *jwt.Token) (interface{}, error) {
   216  		// the jwtTokenSecret is used by the jwt library to validate that the token is valid after decoding it.
   217  		return []byte(jwtTokenSecret), nil
   218  	})
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	var roles []string
   224  	if claims[types.Roles] != nil {
   225  		for _, role := range claims[types.Roles].([]interface{}) {
   226  			roles = append(roles, fmt.Sprint(role))
   227  		}
   228  	}
   229  	user := &types.AuthUser{
   230  		Username:     claims[types.Username].(string),
   231  		Email:        claims[types.Email].(string),
   232  		Roles:        roles,
   233  		Organization: claims[types.Organization].(string),
   234  		AuthProvider: claims[types.AuthProvider].(string),
   235  	}
   236  	if val, ok := claims[types.Token]; ok {
   237  		user.Token = val.(string)
   238  	}
   239  	if val, ok := claims[types.RefreshToken]; ok {
   240  		user.RefreshToken = val.(string)
   241  	}
   242  	return user, nil
   243  }
   244  
   245  // CreateToken creates a jwt token signed with a secret for validation.
   246  func CreateToken(username, email, organization, secret string, roles []string, token, authProvider, refreshToken string) (string, error) {
   247  	mc := jwt.MapClaims{}
   248  	mc[types.Username] = username
   249  	mc[types.Email] = email
   250  	mc[types.Roles] = roles
   251  	mc[types.Organization] = organization
   252  	mc[types.Token] = token
   253  	mc[types.RefreshToken] = refreshToken
   254  	mc[types.AuthProvider] = authProvider
   255  
   256  	claims := jwt.NewWithClaims(jwt.SigningMethodHS512, mc)
   257  
   258  	token, err := claims.SignedString([]byte(secret))
   259  	if err != nil {
   260  		return "", err
   261  	}
   262  	return token, nil
   263  }
   264  
   265  func ValidateToken(c *gin.Context, auth string, banner string, jwtSecret string, getUserBanners func(context.Context, string) ([]*model.BannerInfo, error)) error {
   266  	if auth == "" {
   267  		return fmt.Errorf("no authorization header found")
   268  	}
   269  
   270  	if banner == "" {
   271  		return fmt.Errorf("no banner header found")
   272  	}
   273  
   274  	tokenParts := strings.Split(auth, " ")
   275  
   276  	if len(tokenParts) != 2 {
   277  		return fmt.Errorf("invalid auth token")
   278  	}
   279  
   280  	// Validate outer JWT token and grab values
   281  	user, err := ValidateAndGetUser(tokenParts[1], jwtSecret)
   282  	if err != nil {
   283  		return err
   284  	}
   285  
   286  	// Put user into context
   287  	ctx := context.WithValue(c.Request.Context(), userCtxKey, user)
   288  	c.Request = c.Request.WithContext(ctx)
   289  
   290  	// Validate inner token and get user fq name
   291  	name, err := ValidateBSLTokenAndGetName(user.Token)
   292  	if err != nil {
   293  		return err
   294  	}
   295  
   296  	// We do not need to check banner for org-admins or super admins - they have access to all
   297  	for _, role := range user.Roles {
   298  		if role == string(model.RoleEdgeOrgAdmin) || role == string(model.RoleEdgeSuperAdmin) {
   299  			return nil
   300  		}
   301  	}
   302  
   303  	// Grab user banners via BannerService function and validate if user has access
   304  	banners, err := getUserBanners(ctx, name)
   305  	if err != nil {
   306  		return err
   307  	}
   308  
   309  	for _, b := range banners {
   310  		if banner == b.BannerEdgeID {
   311  			return nil
   312  		}
   313  	}
   314  
   315  	return fmt.Errorf("user does not have access to banner")
   316  }
   317  
   318  func ValidateBSLTokenAndGetName(token string) (string, error) {
   319  	bslClaims := jwt.MapClaims{}
   320  	if _, _, err := new(jwt.Parser).ParseUnverified(token, bslClaims); err != nil {
   321  		return "", err
   322  	}
   323  
   324  	return bslClaims["sub"].(string), bslClaims.Valid()
   325  }
   326  
   327  func GetEdgeRoles(ctx context.Context) ([]string, error) {
   328  	u := ForContext(ctx)
   329  	userRoles := []string{}
   330  	edgeRoles := []string{
   331  		"EDGE_SUPER_ADMIN",
   332  		"EDGE_ORG_ADMIN",
   333  		"EDGE_BANNER_ADMIN",
   334  		"EDGE_BANNER_OPERATOR",
   335  		"EDGE_BANNER_VIEWER",
   336  		"EDGE_OI_ADMIN",
   337  	}
   338  	for _, edgeRole := range edgeRoles {
   339  		if slices.Contains(u.Roles, edgeRole) {
   340  			userRoles = append(userRoles, edgeRole)
   341  		}
   342  	}
   343  	if len(userRoles) == 0 {
   344  		return userRoles, fmt.Errorf("edge role not found for context user")
   345  	}
   346  	return userRoles, nil
   347  }
   348  

View as plain text