...
1 package database
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "testing"
8
9 "github.com/DATA-DOG/go-sqlmock"
10 "github.com/stretchr/testify/assert"
11
12 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
13 datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
14 )
15
16 func initMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) {
17 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
18 if err != nil {
19 t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
20 }
21 return db, mock
22 }
23
24 func TestAddNamesDefer(t *testing.T) {
25 t.Parallel()
26
27 tests := map[string]struct {
28 names []string
29 expRes rulesengine.AddNameResult
30
31 expectations func(mock sqlmock.Sqlmock)
32 errorAssertion assert.ErrorAssertionFunc
33 }{
34 "Successful commit": {
35 names: nil,
36 expectations: func(mock sqlmock.Sqlmock) {
37 mock.ExpectBegin()
38 mock.ExpectCommit().WillReturnError(nil)
39 },
40 errorAssertion: assert.NoError,
41 },
42 "Commit Returns Error": {
43 names: nil,
44 expectations: func(mock sqlmock.Sqlmock) {
45 mock.ExpectBegin()
46 mock.ExpectCommit().WillReturnError(fmt.Errorf("commit error"))
47 },
48 errorAssertion: EqualError("error committing transaction: commit error"),
49 },
50 "Begin and do nothing on conflict": {
51 names: []string{"a"},
52 expRes: rulesengine.AddNameResult{},
53
54 expectations: func(mock sqlmock.Sqlmock) {
55 mock.ExpectBegin()
56 mock.ExpectExec(datasql.InsertCommand).
57 WithArgs("a").WillReturnResult(sqlmock.NewResult(0, 0))
58 mock.ExpectCommit()
59 },
60 errorAssertion: assert.NoError,
61 },
62 "Begin and add name, no conflict": {
63 names: []string{"a"},
64 expRes: rulesengine.AddNameResult{},
65 expectations: func(mock sqlmock.Sqlmock) {
66 mock.ExpectBegin()
67 mock.ExpectExec(datasql.InsertCommand).
68 WithArgs("a").WillReturnResult(sqlmock.NewResult(0, 1))
69 mock.ExpectCommit()
70 },
71 errorAssertion: assert.NoError,
72 },
73 }
74
75 for name, tc := range tests {
76 tc := tc
77 t.Run(name, func(t *testing.T) {
78 t.Parallel()
79
80 db, mock := initMockDB(t)
81 defer db.Close()
82
83 ds := Dataset{db: db}
84
85 tc.expectations(mock)
86
87 res, err := ds.addNames(context.Background(), tc.names, datasql.InsertCommand)
88 tc.errorAssertion(t, err)
89 assert.Equal(t, tc.expRes, res)
90
91 assert.NoError(t, mock.ExpectationsWereMet())
92 })
93 }
94 }
95
View as plain text