package database import ( "context" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "edge-infra.dev/pkg/lib/fog" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func TestAddCommands(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mock.ExpectBegin() mock.ExpectExec(datasql.InsertCommand).WithArgs("ls").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectExec(datasql.InsertCommand).WithArgs("mv").WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectCommit() res, err := ds.AddCommands(context.Background(), []string{"ls", "mv"}) assert.NoError(t, err) assert.Nil(t, res.Conflicts) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestDeleteCommand(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mock.ExpectExec(datasql.DeleteCommand).WithArgs("ls").WillReturnResult(sqlmock.NewResult(1, 1)) _, err := ds.DeleteCommand(context.Background(), "ls") assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestDeleteCommandNoChange(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) command := "not-on-db" mock.ExpectExec(datasql.DeleteCommand).WithArgs(command).WillReturnResult(sqlmock.NewResult(0, 0)) res, err := ds.DeleteCommand(context.Background(), command) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } assert.NotEmpty(t, res.Errors) assert.Equal(t, rulesengine.UnknownCommand, res.Errors[0].Type) assert.Equal(t, int64(0), res.RowsAffected) } func TestDeleteCommandConflict(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) command := "ls" mock.ExpectExec(datasql.DeleteCommand).WithArgs(command).WillReturnError(&pgconn.PgError{Code: "23503"}) res, err := ds.DeleteCommand(context.Background(), command) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } assert.NotEmpty(t, res.Errors) assert.Equal(t, rulesengine.Conflict, res.Errors[0].Type) assert.Equal(t, res.RowsAffected, int64(0)) } func TestReadCommand(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mockRow := sqlmock.NewRows([]string{"command_id", "name"}).AddRow("test", "ls") mock.ExpectQuery(datasql.SelectCommandByName).WithArgs("ls").WillReturnRows(mockRow) res, err := ds.ReadCommand(context.Background(), "ls") assert.NoError(t, err) assert.EqualValues(t, rulesengine.Command{Name: "ls", ID: "test"}, res) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestReadAllCommands(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) mockRow := sqlmock.NewRows([]string{"command_id", "name"}).AddRow("test", "ls") mock.ExpectQuery(datasql.SelectAllCommands).WillReturnRows(mockRow) res, err := ds.ReadAllCommands(context.Background()) assert.EqualValues(t, []rulesengine.Command{{Name: "ls", ID: "test"}}, res) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } //nolint:dupl func TestReadCommandsWithFilter(t *testing.T) { t.Parallel() tests := map[string]struct { args []string mockRowsFunc func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery expectedRes []rulesengine.Command }{ "All Valid": { args: []string{"name1", "name2", "name3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"command_id", "name"}).AddRow("id1", "name1").AddRow("id2", "name2").AddRow("id3", "name3") return mock.ExpectQuery(datasql.SelectCommandsByName).WithArgs([]string{"name1", "name2", "name3"}).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Command{ {ID: "id1", Name: "name1"}, {ID: "id2", Name: "name2"}, {ID: "id3", Name: "name3"}, }, }, "No Rows Returned": { args: []string{"nonexistant1", "nonexistant2", "nonexistant3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"command_id", "name"}) return mock.ExpectQuery(datasql.SelectCommandsByName).WithArgs([]string{"nonexistant1", "nonexistant2", "nonexistant3"}).WillReturnRows(mockRows) }, }, "Some Valid": { args: []string{"name1", "nonexistant2", "name3"}, mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"command_id", "name"}).AddRow("id1", "name1").AddRow("id3", "name3") return mock.ExpectQuery(datasql.SelectCommandsByName).WithArgs([]string{"name1", "nonexistant2", "name3"}).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Command{ {ID: "id1", Name: "name1"}, {ID: "id3", Name: "name3"}, }, }, "No Filter Passed In": { mockRowsFunc: func(mock sqlmock.Sqlmock) *sqlmock.ExpectedQuery { mockRows := sqlmock.NewRows([]string{"command_id", "name"}).AddRow("all", "all") return mock.ExpectQuery(datasql.SelectAllCommands).WillReturnRows(mockRows) }, expectedRes: []rulesengine.Command{ {ID: "all", Name: "all"}, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{})) require.NoError(t, err, "an error '%s' was not expected when opening a stub database connection", err) defer db.Close() ds := New(fog.New(), db) tc.mockRowsFunc(mock) res, err := ds.ReadCommandsWithFilter(context.Background(), tc.args) assert.EqualValues(t, tc.expectedRes, res) assert.NoError(t, err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } }) } }