1 package database
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7
8 "edge-infra.dev/pkg/lib/fog"
9 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
10 datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
11 )
12
13 func (ds Dataset) AddDefaultRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) {
14 tx, err := ds.db.BeginTx(ctx, nil)
15 if err != nil {
16 return feedback, fmt.Errorf("error beginning transaction: %w", err)
17 }
18
19 log := fog.FromContext(ctx).WithName("dataset")
20
21 defer func() {
22 switch {
23 case err != nil:
24 if txErr := tx.Rollback(); txErr != nil {
25 err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
26 }
27 case len(feedback.Errors) != 0:
28 if txErr := tx.Rollback(); txErr != nil {
29 err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr)
30 }
31 default:
32 if txErr := tx.Commit(); txErr != nil {
33 err = fmt.Errorf("error committing transaction: %w", txErr)
34 }
35 }
36 }()
37
38 for _, rule := range rules {
39 var res sql.Result
40 res, err := tx.ExecContext(ctx, datasql.InsertRuleDefault, rule.Command.Name, rule.Privilege.Name)
41 if err != nil {
42 return feedback, fmt.Errorf("error executing query: %w", err)
43 }
44
45 var rowsAffected int64
46 rowsAffected, err = res.RowsAffected()
47 if err != nil {
48 return feedback, fmt.Errorf("error getting query result: %w", err)
49 }
50
51
52
53
54
55 if rowsAffected == 0 {
56 rule, err = populateRuleSegmentIDs(ctx, tx, rule)
57 if err != nil {
58 return feedback, err
59 }
60
61 missingID := false
62
63 if rule.Command.ID == "" {
64 missingID = true
65 feedback.Errors = append(feedback.Errors, rulesengine.Error{
66 Command: rule.Command.Name,
67 Type: rulesengine.UnknownCommand,
68 })
69 }
70
71 if rule.Privilege.ID == "" {
72 missingID = true
73 feedback.Errors = append(feedback.Errors, rulesengine.Error{
74 Privilege: rule.Privilege.Name,
75 Type: rulesengine.UnknownPrivilege,
76 })
77 }
78
79 if missingID {
80
81 continue
82 }
83
84 log.Info(
85 "Default rule already present in database",
86 "commandName", rule.Command.Name,
87 "privilegeName", rule.Privilege.Name,
88 )
89 }
90 }
91 return feedback, err
92 }
93
94 func populateRuleSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) {
95 rows, err := tx.QueryContext(ctx, datasql.GetIDsForRuleSegment, segment.Command.Name, segment.Privilege.Name)
96 if err != nil {
97 return segment, fmt.Errorf("error finding ID's (command %s, privilege %s): %w", segment.Command.Name, segment.Privilege.Name, err)
98 }
99 defer rows.Close()
100
101 for rows.Next() {
102 var kind, value string
103 err := rows.Scan(&kind, &value)
104 if err != nil {
105 return segment, fmt.Errorf("error scanning for ID's: %w", err)
106 }
107
108 switch kind {
109 case "command":
110 segment.Command.ID = value
111 case "privilege":
112 segment.Privilege.ID = value
113 default:
114 return segment, fmt.Errorf("unknown kind for segment: %s", value)
115 }
116 }
117
118 if err := rows.Err(); err != nil {
119 return segment, fmt.Errorf("error finding ID's for segment: %w", err)
120 }
121
122 return segment, nil
123 }
124
125 func scanDefaultRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) {
126 var res []rulesengine.RuleSegment
127 for rows.Next() {
128 var privname string
129 var privid string
130 var comid string
131 var comname string
132 err := rows.Scan(&comname, &privname, &comid, &privid)
133 if err != nil {
134 return nil, fmt.Errorf("error in data:scanDefaultRulesRows: %v", err)
135 }
136 res = append(res, rulesengine.RuleSegment{
137 Command: rulesengine.Command{
138 Name: comname,
139 ID: comid,
140 },
141 Privilege: rulesengine.Privilege{
142 Name: privname,
143 ID: privid,
144 },
145 })
146 }
147 err := rows.Err()
148 if err != nil {
149 err = fmt.Errorf("error in data:scanDefaultRulesRows on rows.Err: %v", err)
150 }
151 return res, err
152 }
153
154 func (ds Dataset) ReadAllDefaultRules(ctx context.Context) ([]rulesengine.RuleSegment, error) {
155 rows, err := ds.db.QueryContext(ctx, datasql.SelectAllDefaultRules)
156 if err != nil {
157 return nil, fmt.Errorf("error in data:ReadAllDefaultRules: %v", err)
158 }
159 defer rows.Close()
160
161 return scanDefaultRulesRows(rows)
162 }
163
164 func assembleRules(ruleSegments []rulesengine.RuleSegment) []rulesengine.Rule {
165 rules := []rulesengine.Rule{}
166 for _, ruleSegment := range ruleSegments {
167 rules = appendRule(rules, ruleSegment)
168 }
169 return rules
170 }
171
172 func appendRule(rules []rulesengine.Rule, ruleSegment rulesengine.RuleSegment) []rulesengine.Rule {
173 if len(rules) == 0 {
174 return []rulesengine.Rule{{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}}}
175 }
176 for idx, rule := range rules {
177 if rule.Command.ID == ruleSegment.Command.ID {
178
179 rules[idx].Privileges = append(rule.Privileges, ruleSegment.Privilege)
180 return rules
181 }
182 }
183 rules = append(rules, rulesengine.Rule{Command: ruleSegment.Command, Privileges: []rulesengine.Privilege{ruleSegment.Privilege}})
184 return rules
185 }
186
187
188
189 func (ds Dataset) ReadDefaultRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) {
190 rows, err := ds.db.QueryContext(ctx, datasql.SelectDefaultRulesByCommandName, commandName)
191 if err != nil {
192 return []rulesengine.RuleSegment{}, err
193 }
194 defer rows.Close()
195
196 return scanDefaultRulesRows(rows)
197 }
198
199 func (ds Dataset) DeleteDefaultRule(ctx context.Context, commandName, privilegeName string) (res rulesengine.DeleteResult, err error) {
200 res, err = ds.deleteValue(ctx, datasql.DeleteDefaultRule, commandName, privilegeName)
201 if err != nil {
202 return res, fmt.Errorf("error deleting privilege from default rule: %w", err)
203 }
204 if res.RowsAffected != 0 || len(res.Errors) != 0 {
205
206 return res, nil
207 }
208
209
210
211
212 tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
213 if err != nil {
214 return res, fmt.Errorf("error creating transaction: %w", err)
215 }
216 defer func() {
217 switch {
218 case err != nil:
219 if txErr := tx.Rollback(); txErr != nil {
220 err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
221 }
222 default:
223 if txErr := tx.Commit(); txErr != nil {
224 err = fmt.Errorf("error committing transaction: %w", txErr)
225 }
226 }
227 }()
228
229 segment := rulesengine.RuleSegment{
230 Command: rulesengine.Command{Name: commandName},
231 Privilege: rulesengine.Privilege{Name: privilegeName},
232 }
233 segment, err = populateRuleSegmentIDs(ctx, tx, segment)
234 if err != nil {
235 return res, fmt.Errorf("error discovering ID's: %w", err)
236 }
237 res.Errors = checkUnknownDeleteRuleErrs(segment)
238
239 return res, err
240 }
241
242 func checkUnknownDeleteRuleErrs(segment rulesengine.RuleSegment) []rulesengine.Error {
243 var errs []rulesengine.Error
244
245 if segment.Command.ID == "" {
246 errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name})
247 }
248 if segment.Privilege.ID == "" {
249 errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name})
250 }
251 if len(errs) == 0 {
252 errs = append(errs, rulesengine.Error{Type: rulesengine.UnknownRule, Command: segment.Command.Name, Privilege: segment.Privilege.Name})
253 }
254
255 return errs
256 }
257
View as plain text