1 package database
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7
8 "edge-infra.dev/pkg/lib/fog"
9 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
10 datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
11 )
12
13 func (ds Dataset) ReadRulesForBanner(ctx context.Context, bannerName string) ([]rulesengine.RuleSegment, error) {
14 rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesForBanner, bannerName)
15 if err != nil {
16 return nil, fmt.Errorf("error reading banner rules: %w", err)
17 }
18 defer rows.Close()
19
20 return scanDefaultRulesRows(rows)
21 }
22
23 func (ds Dataset) ReadBannerRulesForCommandAndBanner(ctx context.Context, bannerName string, commandName string) ([]rulesengine.RuleSegment, error) {
24 rows, err := ds.db.QueryContext(ctx, datasql.SelectBannerRulesForCommandAndBanner, bannerName, commandName)
25 if err != nil {
26 return nil, fmt.Errorf("error reading banner rules: %w", err)
27 }
28 defer rows.Close()
29
30 return scanBannerRulesRows(rows)
31 }
32
33 func (ds Dataset) ReadBannerRulesForCommand(ctx context.Context, commandName string) ([]rulesengine.RuleSegment, error) {
34 rows, err := ds.db.QueryContext(ctx, datasql.SelectAllBannerRulesForCommand, commandName)
35 if err != nil {
36 return nil, fmt.Errorf("error reading banner rules: %w", err)
37 }
38 defer rows.Close()
39
40 return scanBannerRulesRows(rows)
41 }
42
43 func (ds Dataset) DeletePrivilegeFromBannerRule(ctx context.Context, bannerName, commandName, privilegeName string) (rulesengine.DeleteResult, error) {
44 res, err := ds.deleteValue(ctx, datasql.DeletePrivilegeFromBannerRule, bannerName, commandName, privilegeName)
45 if err != nil {
46 return res, fmt.Errorf("error deleting privilege from banner: %w", err)
47 }
48
49 if res.RowsAffected != 0 || len(res.Errors) != 0 {
50
51 return res, nil
52 }
53
54
55
56
57 tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
58 if err != nil {
59 return res, fmt.Errorf("error creating transaction: %w", err)
60 }
61
62 segment := rulesengine.RuleSegment{
63 Banner: rulesengine.Banner{BannerName: bannerName},
64 Command: rulesengine.Command{Name: commandName},
65 Privilege: rulesengine.Privilege{Name: privilegeName},
66 }
67
68
69 segment, err = ds.populateBannerSegmentIDs(ctx, tx, segment)
70 if err != nil {
71 if txErr := tx.Rollback(); txErr != nil {
72 err = fmt.Errorf("error rolling back transaction (%w): %w", txErr, err)
73 }
74 return res, fmt.Errorf("error discovering ID's: %w", err)
75 }
76
77 if err := tx.Commit(); err != nil {
78 return res, fmt.Errorf("error committing read transaction: %w", err)
79 }
80
81
82 if segment.Command.ID == "" {
83 res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownCommand, Command: segment.Command.Name})
84 }
85 if segment.Privilege.ID == "" {
86 res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownPrivilege, Privilege: segment.Privilege.Name})
87 }
88 if segment.Banner.BannerID == "" {
89 res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownBanner, Banner: segment.Banner.BannerName})
90 }
91
92
93 if len(res.Errors) == 0 {
94 res.Errors = append(res.Errors, rulesengine.Error{Type: rulesengine.UnknownRule, Banner: bannerName, Command: commandName, Privilege: privilegeName})
95 }
96
97 return res, nil
98 }
99
100 func (ds Dataset) ReadRulesForAllBanners(ctx context.Context) ([]rulesengine.RuleSegment, error) {
101 rows, err := ds.db.QueryContext(ctx, datasql.SelectAllRulesFarAllBanners)
102 if err != nil {
103 return nil, fmt.Errorf("error querying for all banner rules: %w", err)
104 }
105 defer rows.Close()
106
107 return scanBannerRulesRows(rows)
108 }
109
110 func scanBannerRulesRows(rows *sql.Rows) ([]rulesengine.RuleSegment, error) {
111 var res []rulesengine.RuleSegment
112 for rows.Next() {
113 var b rulesengine.Banner
114 var c rulesengine.Command
115 var p rulesengine.Privilege
116
117 err := rows.Scan(&b.BannerName, &c.Name, &p.Name, &b.BannerID, &c.ID, &p.ID)
118 if err != nil {
119 return nil, fmt.Errorf("error scanning banner row: %w", err)
120 }
121
122 res = append(res, rulesengine.RuleSegment{
123 Banner: b,
124 Command: c,
125 Privilege: p,
126 })
127 }
128
129 if err := rows.Err(); err != nil {
130 return nil, fmt.Errorf("error while reading all banner rules: %w", err)
131 }
132
133 return res, nil
134 }
135
136 func (ds Dataset) AddBannerRules(ctx context.Context, rules []rulesengine.RuleSegment) (feedback rulesengine.AddRuleResult, err error) {
137 tx, err := ds.db.BeginTx(ctx, nil)
138 if err != nil {
139 return feedback, fmt.Errorf("error beginning transaction: %w", err)
140 }
141
142 log := fog.FromContext(ctx).WithName("dataset")
143
144 defer func() {
145 switch {
146 case err != nil:
147 if txErr := tx.Rollback(); txErr != nil {
148 err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
149 }
150 case len(feedback.Errors) != 0:
151 if txErr := tx.Rollback(); txErr != nil {
152 err = fmt.Errorf("error rolling back transaction due to errors: %w", txErr)
153 }
154 default:
155 if txErr := tx.Commit(); txErr != nil {
156 err = fmt.Errorf("error committing transaction: %w", txErr)
157 }
158 }
159 }()
160
161 for _, rule := range rules {
162 var res sql.Result
163 res, err = tx.ExecContext(ctx, datasql.InsertBannerRule, rule.Banner.BannerName, rule.Command.Name, rule.Privilege.Name)
164 if err != nil {
165 return feedback, fmt.Errorf("error executing query: %w", err)
166 }
167
168 var rowsAffected int64
169 rowsAffected, err = res.RowsAffected()
170 if err != nil {
171 return feedback, fmt.Errorf("error getting query result: %w", err)
172 }
173
174
175
176
177
178 if rowsAffected == 0 {
179 rule, err = ds.populateBannerSegmentIDs(ctx, tx, rule)
180 if err != nil {
181 return feedback, err
182 }
183
184 errors, missingID := createErrors(rule)
185
186 if missingID {
187
188 feedback.Errors = append(feedback.Errors, errors...)
189 continue
190 }
191
192 log.Info(
193 "Rule already present in database",
194 "bannerName", rule.Banner.BannerName,
195 "commandName", rule.Command.Name,
196 "privilegeName", rule.Privilege.Name,
197 )
198 }
199 }
200
201 return feedback, nil
202 }
203
204
205 func (ds Dataset) populateBannerSegmentIDs(ctx context.Context, tx *sql.Tx, segment rulesengine.RuleSegment) (rulesengine.RuleSegment, error) {
206 rows, err := tx.QueryContext(ctx, datasql.GetIDsForBannerSegment, segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name)
207 if err != nil {
208 return segment, fmt.Errorf("error finding ID's (banner %s, command %s, privilege %s): %w", segment.Banner.BannerName, segment.Command.Name, segment.Privilege.Name, err)
209 }
210 defer rows.Close()
211
212 for rows.Next() {
213 var kind, value string
214 err := rows.Scan(&kind, &value)
215 if err != nil {
216 return segment, fmt.Errorf("error scanning for ID's: %w", err)
217 }
218
219 switch kind {
220 case "banner":
221 segment.Banner.BannerID = value
222 case "command":
223 segment.Command.ID = value
224 case "privilege":
225 segment.Privilege.ID = value
226 default:
227 return segment, fmt.Errorf("unknown kind for segment: %s", value)
228 }
229 }
230
231 if err := rows.Err(); err != nil {
232 return segment, fmt.Errorf("error finding ID's for sgement: %w", err)
233 }
234
235 return segment, nil
236 }
237
238
239
240 func createErrors(rule rulesengine.RuleSegment) ([]rulesengine.Error, bool) {
241 missingID := false
242 var errors []rulesengine.Error
243
244 if rule.Banner.BannerID == "" {
245 missingID = true
246 errors = append(errors, rulesengine.Error{
247 Banner: rule.Banner.BannerName,
248 Type: rulesengine.UnknownBanner,
249 })
250 }
251
252 if rule.Command.ID == "" {
253 missingID = true
254 errors = append(errors, rulesengine.Error{
255 Command: rule.Command.Name,
256 Type: rulesengine.UnknownCommand,
257 })
258 }
259
260 if rule.Privilege.ID == "" {
261 missingID = true
262 errors = append(errors, rulesengine.Error{
263 Privilege: rule.Privilege.Name,
264 Type: rulesengine.UnknownPrivilege,
265 })
266 }
267
268 return errors, missingID
269 }
270
View as plain text