1 package database
2
3 import (
4 "context"
5 "database/sql/driver"
6 "strings"
7 "testing"
8
9 "github.com/DATA-DOG/go-sqlmock"
10 "github.com/jackc/pgconn"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13
14 "edge-infra.dev/pkg/lib/fog"
15 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
16 datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql"
17 )
18
19 func TestAddPrivilegesBulk(t *testing.T) {
20 db, mock := initMockDB(t)
21 defer db.Close()
22 ds := New(fog.New(), db)
23
24 mock.ExpectBegin()
25 mock.ExpectExec(datasql.InsertPrivilege).WithArgs("basic").WillReturnResult(sqlmock.NewResult(1, 1))
26 mock.ExpectExec(datasql.InsertPrivilege).WithArgs("admin").WillReturnResult(sqlmock.NewResult(2, 1))
27 mock.ExpectCommit()
28
29 res, err := ds.AddPrivileges(context.Background(), []string{"basic", "admin"})
30 assert.NoError(t, err)
31 assert.Nil(t, res.Conflicts)
32 if err := mock.ExpectationsWereMet(); err != nil {
33 t.Errorf("there were unfulfilled expectations: %s", err)
34 }
35 }
36
37 func TestDeletePrivilege(t *testing.T) {
38 db, mock := initMockDB(t)
39 defer db.Close()
40 ds := New(fog.New(), db)
41
42 mock.ExpectExec(datasql.DeletePrivilege).WithArgs("basic").WillReturnResult(sqlmock.NewResult(1, 1))
43
44 _, err := ds.DeletePrivilege(context.Background(), "basic")
45 assert.NoError(t, err)
46 if err := mock.ExpectationsWereMet(); err != nil {
47 t.Errorf("there were unfulfilled expectations: %s", err)
48 }
49 }
50
51 func TestDeletePrivilegeNoChange(t *testing.T) {
52 db, mock := initMockDB(t)
53 defer db.Close()
54 ds := New(fog.New(), db)
55
56 priv := "not-on-db"
57 mock.ExpectExec(datasql.DeletePrivilege).WithArgs(priv).WillReturnResult(sqlmock.NewResult(0, 0))
58
59 res, err := ds.DeletePrivilege(context.Background(), priv)
60 assert.NoError(t, err)
61 if err := mock.ExpectationsWereMet(); err != nil {
62 t.Errorf("there were unfulfilled expectations: %s", err)
63 }
64 assert.NotEmpty(t, res.Errors)
65 assert.Equal(t, rulesengine.UnknownPrivilege, res.Errors[0].Type)
66 assert.Equal(t, int64(0), res.RowsAffected)
67 }
68
69 func TestDeletePrivilegeConflict(t *testing.T) {
70 db, mock := initMockDB(t)
71 defer db.Close()
72 ds := New(fog.New(), db)
73
74 priv := "conflict"
75 mock.ExpectExec(datasql.DeletePrivilege).WithArgs(priv).WillReturnError(&pgconn.PgError{Code: "23503"})
76
77 res, err := ds.DeletePrivilege(context.Background(), priv)
78 assert.NoError(t, err)
79 if err := mock.ExpectationsWereMet(); err != nil {
80 t.Errorf("there were unfulfilled expectations: %s", err)
81 }
82 assert.NotEmpty(t, res.Errors)
83 assert.Equal(t, rulesengine.Conflict, res.Errors[0].Type)
84 assert.Equal(t, int64(0), res.RowsAffected)
85 }
86
87 func TestReadPrivilege(t *testing.T) {
88 db, mock := initMockDB(t)
89 defer db.Close()
90 ds := New(fog.New(), db)
91
92 mockRow := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("test", "basic")
93 mock.ExpectQuery(datasql.SelectPrivilegeByName).WithArgs("basic").WillReturnRows(mockRow)
94
95 res, err := ds.ReadPrivilege(context.Background(), "basic")
96 assert.NoError(t, err)
97 assert.EqualValues(t, rulesengine.Privilege{Name: "basic", ID: "test"}, res)
98 if err := mock.ExpectationsWereMet(); err != nil {
99 t.Errorf("there were unfulfilled expectations: %s", err)
100 }
101 }
102
103
104 type StringSliceValueConverter struct{}
105
106
107 func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
108 if vv, ok := v.([]string); ok {
109
110
111 arrayStr := "{" + strings.Join(vv, ",") + "}"
112 return arrayStr, nil
113 }
114
115 return driver.DefaultParameterConverter.ConvertValue(v)
116 }
117
118
119 func TestReadPrivilegesWithFilter(t *testing.T) {
120 t.Parallel()
121
122 tests := map[string]struct {
123 args []string
124 mockRowsFunc func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery
125 expectedRes []rulesengine.Privilege
126 }{
127 "All Valid": {
128 args: []string{"name1", "name2", "name3"},
129 mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery {
130 mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("id1", "name1").AddRow("id2", "name2").AddRow("id3", "name3")
131 return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"name1", "name2", "name3"}).WillReturnRows(mockRows)
132 },
133 expectedRes: []rulesengine.Privilege{
134 {ID: "id1", Name: "name1"},
135 {ID: "id2", Name: "name2"},
136 {ID: "id3", Name: "name3"},
137 },
138 },
139 "No Rows Returned": {
140 args: []string{"nonexistant1", "nonexistant2", "nonexistant3"},
141 mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery {
142 mockRows := sqlmock.NewRows([]string{"privilege_id", "name"})
143 return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"nonexistant1", "nonexistant2", "nonexistant3"}).WillReturnRows(mockRows)
144 },
145 },
146 "Some Valid": {
147 args: []string{"name1", "nonexistant2", "name3"},
148 mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery {
149 mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("id1", "name1").AddRow("id3", "name3")
150 return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"name1", "nonexistant2", "name3"}).WillReturnRows(mockRows)
151 },
152 expectedRes: []rulesengine.Privilege{
153 {ID: "id1", Name: "name1"},
154 {ID: "id3", Name: "name3"},
155 },
156 },
157 "No Filter Passed In": {
158 mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery {
159 mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("all", "all")
160 return mock.ExpectQuery(datasql.SelectAllPrivileges).WillReturnRows(mockRows)
161 },
162 expectedRes: []rulesengine.Privilege{
163 {ID: "all", Name: "all"},
164 },
165 },
166 }
167
168 for name, tc := range tests {
169 tc := tc
170 t.Run(name, func(t *testing.T) {
171 t.Parallel()
172
173 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{}))
174 require.NoError(t, err, "an error '%s' was not expected when opening a stub database connection", err)
175 defer db.Close()
176
177 ds := New(fog.New(), db)
178 tc.mockRowsFunc(mock)
179
180 res, err := ds.ReadPrivilegesWithFilter(context.Background(), tc.args)
181
182 assert.EqualValues(t, tc.expectedRes, res)
183 assert.NoError(t, err)
184 if err := mock.ExpectationsWereMet(); err != nil {
185 t.Errorf("there were unfulfilled expectations: %s", err)
186 }
187 })
188 }
189 }
190
191 func TestReadAllPrivileges(t *testing.T) {
192 db, mock := initMockDB(t)
193 defer db.Close()
194 ds := New(fog.New(), db)
195
196 mockRow := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("test", "basic")
197 mock.ExpectQuery(datasql.SelectAllPrivileges).WillReturnRows(mockRow)
198
199 res, err := ds.ReadAllPrivileges(context.Background())
200 assert.EqualValues(t, []rulesengine.Privilege{{Name: "basic", ID: "test"}}, res)
201 assert.Nil(t, err)
202 if err := mock.ExpectationsWereMet(); err != nil {
203 t.Errorf("there were unfulfilled expectations: %s", err)
204 }
205 }
206
View as plain text