package service import ( "context" "database/sql" "fmt" "edge-infra.dev/pkg/lib/fog" ) const ( getEARolesQuery = `SELECT oi_role_privileges.role_name, ea_rules_privileges.name FROM ea_rules_privileges INNER JOIN oi_role_privileges ON ea_rules_privileges.privilege_id = oi_role_privileges.privilege_id WHERE role_name = ANY ($1);` ) // privilegeSet represents a set of privileges using a map of empty structs. A map is used to ensure uniqueness. type privilegeSet map[string]struct{} func (ps privilegeSet) Slice() []string { privileges := make([]string, 0, len(ps)) for k := range ps { privileges = append(privileges, k) } return privileges } func (ps privilegeSet) Insert(privileges []string) { for _, privilege := range privileges { ps[privilege] = struct{}{} } } type User struct { db *sql.DB cache *cache } func NewUser(db *sql.DB) (User, error) { return User{ db: db, cache: &cache{ cache: make(map[string][]string), expiration: cacheExpiration, }, }, nil } func (us *User) GetEARoles(ctx context.Context, roles []string) ([]string, error) { if len(roles) == 0 { return []string{}, nil } log := fog.FromContext(ctx) log.Info("Checking cache with user roles", "userRoles", roles) uncachedRoles, privilegeSet := us.checkCache(roles) if len(uncachedRoles) != 0 { log.Info("Querying database for role mappings", "userRoles", uncachedRoles) roleMappings, err := us.getRoleMappingsFromDB(ctx, uncachedRoles) if err != nil { return nil, err } for role, privileges := range roleMappings { us.cache.Insert(role, privileges) privilegeSet.Insert(privileges) } } return privilegeSet.Slice(), nil } // Check's the userservice cache for privileges matching the input roles. // Returns a set of privileges in the cache and any roles missing from the cache func (us *User) checkCache(roles []string) (remainingRoles []string, privileges privilegeSet) { privileges = make(privilegeSet) for _, role := range roles { cachedPrivileges := us.cache.Get(role) if len(cachedPrivileges) != 0 { privileges.Insert(cachedPrivileges) } else { remainingRoles = append(remainingRoles, role) } } return remainingRoles, privileges } func (us *User) getRoleMappingsFromDB(ctx context.Context, roles []string) (roleMappings map[string][]string, err error) { rows, err := us.db.QueryContext(ctx, getEARolesQuery, roles) if err != nil { return nil, fmt.Errorf("error querying earoles: %w", err) } defer rows.Close() roleMappings = map[string][]string{} for rows.Next() { var role, privilege string err := rows.Scan(&role, &privilege) if err != nil { return nil, fmt.Errorf("error scanning earoles row: %w", err) } roleMappings[role] = append(roleMappings[role], privilege) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error while reading earoles: %w", err) } return roleMappings, nil }