...
1 package service
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7
8 "edge-infra.dev/pkg/lib/fog"
9 )
10
11 const (
12 getEARolesQuery = `SELECT oi_role_privileges.role_name, ea_rules_privileges.name
13 FROM ea_rules_privileges
14 INNER JOIN oi_role_privileges
15 ON ea_rules_privileges.privilege_id = oi_role_privileges.privilege_id
16 WHERE role_name = ANY ($1);`
17 )
18
19
20 type privilegeSet map[string]struct{}
21
22 func (ps privilegeSet) Slice() []string {
23 privileges := make([]string, 0, len(ps))
24 for k := range ps {
25 privileges = append(privileges, k)
26 }
27 return privileges
28 }
29
30 func (ps privilegeSet) Insert(privileges []string) {
31 for _, privilege := range privileges {
32 ps[privilege] = struct{}{}
33 }
34 }
35
36 type User struct {
37 db *sql.DB
38 cache *cache
39 }
40
41 func NewUser(db *sql.DB) (User, error) {
42 return User{
43 db: db,
44 cache: &cache{
45 cache: make(map[string][]string),
46 expiration: cacheExpiration,
47 },
48 }, nil
49 }
50
51 func (us *User) GetEARoles(ctx context.Context, roles []string) ([]string, error) {
52 if len(roles) == 0 {
53 return []string{}, nil
54 }
55
56 log := fog.FromContext(ctx)
57 log.Info("Checking cache with user roles", "userRoles", roles)
58 uncachedRoles, privilegeSet := us.checkCache(roles)
59
60 if len(uncachedRoles) != 0 {
61 log.Info("Querying database for role mappings", "userRoles", uncachedRoles)
62 roleMappings, err := us.getRoleMappingsFromDB(ctx, uncachedRoles)
63 if err != nil {
64 return nil, err
65 }
66 for role, privileges := range roleMappings {
67 us.cache.Insert(role, privileges)
68 privilegeSet.Insert(privileges)
69 }
70 }
71
72 return privilegeSet.Slice(), nil
73 }
74
75
76
77 func (us *User) checkCache(roles []string) (remainingRoles []string, privileges privilegeSet) {
78 privileges = make(privilegeSet)
79 for _, role := range roles {
80 cachedPrivileges := us.cache.Get(role)
81 if len(cachedPrivileges) != 0 {
82 privileges.Insert(cachedPrivileges)
83 } else {
84 remainingRoles = append(remainingRoles, role)
85 }
86 }
87 return remainingRoles, privileges
88 }
89
90 func (us *User) getRoleMappingsFromDB(ctx context.Context, roles []string) (roleMappings map[string][]string, err error) {
91 rows, err := us.db.QueryContext(ctx, getEARolesQuery, roles)
92 if err != nil {
93 return nil, fmt.Errorf("error querying earoles: %w", err)
94 }
95 defer rows.Close()
96
97 roleMappings = map[string][]string{}
98 for rows.Next() {
99 var role, privilege string
100 err := rows.Scan(&role, &privilege)
101 if err != nil {
102 return nil, fmt.Errorf("error scanning earoles row: %w", err)
103 }
104 roleMappings[role] = append(roleMappings[role], privilege)
105 }
106
107 if err := rows.Err(); err != nil {
108 return nil, fmt.Errorf("error while reading earoles: %w", err)
109 }
110 return roleMappings, nil
111 }
112
View as plain text