1 package database
2
3
4
5 import (
6 "context"
7 "database/sql"
8 "errors"
9 "fmt"
10
11 "github.com/go-logr/logr"
12
13 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
14 )
15
16 type Dataset struct {
17 log logr.Logger
18 db *sql.DB
19 }
20
21
22 func New(log logr.Logger, db *sql.DB) Dataset {
23 return Dataset{log: log, db: db}
24 }
25
26
27 func checkPostGresErr(err error, code string) bool {
28 var sqlErr interface {
29 error
30 SQLState() string
31 }
32 if errors.As(err, &sqlErr) {
33 return sqlErr.SQLState() == code
34 }
35 return false
36 }
37
38 func (ds Dataset) deleteValue(ctx context.Context, query string, args ...any) (rulesengine.DeleteResult, error) {
39 res, err := ds.db.ExecContext(ctx, query, args...)
40 if err != nil {
41 if checkPostGresErr(err, "23503") {
42 return rulesengine.DeleteResult{Errors: []rulesengine.Error{{Type: rulesengine.Conflict}}}, nil
43 }
44 return rulesengine.DeleteResult{}, fmt.Errorf("error in data:deleteValue: %v", err)
45 }
46 rows, err := res.RowsAffected()
47 if err != nil {
48 err = fmt.Errorf("error in data:deletevalue rowsAffected: %v", err)
49 }
50 return rulesengine.DeleteResult{RowsAffected: rows}, err
51 }
52
53 type selectNameResult struct {
54 id string
55 name string
56 }
57
58 func (ds Dataset) readNames(ctx context.Context, args []string, query string) (results []selectNameResult, err error) {
59 rows, err := ds.db.QueryContext(ctx, query, args)
60 if err != nil {
61 return nil, fmt.Errorf("error executing query: %w", err)
62 }
63 defer rows.Close()
64 for rows.Next() {
65 var id, name string
66 if err := rows.Scan(&id, &name); err != nil {
67 return nil, fmt.Errorf("error scanning sql row: %w", err)
68 }
69 results = append(results, selectNameResult{
70 id: id,
71 name: name,
72 })
73 }
74 if err := rows.Err(); err != nil {
75 return nil, fmt.Errorf("error while reading names: %w", err)
76 }
77 return results, nil
78 }
79
80
81 func (ds Dataset) addNames(ctx context.Context, names []string, query string) (feedback rulesengine.AddNameResult, err error) {
82 tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{})
83 if err != nil {
84 return rulesengine.AddNameResult{}, err
85 }
86
87 for _, name := range names {
88 _, err := tx.ExecContext(ctx, query, name)
89 if err != nil {
90 txErr := tx.Rollback()
91 if txErr != nil {
92 err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr)
93 }
94 return feedback, fmt.Errorf("error executing query: %w", err)
95 }
96 }
97
98 if txErr := tx.Commit(); txErr != nil {
99 err = fmt.Errorf("error committing transaction: %w", txErr)
100 }
101
102 return feedback, err
103 }
104
View as plain text