package database import ( "context" "fmt" "io" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/lib/fog" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func TestAddDefaultRule(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) command, priv := "ls", "basic" mock.ExpectBegin() mock.ExpectExec(datasql.InsertRuleDefault).WithArgs(command, priv).WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() res, err := ds.AddDefaultRules(context.Background(), []rulesengine.RuleSegment{{Command: rulesengine.Command{Name: command}, Privilege: rulesengine.Privilege{Name: priv}}}) assert.NoError(t, err) assert.Nil(t, res.Errors) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestAddDefaultRuleDefer(t *testing.T) { t.Parallel() ctx := fog.IntoContext(context.Background(), fog.New(fog.To(io.Discard))) tests := map[string]struct { inputRules []rulesengine.RuleSegment expRes rulesengine.AddRuleResult expectations func(mock sqlmock.Sqlmock) errorAssertion assert.ErrorAssertionFunc }{ "Successful commit": { inputRules: nil, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectCommit().WillReturnError(nil) }, errorAssertion: assert.NoError, }, "Commit Returns Error": { inputRules: nil, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectCommit().WillReturnError(fmt.Errorf("commit error")) }, errorAssertion: EqualError("error committing transaction: commit error"), }, "Rollback on error": { inputRules: []rulesengine.RuleSegment{{Command: rulesengine.Command{Name: "command"}, Privilege: rulesengine.Privilege{Name: "priv"}}}, expectations: func(mock sqlmock.Sqlmock) { // Return an error immediately from the insert query mock.ExpectBegin() mock.ExpectExec(datasql.InsertRuleDefault).WithArgs("command", "priv").WillReturnResult(sqlmock.NewErrorResult(fmt.Errorf("insert error"))) // Rollback is successful mock.ExpectRollback() }, errorAssertion: EqualError("error getting query result: insert error"), }, "Rollback on error returns an error": { inputRules: []rulesengine.RuleSegment{{Command: rulesengine.Command{Name: "command"}, Privilege: rulesengine.Privilege{Name: "priv"}}}, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Returns an error mock.ExpectExec(datasql.InsertRuleDefault).WithArgs("command", "priv").WillReturnResult(sqlmock.NewErrorResult(fmt.Errorf("insert error"))) mock.ExpectRollback().WillReturnError(fmt.Errorf("rollback error")) }, // Contains both insert error and rollback error errorAssertion: EqualError("error rolling back transaction due to application error (error getting query result: insert error): rollback error"), }, "Rollback on conflicts": { inputRules: []rulesengine.RuleSegment{{Command: rulesengine.Command{Name: "command"}, Privilege: rulesengine.Privilege{Name: "priv"}}}, expRes: rulesengine.AddRuleResult{Errors: []rulesengine.Error{{Privilege: "priv", Type: rulesengine.UnknownPrivilege}}}, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(datasql.InsertRuleDefault). WithArgs("command", "priv"). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectQuery(datasql.GetIDsForRuleSegment). WithArgs("command", "priv"). WillReturnRows( // Missing the privilege row sqlmock.NewRows([]string{"type", "id"}). AddRow("command", "command-uuid"), ) mock.ExpectRollback() }, errorAssertion: assert.NoError, }, "Rollback on Conflicts returns error": { inputRules: []rulesengine.RuleSegment{{Command: rulesengine.Command{Name: "command"}, Privilege: rulesengine.Privilege{Name: "priv"}}}, expRes: rulesengine.AddRuleResult{Errors: []rulesengine.Error{{Privilege: "priv", Type: rulesengine.UnknownPrivilege}}}, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(datasql.InsertRuleDefault). WithArgs("command", "priv"). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectQuery(datasql.GetIDsForRuleSegment). WithArgs("command", "priv"). WillReturnRows( // Missing the privilege row sqlmock.NewRows([]string{"type", "id"}). AddRow("command", "command-uuid"), ) mock.ExpectRollback().WillReturnError(fmt.Errorf("rollback error")) }, errorAssertion: EqualError("error rolling back transaction due to errors: rollback error"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) res, err := ds.AddDefaultRules(ctx, tc.inputRules) tc.errorAssertion(t, err) assert.Equal(t, tc.expRes, res) assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestReadAllDefaultRules(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) // &comname, &privname, &comid, &privid mockRow := sqlmock.NewRows([]string{"comname", "privname", "command_id", "privilege_id"}).AddRow("ls", "basic", "test_comid", "test_privid") mock.ExpectQuery(datasql.SelectAllDefaultRules).WillReturnRows(mockRow) res, err := ds.ReadAllDefaultRules(context.Background()) assert.NoError(t, err) assert.EqualValues(t, []rulesengine.RuleSegment{{ Command: rulesengine.Command{ ID: "test_comid", Name: "ls"}, Privilege: rulesengine.Privilege{ ID: "test_privid", Name: "basic"}}}, res) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestReadDefaultRulesForCommand(t *testing.T) { db, mock := initMockDB(t) defer db.Close() ds := New(fog.New(), db) // &comname, &privname, &comid, &privid mockRow := sqlmock.NewRows([]string{"comname", "privname", "command_id", "privilege_id"}).AddRow("ls", "basic", "test_comid", "test_privid") mock.ExpectQuery(datasql.SelectDefaultRulesByCommandName).WithArgs("ls").WillReturnRows(mockRow) res, err := ds.ReadDefaultRulesForCommand(context.Background(), "ls") assert.NoError(t, err) assert.EqualValues(t, []rulesengine.RuleSegment{{ Command: rulesengine.Command{ ID: "test_comid", Name: "ls"}, Privilege: rulesengine.Privilege{ ID: "test_privid", Name: "basic"}}}, res) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } } func TestDeleteDefaultRule(t *testing.T) { t.Parallel() tests := map[string]struct { commandName string privName string expectations func(mock sqlmock.Sqlmock) expRes rulesengine.DeleteResult expErr assert.ErrorAssertionFunc }{ "Deleted": { commandName: "ls", privName: "basic", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectExec(datasql.DeleteDefaultRule). WithArgs("ls", "basic"). WillReturnResult(sqlmock.NewResult(1, 1)) }, expRes: rulesengine.DeleteResult{RowsAffected: 1}, expErr: assert.NoError, }, "Unknown Command": { commandName: "unknownCommand", privName: "basic", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectExec(datasql.DeleteDefaultRule). WithArgs("unknownCommand", "basic"). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectBegin() mock.ExpectQuery(datasql.GetIDsForRuleSegment). WithArgs("unknownCommand", "basic"). WillReturnRows(sqlmock.NewRows([]string{"type", "id"}). AddRow("privilege", "basic"), ) mock.ExpectCommit() }, expRes: rulesengine.DeleteResult{RowsAffected: 0, Errors: []rulesengine.Error{{Type: rulesengine.UnknownCommand, Command: "unknownCommand"}}}, expErr: assert.NoError, }, "Unknown Privilege": { commandName: "ls", privName: "unknownPriv", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectExec(datasql.DeleteDefaultRule). WithArgs("ls", "unknownPriv"). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectBegin() mock.ExpectQuery(datasql.GetIDsForRuleSegment). WithArgs("ls", "unknownPriv"). WillReturnRows(sqlmock.NewRows([]string{"type", "id"}). AddRow("command", "ls"), ) mock.ExpectCommit() }, expRes: rulesengine.DeleteResult{RowsAffected: 0, Errors: []rulesengine.Error{{Type: rulesengine.UnknownPrivilege, Privilege: "unknownPriv"}}}, expErr: assert.NoError, }, "Both Unknown": { commandName: "unknownCommand", privName: "unknownPriv", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectExec(datasql.DeleteDefaultRule). WithArgs("unknownCommand", "unknownPriv"). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectBegin() mock.ExpectQuery(datasql.GetIDsForRuleSegment). WithArgs("unknownCommand", "unknownPriv"). WillReturnRows(sqlmock.NewRows([]string{"type", "id"})) mock.ExpectCommit() }, expRes: rulesengine.DeleteResult{ RowsAffected: 0, Errors: []rulesengine.Error{ {Type: rulesengine.UnknownCommand, Command: "unknownCommand"}, {Type: rulesengine.UnknownPrivilege, Privilege: "unknownPriv"}, }, }, expErr: assert.NoError, }, "Unknown Rule association": { commandName: "ls", privName: "basic", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectExec(datasql.DeleteDefaultRule). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectBegin() mock.ExpectQuery(datasql.GetIDsForRuleSegment). // All rows are returned WillReturnRows(sqlmock.NewRows([]string{"type", "id"}). AddRow("command", "ls"). AddRow("privilege", "basic"), ) mock.ExpectCommit() }, expRes: rulesengine.DeleteResult{RowsAffected: 0, Errors: []rulesengine.Error{ {Type: rulesengine.UnknownRule, Command: "ls", Privilege: "basic"}, }}, expErr: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) tc.expectations(mock) ds := Dataset{db: db} res, err := ds.DeleteDefaultRule(context.Background(), tc.commandName, tc.privName) tc.expErr(t, err) assert.Equal(t, tc.expRes, res) assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestAppendRule(t *testing.T) { rulesObjects := []rulesengine.RuleSegment{ { Command: rulesengine.Command{ID: "testcomID1", Name: "com1"}, Privilege: rulesengine.Privilege{ID: "testPrivID1", Name: "priv1"}, }, { Command: rulesengine.Command{ID: "testcomID1", Name: "com1"}, Privilege: rulesengine.Privilege{ID: "testPrivID2", Name: "priv2"}, }, { Command: rulesengine.Command{ID: "testcomID2", Name: "com2"}, Privilege: rulesengine.Privilege{ID: "testPrivID3", Name: "priv3"}, }, } expectedRules := []rulesengine.Rule{ { Command: rulesengine.Command{ID: "testcomID1", Name: "com1"}, Privileges: []rulesengine.Privilege{ { ID: "testPrivID1", Name: "priv1", }, { ID: "testPrivID2", Name: "priv2", }, }, }, { Command: rulesengine.Command{ID: "testcomID2", Name: "com2"}, Privileges: []rulesengine.Privilege{ { ID: "testPrivID3", Name: "priv3", }, }, }, } actualRules := assembleRules(rulesObjects) assert.EqualValues(t, expectedRules, actualRules) } func TestCheckUnknownDeleteRuleErrs(t *testing.T) { t.Parallel() tests := map[string]struct { segment rulesengine.RuleSegment expErrs []rulesengine.Error }{ "UnknownCommand": { segment: rulesengine.RuleSegment{ Command: rulesengine.Command{Name: "unknownCommand"}, Privilege: rulesengine.Privilege{ Name: "knownPriv", ID: "privID", }, }, expErrs: []rulesengine.Error{{Type: rulesengine.UnknownCommand, Command: "unknownCommand"}}, }, "UnknownPrivilege": { segment: rulesengine.RuleSegment{ Command: rulesengine.Command{ Name: "knownCommand", ID: "commandID", }, Privilege: rulesengine.Privilege{Name: "unknownPriv"}, }, expErrs: []rulesengine.Error{{Type: rulesengine.UnknownPrivilege, Privilege: "unknownPriv"}}, }, "UnknownCommand and UnknownPrivilege": { segment: rulesengine.RuleSegment{ Command: rulesengine.Command{Name: "unknownCommand"}, Privilege: rulesengine.Privilege{Name: "unknownPriv"}, }, expErrs: []rulesengine.Error{ {Type: rulesengine.UnknownCommand, Command: "unknownCommand"}, {Type: rulesengine.UnknownPrivilege, Privilege: "unknownPriv"}}, }, "UnknownRule": { segment: rulesengine.RuleSegment{ Command: rulesengine.Command{ Name: "knownCommand", ID: "commandID", }, Privilege: rulesengine.Privilege{ Name: "knownPriv", ID: "privID", }, }, expErrs: []rulesengine.Error{{Type: rulesengine.UnknownRule, Command: "knownCommand", Privilege: "knownPriv"}}, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() errs := checkUnknownDeleteRuleErrs(tc.segment) assert.Equal(t, tc.expErrs, errs) }) } }