1 package database
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "testing"
8
9 "edge-infra.dev/pkg/lib/fog"
10 "edge-infra.dev/pkg/lib/uuid"
11 datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql"
12
13 "github.com/DATA-DOG/go-sqlmock"
14 "github.com/stretchr/testify/assert"
15 )
16
17
18 func EqualError(message string) assert.ErrorAssertionFunc {
19 return func(t assert.TestingT, err error, i ...interface{}) bool {
20 return assert.EqualError(t, err, message, i...)
21 }
22 }
23
24 func initMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) {
25 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
26 if err != nil {
27 t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
28 }
29 return db, mock
30 }
31
32 func TestNew(t *testing.T) {
33 log := fog.New()
34 db, _, err := sqlmock.New()
35 assert.NoError(t, err)
36
37 expected := Dataset{log, db}
38 actual := New(log, db)
39
40 assert.Equal(t, expected, actual)
41 }
42
43 func TestGetProjectAndBannerID(t *testing.T) {
44 t.Parallel()
45
46 validBannerUUID, validProjectUUID := uuid.New().UUID, uuid.New().UUID
47
48 tests := map[string]struct {
49 banner string
50
51 expectations func(mock sqlmock.Sqlmock)
52 expProjectID string
53 expBannerID string
54 errorAssertion assert.ErrorAssertionFunc
55 }{
56 "Using Banner ID": {
57 banner: validBannerUUID,
58 expectations: func(mock sqlmock.Sqlmock) {
59 mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
60 WithArgs(validBannerUUID, validBannerUUID).
61 WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID))
62 },
63 expProjectID: validProjectUUID,
64 expBannerID: validBannerUUID,
65 errorAssertion: assert.NoError,
66 },
67 "Using Banner Name": {
68 banner: "name",
69 expectations: func(mock sqlmock.Sqlmock) {
70 mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
71 WithArgs(nil, "name").
72 WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID))
73 },
74 expProjectID: validProjectUUID,
75 expBannerID: validBannerUUID,
76 errorAssertion: assert.NoError,
77 },
78 "Query Error": {
79 banner: "name",
80 expectations: func(mock sqlmock.Sqlmock) {
81 mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
82 WithArgs(nil, "name").
83 WillReturnError(fmt.Errorf("error"))
84 },
85 expProjectID: "",
86 expBannerID: "",
87 errorAssertion: EqualError("error querying db in data:GetProjectIDAndBannerID: error"),
88 },
89 "Multiple Rows": {
90 banner: "name",
91 expectations: func(mock sqlmock.Sqlmock) {
92 mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
93 WithArgs(nil, "name").
94 WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID).AddRow("oops", "double-oops"))
95 },
96 expProjectID: "",
97 expBannerID: "",
98 errorAssertion: EqualError("error scanning rows in data:GetProjectIDAndBannerID: error multiple rows returned in data:scanRows"),
99 },
100 "No Rows Returned": {
101 banner: "name",
102 expectations: func(mock sqlmock.Sqlmock) {
103 mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
104 WithArgs(nil, "name").
105 WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}))
106 },
107 expProjectID: "",
108 expBannerID: "",
109 errorAssertion: assert.NoError,
110 },
111 }
112
113 for name, tc := range tests {
114 tc := tc
115 t.Run(name, func(t *testing.T) {
116 t.Parallel()
117
118 db, mock := initMockDB(t)
119 defer db.Close()
120
121 ds := Dataset{db: db}
122
123 tc.expectations(mock)
124
125 projectID, bannerID, err := ds.GetProjectAndBannerID(context.Background(), tc.banner)
126 tc.errorAssertion(t, err)
127 assert.Equal(t, tc.expProjectID, projectID)
128 assert.Equal(t, tc.expBannerID, bannerID)
129
130 assert.NoError(t, mock.ExpectationsWereMet())
131 })
132 }
133 }
134
135
136 func TestGetStoreID(t *testing.T) {
137 t.Parallel()
138
139 validUUID := uuid.New().UUID
140
141 tests := map[string]struct {
142 store string
143 bannerID string
144
145 expectations func(mock sqlmock.Sqlmock)
146 expID string
147 errorAssertion assert.ErrorAssertionFunc
148 }{
149 "Using Store ID": {
150 store: validUUID,
151 bannerID: "id",
152 expectations: func(mock sqlmock.Sqlmock) {
153 mock.ExpectQuery(datasql.SelectStoreID).
154 WithArgs(validUUID, validUUID, "id").
155 WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID))
156 },
157 expID: validUUID,
158 errorAssertion: assert.NoError,
159 },
160 "Using Store Name": {
161 store: "name",
162 bannerID: "id",
163 expectations: func(mock sqlmock.Sqlmock) {
164 mock.ExpectQuery(datasql.SelectStoreID).
165 WithArgs(nil, "name", "id").
166 WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID))
167 },
168 expID: validUUID,
169 errorAssertion: assert.NoError,
170 },
171 "Query Error": {
172 store: "name",
173 bannerID: "id",
174 expectations: func(mock sqlmock.Sqlmock) {
175 mock.ExpectQuery(datasql.SelectStoreID).
176 WithArgs(nil, "name", "id").
177 WillReturnError(fmt.Errorf("error"))
178 },
179 expID: "",
180 errorAssertion: EqualError("error querying db in data:GetStoreID: error"),
181 },
182 "Multiple Rows": {
183 store: "name",
184 bannerID: "id",
185 expectations: func(mock sqlmock.Sqlmock) {
186 mock.ExpectQuery(datasql.SelectStoreID).
187 WithArgs(nil, "name", "id").
188 WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID).AddRow("oops"))
189 },
190 expID: "",
191 errorAssertion: EqualError("error scanning rows in data:GetStoreID: error multiple rows returned in data:scanRows"),
192 },
193 "No Rows Returned": {
194 store: "name",
195 bannerID: "id",
196 expectations: func(mock sqlmock.Sqlmock) {
197 mock.ExpectQuery(datasql.SelectStoreID).
198 WithArgs(nil, "name", "id").
199 WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}))
200 },
201 expID: "",
202 errorAssertion: assert.NoError,
203 },
204 }
205
206 for name, tc := range tests {
207 tc := tc
208 t.Run(name, func(t *testing.T) {
209 t.Parallel()
210
211 db, mock := initMockDB(t)
212 defer db.Close()
213
214 ds := Dataset{db: db}
215
216 tc.expectations(mock)
217
218 storeID, err := ds.GetStoreID(context.Background(), tc.store, tc.bannerID)
219 tc.errorAssertion(t, err)
220 assert.Equal(t, tc.expID, storeID)
221
222 assert.NoError(t, mock.ExpectationsWereMet())
223 })
224 }
225 }
226
227
228 func TestGetTerminalID(t *testing.T) {
229 t.Parallel()
230
231 validUUID := uuid.New().UUID
232
233 tests := map[string]struct {
234 terminal string
235 storeID string
236
237 expectations func(mock sqlmock.Sqlmock)
238 expID string
239 errorAssertion assert.ErrorAssertionFunc
240 }{
241 "Using Terminal ID": {
242 terminal: validUUID,
243 storeID: "id",
244 expectations: func(mock sqlmock.Sqlmock) {
245 mock.ExpectQuery(datasql.SelectTerminalID).
246 WithArgs(validUUID, validUUID, "id").
247 WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID))
248 },
249 expID: validUUID,
250 errorAssertion: assert.NoError,
251 },
252 "Using Terminal Name": {
253 terminal: "name",
254 storeID: "id",
255 expectations: func(mock sqlmock.Sqlmock) {
256 mock.ExpectQuery(datasql.SelectTerminalID).
257 WithArgs(nil, "name", "id").
258 WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID))
259 },
260 expID: validUUID,
261 errorAssertion: assert.NoError,
262 },
263 "Query Error": {
264 terminal: "name",
265 storeID: "id",
266 expectations: func(mock sqlmock.Sqlmock) {
267 mock.ExpectQuery(datasql.SelectTerminalID).
268 WithArgs(nil, "name", "id").
269 WillReturnError(fmt.Errorf("error"))
270 },
271 expID: "",
272 errorAssertion: EqualError("error querying db in data:GetTerminalID: error"),
273 },
274 "Multiple Rows": {
275 terminal: "name",
276 storeID: "id",
277 expectations: func(mock sqlmock.Sqlmock) {
278 mock.ExpectQuery(datasql.SelectTerminalID).
279 WithArgs(nil, "name", "id").
280 WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID).AddRow("oops"))
281 },
282 expID: "",
283 errorAssertion: EqualError("error scanning rows in data:GetTerminalID: error multiple rows returned in data:scanRows"),
284 },
285 "No Rows Returned": {
286 terminal: "name",
287 storeID: "id",
288 expectations: func(mock sqlmock.Sqlmock) {
289 mock.ExpectQuery(datasql.SelectTerminalID).
290 WithArgs(nil, "name", "id").
291 WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}))
292 },
293 expID: "",
294 errorAssertion: assert.NoError,
295 },
296 }
297
298 for name, tc := range tests {
299 tc := tc
300 t.Run(name, func(t *testing.T) {
301 t.Parallel()
302
303 db, mock := initMockDB(t)
304 defer db.Close()
305
306 ds := Dataset{db: db}
307
308 tc.expectations(mock)
309
310 terminalID, err := ds.GetTerminalID(context.Background(), tc.terminal, tc.storeID)
311 tc.errorAssertion(t, err)
312 assert.Equal(t, tc.expID, terminalID)
313
314 assert.NoError(t, mock.ExpectationsWereMet())
315 })
316 }
317 }
318
319 func mockRowsToSQLRows(mockRows *sqlmock.Rows) (*sql.Rows, error) {
320 db, mock, err := sqlmock.New()
321 if err != nil {
322 return nil, err
323 }
324 defer db.Close()
325 mock.ExpectQuery("select").WillReturnRows(mockRows)
326 rows, err := db.Query("select")
327 if err != nil {
328 return nil, err
329 }
330 return rows, nil
331 }
332
333 func TestScanRowsSingleColumn(t *testing.T) {
334 t.Parallel()
335
336 tests := map[string]struct {
337 mockRows *sqlmock.Rows
338 expRes string
339 errAssert assert.ErrorAssertionFunc
340 }{
341 "Valid": {
342 mockRows: sqlmock.NewRows([]string{"col"}).AddRow("val"),
343 expRes: "val",
344 errAssert: assert.NoError,
345 },
346 "No Rows": {
347 mockRows: sqlmock.NewRows([]string{}),
348 expRes: "",
349 errAssert: assert.NoError,
350 },
351 "Multiple Rows": {
352 mockRows: sqlmock.NewRows([]string{"col"}).AddRow("val").AddRow("oops"),
353 expRes: "val",
354 errAssert: EqualError("error multiple rows returned in data:scanRows"),
355 },
356 }
357
358 for name, tc := range tests {
359 tc := tc
360 t.Run(name, func(t *testing.T) {
361 t.Parallel()
362
363 rows, err := mockRowsToSQLRows(tc.mockRows)
364 assert.NoError(t, err)
365 defer rows.Close()
366
367 empty := ""
368 result := &empty
369 err = scanRowsForIDs(rows, &result)
370 tc.errAssert(t, err)
371 assert.Equal(t, tc.expRes, *result)
372 })
373 }
374 }
375
376 func TestScanRowsMultiColumn(t *testing.T) {
377 t.Parallel()
378
379 tests := map[string]struct {
380 mockRows *sqlmock.Rows
381 expRes1 string
382 expRes2 string
383 errAssert assert.ErrorAssertionFunc
384 }{
385 "Valid": {
386 mockRows: sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2"),
387 expRes1: "val1",
388 expRes2: "val2",
389 errAssert: assert.NoError,
390 },
391 "No Rows": {
392 mockRows: sqlmock.NewRows([]string{}),
393 expRes1: "",
394 expRes2: "",
395 errAssert: assert.NoError,
396 },
397 "Multiple Rows": {
398 mockRows: sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2").AddRow("oops1", "oops2"),
399 expRes1: "val1",
400 expRes2: "val2",
401 errAssert: EqualError("error multiple rows returned in data:scanRows"),
402 },
403 }
404
405 for name, tc := range tests {
406 tc := tc
407 t.Run(name, func(t *testing.T) {
408 t.Parallel()
409
410 rows, err := mockRowsToSQLRows(tc.mockRows)
411 assert.NoError(t, err)
412 defer rows.Close()
413
414 empty := ""
415 result1 := &empty
416 result2 := &empty
417 err = scanRowsForIDs(rows, &result1, &result2)
418 tc.errAssert(t, err)
419 assert.Equal(t, tc.expRes1, *result1)
420 assert.Equal(t, tc.expRes2, *result2)
421 })
422 }
423 }
424
425 func TestIsUUID(t *testing.T) {
426 t.Parallel()
427
428 tests := map[string]struct {
429 val string
430 exp bool
431 }{
432 "Valid UUID": {
433 val: uuid.New().UUID,
434 exp: true,
435 },
436 "Invalid UUID": {
437 val: "an-invalid-uuid",
438 exp: false,
439 },
440 "Empty String": {
441 val: "",
442 exp: false,
443 },
444 }
445
446 for name, tc := range tests {
447 tc := tc
448 t.Run(name, func(t *testing.T) {
449 t.Parallel()
450 assert.Equal(t, tc.exp, isUUID(tc.val))
451 })
452 }
453 }
454
455 func TestSafeStringDereference(t *testing.T) {
456 s := "a-string"
457
458 tests := map[string]struct {
459 input *string
460 expected string
461 }{
462 "Valid": {
463 input: &s,
464 expected: s,
465 },
466 "Empty String": {
467 input: new(string),
468 expected: "",
469 },
470 "Nil": {
471 input: nil,
472 expected: "",
473 },
474 }
475
476 for name, tc := range tests {
477 tc := tc
478 t.Run(name, func(t *testing.T) {
479 t.Parallel()
480 assert.Equal(t, tc.expected, safeStringDereference(tc.input))
481 })
482 }
483 }
484
View as plain text