package database // Contract: Be a service layer for data storage and retrieval import ( "context" "database/sql" "errors" "fmt" "github.com/go-logr/logr" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" ) type Dataset struct { log logr.Logger db *sql.DB } // New generates a new dataset connected to the input SQL database func New(log logr.Logger, db *sql.DB) Dataset { return Dataset{log: log, db: db} } // returns true if there the err is from PostGres and the code matches the one given func checkPostGresErr(err error, code string) bool { var sqlErr interface { error SQLState() string } if errors.As(err, &sqlErr) { return sqlErr.SQLState() == code } return false } func (ds Dataset) deleteValue(ctx context.Context, query string, args ...any) (rulesengine.DeleteResult, error) { res, err := ds.db.ExecContext(ctx, query, args...) if err != nil { if checkPostGresErr(err, "23503") { return rulesengine.DeleteResult{Errors: []rulesengine.Error{{Type: rulesengine.Conflict}}}, nil } return rulesengine.DeleteResult{}, fmt.Errorf("error in data:deleteValue: %v", err) } rows, err := res.RowsAffected() if err != nil { err = fmt.Errorf("error in data:deletevalue rowsAffected: %v", err) } return rulesengine.DeleteResult{RowsAffected: rows}, err } type selectNameResult struct { id string name string } func (ds Dataset) readNames(ctx context.Context, args []string, query string) (results []selectNameResult, err error) { rows, err := ds.db.QueryContext(ctx, query, args) if err != nil { return nil, fmt.Errorf("error executing query: %w", err) } defer rows.Close() for rows.Next() { var id, name string if err := rows.Scan(&id, &name); err != nil { return nil, fmt.Errorf("error scanning sql row: %w", err) } results = append(results, selectNameResult{ id: id, name: name, }) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error while reading names: %w", err) } return results, nil } // Best effort function to add names to the database. AddNameResult.Conflicts is no longer used. func (ds Dataset) addNames(ctx context.Context, names []string, query string) (feedback rulesengine.AddNameResult, err error) { tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return rulesengine.AddNameResult{}, err } for _, name := range names { _, err := tx.ExecContext(ctx, query, name) if err != nil { txErr := tx.Rollback() if txErr != nil { err = fmt.Errorf("error rolling back transaction due to application error (%w): %w", err, txErr) } return feedback, fmt.Errorf("error executing query: %w", err) } } if txErr := tx.Commit(); txErr != nil { err = fmt.Errorf("error committing transaction: %w", txErr) } return feedback, err }