package database import ( "context" "database/sql" "fmt" datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql" "github.com/go-logr/logr" "github.com/google/uuid" ) type Dataset struct { log logr.Logger db *sql.DB } func New(log logr.Logger, db *sql.DB) Dataset { return Dataset{log: log, db: db} } func isUUID(val string) bool { _, err := uuid.Parse(val) return err == nil } func (ds Dataset) GetProjectAndBannerID(ctx context.Context, banner string) (projectID, bannerID string, err error) { // If banner value is a valid UUID, check for both id and name, otherwise just name var rows *sql.Rows var bannerUUID *string if isUUID(banner) { bannerUUID = &banner } rows, err = ds.db.QueryContext(ctx, datasql.SelectProjectIDAndBannerID, bannerUUID, banner) if err != nil { return "", "", fmt.Errorf("error querying db in data:GetProjectIDAndBannerID: %w", err) } defer rows.Close() var retProjectID *string var retBannerID *string err = scanRowsForIDs(rows, &retProjectID, &retBannerID) if err != nil { return "", "", fmt.Errorf("error scanning rows in data:GetProjectIDAndBannerID: %w", err) } return safeStringDereference(retProjectID), safeStringDereference(retBannerID), nil } func (ds Dataset) GetStoreID(ctx context.Context, store, bannerID string) (storeID string, err error) { var rows *sql.Rows var storeUUID *string if isUUID(store) { storeUUID = &store } rows, err = ds.db.QueryContext(ctx, datasql.SelectStoreID, storeUUID, store, bannerID) if err != nil { return "", fmt.Errorf("error querying db in data:GetStoreID: %w", err) } defer rows.Close() var retStoreID *string err = scanRowsForIDs(rows, &retStoreID) if err != nil { return "", fmt.Errorf("error scanning rows in data:GetStoreID: %w", err) } return safeStringDereference(retStoreID), nil } func (ds Dataset) GetTerminalID(ctx context.Context, terminal, storeID string) (terminalID string, err error) { var rows *sql.Rows var terminalUUID *string if isUUID(terminal) { terminalUUID = &terminal } rows, err = ds.db.QueryContext(ctx, datasql.SelectTerminalID, terminalUUID, terminal, storeID) if err != nil { return "", fmt.Errorf("error querying db in data:GetTerminalID: %w", err) } defer rows.Close() var retTerminalID *string err = scanRowsForIDs(rows, &retTerminalID) if err != nil { return "", fmt.Errorf("error scanning rows in data:GetTerminalID: %w", err) } return safeStringDereference(retTerminalID), nil } func scanRowsForIDs(rows *sql.Rows, vals ...any) (err error) { if rows.Next() { err := rows.Scan(vals...) if err != nil { return fmt.Errorf("error scanning rows in data:scanRows: %w", err) } } if rows.Next() { return fmt.Errorf("error multiple rows returned in data:scanRows") } err = rows.Err() if err != nil { return fmt.Errorf("error data:scanRows on rows.Err: %w", err) } return nil } func safeStringDereference(s *string) string { if s == nil { s = new(string) } return *s }