1 package database
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7
8 datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql"
9
10 "github.com/go-logr/logr"
11 "github.com/google/uuid"
12 )
13
14 type Dataset struct {
15 log logr.Logger
16 db *sql.DB
17 }
18
19 func New(log logr.Logger, db *sql.DB) Dataset {
20 return Dataset{log: log, db: db}
21 }
22
23 func isUUID(val string) bool {
24 _, err := uuid.Parse(val)
25 return err == nil
26 }
27
28 func (ds Dataset) GetProjectAndBannerID(ctx context.Context, banner string) (projectID, bannerID string, err error) {
29
30 var rows *sql.Rows
31 var bannerUUID *string
32 if isUUID(banner) {
33 bannerUUID = &banner
34 }
35 rows, err = ds.db.QueryContext(ctx, datasql.SelectProjectIDAndBannerID, bannerUUID, banner)
36 if err != nil {
37 return "", "", fmt.Errorf("error querying db in data:GetProjectIDAndBannerID: %w", err)
38 }
39 defer rows.Close()
40
41 var retProjectID *string
42 var retBannerID *string
43 err = scanRowsForIDs(rows, &retProjectID, &retBannerID)
44 if err != nil {
45 return "", "", fmt.Errorf("error scanning rows in data:GetProjectIDAndBannerID: %w", err)
46 }
47 return safeStringDereference(retProjectID), safeStringDereference(retBannerID), nil
48 }
49
50 func (ds Dataset) GetStoreID(ctx context.Context, store, bannerID string) (storeID string, err error) {
51 var rows *sql.Rows
52 var storeUUID *string
53 if isUUID(store) {
54 storeUUID = &store
55 }
56 rows, err = ds.db.QueryContext(ctx, datasql.SelectStoreID, storeUUID, store, bannerID)
57 if err != nil {
58 return "", fmt.Errorf("error querying db in data:GetStoreID: %w", err)
59 }
60 defer rows.Close()
61
62 var retStoreID *string
63 err = scanRowsForIDs(rows, &retStoreID)
64 if err != nil {
65 return "", fmt.Errorf("error scanning rows in data:GetStoreID: %w", err)
66 }
67 return safeStringDereference(retStoreID), nil
68 }
69
70 func (ds Dataset) GetTerminalID(ctx context.Context, terminal, storeID string) (terminalID string, err error) {
71 var rows *sql.Rows
72 var terminalUUID *string
73 if isUUID(terminal) {
74 terminalUUID = &terminal
75 }
76 rows, err = ds.db.QueryContext(ctx, datasql.SelectTerminalID, terminalUUID, terminal, storeID)
77 if err != nil {
78 return "", fmt.Errorf("error querying db in data:GetTerminalID: %w", err)
79 }
80 defer rows.Close()
81
82 var retTerminalID *string
83 err = scanRowsForIDs(rows, &retTerminalID)
84 if err != nil {
85 return "", fmt.Errorf("error scanning rows in data:GetTerminalID: %w", err)
86 }
87 return safeStringDereference(retTerminalID), nil
88 }
89
90 func scanRowsForIDs(rows *sql.Rows, vals ...any) (err error) {
91 if rows.Next() {
92 err := rows.Scan(vals...)
93 if err != nil {
94 return fmt.Errorf("error scanning rows in data:scanRows: %w", err)
95 }
96 }
97
98 if rows.Next() {
99 return fmt.Errorf("error multiple rows returned in data:scanRows")
100 }
101
102 err = rows.Err()
103 if err != nil {
104 return fmt.Errorf("error data:scanRows on rows.Err: %w", err)
105 }
106 return nil
107 }
108
109 func safeStringDereference(s *string) string {
110 if s == nil {
111 s = new(string)
112 }
113 return *s
114 }
115
View as plain text