package database import ( "context" "database/sql/driver" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "edge-infra.dev/pkg/lib/fog" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func TestAddPrivilegesBulk(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mock.ExpectBegin() mock.ExpectExec(datasql.InsertPrivilege).WithArgs("basic").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(datasql.InsertPrivilege).WithArgs("admin").WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectCommit() res, err := ds.AddPrivileges(context.Background(), []string{"basic", "admin"}) assert.NoError(t, err) assert.Nil(t, res.Conflicts) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestDeletePrivilege(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mock.ExpectExec(datasql.DeletePrivilege).WithArgs("basic").WillReturnResult(sqlmock.NewResult(1, 1)) _, err := ds.DeletePrivilege(context.Background(), "basic") assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestDeletePrivilegeNoChange(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) priv := "not-on-db" mock.ExpectExec(datasql.DeletePrivilege).WithArgs(priv).WillReturnResult(sqlmock.NewResult(0, 0)) res, err := ds.DeletePrivilege(context.Background(), priv) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } assert.NotEmpty(t, res.Errors) assert.Equal(t, rulesengine.UnknownPrivilege, res.Errors[0].Type) assert.Equal(t, int64(0), res.RowsAffected) } func TestDeletePrivilegeConflict(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) priv := "conflict" mock.ExpectExec(datasql.DeletePrivilege).WithArgs(priv).WillReturnError(&pgconn.PgError{Code: "23503"}) res, err := ds.DeletePrivilege(context.Background(), priv) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } assert.NotEmpty(t, res.Errors) assert.Equal(t, rulesengine.Conflict, res.Errors[0].Type) assert.Equal(t, int64(0), res.RowsAffected) } func TestReadPrivilege(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mockRow := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("test", "basic") mock.ExpectQuery(datasql.SelectPrivilegeByName).WithArgs("basic").WillReturnRows(mockRow) res, err := ds.ReadPrivilege(context.Background(), "basic") assert.NoError(t, err) assert.EqualValues(t, rulesengine.Privilege{Name: "basic", ID: "test"}, res) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } // StringSliceValueConverter converts a slice of strings to a PostgreSQL array representation. type StringSliceValueConverter struct{} // ConvertValue implements the driver.ValueConverter interface. func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) { if vv, ok := v.([]string); ok { // Convert []string to a PostgreSQL array representation. // Note: Proper escaping and handling of special characters is necessary for production code. arrayStr := "{" + strings.Join(vv, ",") + "}" return arrayStr, nil } // Fallback for other types. return driver.DefaultParameterConverter.ConvertValue(v) } //nolint:dupl func TestReadPrivilegesWithFilter(t *testing.T) { t.Parallel() tests := map[string]struct { args []string mockRowsFunc func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery expectedRes []rulesengine.Privilege }{ "All Valid": { args: []string{"name1", "name2", "name3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("id1", "name1").AddRow("id2", "name2").AddRow("id3", "name3") return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"name1", "name2", "name3"}).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Privilege{ {ID: "id1", Name: "name1"}, {ID: "id2", Name: "name2"}, {ID: "id3", Name: "name3"}, }, }, "No Rows Returned": { args: []string{"nonexistant1", "nonexistant2", "nonexistant3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}) return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"nonexistant1", "nonexistant2", "nonexistant3"}).WillReturnRows(mockRows) }, }, "Some Valid": { args: []string{"name1", "nonexistant2", "name3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("id1", "name1").AddRow("id3", "name3") return mock.ExpectQuery(datasql.SelectPrivilegesByName).WithArgs([]string{"name1", "nonexistant2", "name3"}).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Privilege{ {ID: "id1", Name: "name1"}, {ID: "id3", Name: "name3"}, }, }, "No Filter Passed In": { mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("all", "all") return mock.ExpectQuery(datasql.SelectAllPrivileges).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Privilege{ {ID: "all", Name: "all"}, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{})) require.NoError(t, err, "an error '%s' was not expected when opening a stub database connection", err) defer db.Close() ds := New(fog.New(), db) tc.mockRowsFunc(mock) res, err := ds.ReadPrivilegesWithFilter(context.Background(), tc.args) assert.EqualValues(t, tc.expectedRes, res) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } }) } } func TestReadAllPrivileges(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mockRow := sqlmock.NewRows([]string{"privilege_id", "name"}).AddRow("test", "basic") mock.ExpectQuery(datasql.SelectAllPrivileges).WillReturnRows(mockRow) res, err := ds.ReadAllPrivileges(context.Background()) assert.EqualValues(t, []rulesengine.Privilege{{Name: "basic", ID: "test"}}, res) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } }