package database import ( "context" "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) const bannerID = "2f9f5965-ed2a-4262-9fd9-9d2d8f8bee8a" func TestEARoles(t *testing.T) { t.Parallel() tests := map[string]struct { bannerID string command rulesengine.Command expRes []string expectations func(mock sqlmock.Sqlmock) errorAssertion assert.ErrorAssertionFunc }{ "Success": { bannerID: bannerID, command: rulesengine.Command{ Name: "ls", Type: eaconst.Command, }, expectations: func(mock sqlmock.Sqlmock) { mockRow := sqlmock.NewRows([]string{"privilegeName"}).AddRow("ea-read") mock.ExpectQuery(datasql.SelectPrivNamesForCommandAndBanner). WithArgs("ls", "command", bannerID).WillReturnRows(mockRow) }, expRes: []string{"ea-read"}, errorAssertion: assert.NoError, }, "Success Executable": { bannerID: bannerID, command: rulesengine.Command{ Name: "myScript", Type: eaconst.Executable, }, expectations: func(mock sqlmock.Sqlmock) { mockRow := sqlmock.NewRows([]string{"privilegeName"}).AddRow("ea-read") mock.ExpectQuery(datasql.SelectPrivNamesForCommandAndBanner). WithArgs("myScript", "executable", bannerID).WillReturnRows(mockRow) }, expRes: []string{"ea-read"}, errorAssertion: assert.NoError, }, "Unknown Type": { bannerID: bannerID, command: rulesengine.Command{ Name: "myScript", Type: eaconst.RequestType("invalidType"), }, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectPrivNamesForCommandAndBanner). WithArgs("myScript", "invalidType", bannerID). WillReturnError(fmt.Errorf(`Error: invalid input value for enum command_type: "Type" (SQLSTATE 22P02)`)) }, expRes: nil, errorAssertion: func(tt assert.TestingT, err error, _ ...interface{}) bool { return assert.EqualError(tt, err, `Error: invalid input value for enum command_type: "Type" (SQLSTATE 22P02)`) }, }, "Fail no banner": { bannerID: "", command: rulesengine.Command{ Name: "ls", Type: eaconst.Command, }, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectPrivNamesForCommandAndBanner). WithArgs("ls", "command", "").WillReturnError(fmt.Errorf("uuid error")) }, expRes: []string(nil), errorAssertion: func(tt assert.TestingT, err error, _ ...interface{}) bool { return assert.EqualError(tt, err, "uuid error") }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) res, err := ds.EARoles(context.Background(), tc.bannerID, tc.command) tc.errorAssertion(t, err) assert.Equal(t, tc.expRes, res) assert.NoError(t, mock.ExpectationsWereMet()) }) } }