...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/user/service/userservice.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/user/service

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"edge-infra.dev/pkg/lib/fog"
     9  )
    10  
    11  const (
    12  	getEARolesQuery = `SELECT oi_role_privileges.role_name, ea_rules_privileges.name 
    13  	FROM ea_rules_privileges 
    14  	INNER JOIN oi_role_privileges 
    15  		ON ea_rules_privileges.privilege_id = oi_role_privileges.privilege_id 
    16  	WHERE role_name = ANY ($1);`
    17  )
    18  
    19  // privilegeSet represents a set of privileges using a map of empty structs. A map is used to ensure uniqueness.
    20  type privilegeSet map[string]struct{}
    21  
    22  func (ps privilegeSet) Slice() []string {
    23  	privileges := make([]string, 0, len(ps))
    24  	for k := range ps {
    25  		privileges = append(privileges, k)
    26  	}
    27  	return privileges
    28  }
    29  
    30  func (ps privilegeSet) Insert(privileges []string) {
    31  	for _, privilege := range privileges {
    32  		ps[privilege] = struct{}{}
    33  	}
    34  }
    35  
    36  type User struct {
    37  	db    *sql.DB
    38  	cache *cache
    39  }
    40  
    41  func NewUser(db *sql.DB) (User, error) {
    42  	return User{
    43  		db: db,
    44  		cache: &cache{
    45  			cache:      make(map[string][]string),
    46  			expiration: cacheExpiration,
    47  		},
    48  	}, nil
    49  }
    50  
    51  func (us *User) GetEARoles(ctx context.Context, roles []string) ([]string, error) {
    52  	if len(roles) == 0 {
    53  		return []string{}, nil
    54  	}
    55  
    56  	log := fog.FromContext(ctx)
    57  	log.Info("Checking cache with user roles", "userRoles", roles)
    58  	uncachedRoles, privilegeSet := us.checkCache(roles)
    59  
    60  	if len(uncachedRoles) != 0 {
    61  		log.Info("Querying database for role mappings", "userRoles", uncachedRoles)
    62  		roleMappings, err := us.getRoleMappingsFromDB(ctx, uncachedRoles)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		for role, privileges := range roleMappings {
    67  			us.cache.Insert(role, privileges)
    68  			privilegeSet.Insert(privileges)
    69  		}
    70  	}
    71  
    72  	return privilegeSet.Slice(), nil
    73  }
    74  
    75  // Check's the userservice cache for privileges matching the input roles.
    76  // Returns a set of privileges in the cache and any roles missing from the cache
    77  func (us *User) checkCache(roles []string) (remainingRoles []string, privileges privilegeSet) {
    78  	privileges = make(privilegeSet)
    79  	for _, role := range roles {
    80  		cachedPrivileges := us.cache.Get(role)
    81  		if len(cachedPrivileges) != 0 {
    82  			privileges.Insert(cachedPrivileges)
    83  		} else {
    84  			remainingRoles = append(remainingRoles, role)
    85  		}
    86  	}
    87  	return remainingRoles, privileges
    88  }
    89  
    90  func (us *User) getRoleMappingsFromDB(ctx context.Context, roles []string) (roleMappings map[string][]string, err error) {
    91  	rows, err := us.db.QueryContext(ctx, getEARolesQuery, roles)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("error querying earoles: %w", err)
    94  	}
    95  	defer rows.Close()
    96  
    97  	roleMappings = map[string][]string{}
    98  	for rows.Next() {
    99  		var role, privilege string
   100  		err := rows.Scan(&role, &privilege)
   101  		if err != nil {
   102  			return nil, fmt.Errorf("error scanning earoles row: %w", err)
   103  		}
   104  		roleMappings[role] = append(roleMappings[role], privilege)
   105  	}
   106  
   107  	if err := rows.Err(); err != nil {
   108  		return nil, fmt.Errorf("error while reading earoles: %w", err)
   109  	}
   110  	return roleMappings, nil
   111  }
   112  

View as plain text