package services import ( "context" "database/sql/driver" "fmt" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/edge/api/graph/model" sqlquery "edge-infra.dev/pkg/edge/api/sql" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" ) func TestCreatePrivileges(t *testing.T) { t.Parallel() tests := map[string]struct { privileges []*model.OperatorInterventionPrivilegeInput expected *model.CreateOperatorInterventionPrivilegeResponse }{ "Single Privilege": { privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "privilege1"}, }, expected: &model.CreateOperatorInterventionPrivilegeResponse{}, }, "Multiple Privileges": { privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "privilege1"}, {Name: "privilege2"}, {Name: "privilege3"}, }, expected: &model.CreateOperatorInterventionPrivilegeResponse{}, }, "No Privileges": { privileges: []*model.OperatorInterventionPrivilegeInput{}, expected: &model.CreateOperatorInterventionPrivilegeResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput}}}, }, "Empty Privilege": { privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: ""}, }, expected: &model.CreateOperatorInterventionPrivilegeResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "" return &priv }(), }}}, }, "Invalid Privileges": { privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "abc!"}, // special char {Name: "abc 123"}, // space {Name: "abc_123"}, // underscore {Name: "123abc"}, // starts with a number {Name: "a"}, // too short {Name: "-abc"}, // starts with a special char }, expected: &model.CreateOperatorInterventionPrivilegeResponse{ Errors: []*model.OperatorInterventionErrorResponse{ { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "abc!" return &priv }(), }, { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "abc 123" return &priv }(), }, { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "abc_123" return &priv }(), }, { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "123abc" return &priv }(), }, { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "a" return &priv }(), }, { Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string { priv := "-abc" return &priv }(), }, }, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ctx := context.Background() // Create a mock RulesEngine mockRulesEngine := &mockRulesEngine{ AddPrivilegesFunc: func(_ context.Context, _ []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error) { return rulesengine.AddNameResult{}, nil }, } // Create a new instance of the operatorInterventionService with the mockRulesEngine service := &operatorInterventionService{ reng: mockRulesEngine, } // Call the CreateOperatorInterventionPrivileges method response, err := service.CreatePrivileges(ctx, tc.privileges) assert.NoError(t, err) // Check the response assert.Equal(t, tc.expected, response) }) } } func TestCreateCommands(t *testing.T) { t.Parallel() tests := map[string]struct { input []*model.OperatorInterventionCommandInput expected *model.CreateOperatorInterventionCommandResponse }{ "Single Command": { input: []*model.OperatorInterventionCommandInput{ {Name: "command1"}, }, expected: &model.CreateOperatorInterventionCommandResponse{}, }, "Valid": { input: []*model.OperatorInterventionCommandInput{ {Name: "command.1"}, // period {Name: "command-2"}, // hyphen {Name: "command_3"}, // underscore {Name: "command4"}, // no special characters {Name: "/command5"}, // slash }, expected: &model.CreateOperatorInterventionCommandResponse{}, }, "Invalid": { input: []*model.OperatorInterventionCommandInput{ {Name: "command 1"}, // space {Name: "command!2"}, // other special characters }, expected: &model.CreateOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput, Command: func() *string { command := "command 1" return &command }(), }, {Type: model.OperatorInterventionErrorTypeInvalidInput, Command: func() *string { command := "command!2" return &command }(), }, }, }, }, "Empty": { input: []*model.OperatorInterventionCommandInput{ {Name: ""}, }, expected: &model.CreateOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput, Command: func() *string { command := "" return &command }(), }, }, }, }, "Nil": { expected: &model.CreateOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput}, }, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() // Create a mock RulesEngine mockRulesEngine := &mockRulesEngine{ AddCommandsFunc: func(_ context.Context, _ []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error) { return rulesengine.AddNameResult{}, nil }, } // Create a new instance of the operatorInterventionService with the mockRulesEngine o := operatorInterventionService{ reng: mockRulesEngine, } // Call the CreateOperatorInterventionCommands method and check the response resp, err := o.CreateCommands(context.Background(), tc.input) assert.NoError(t, err) assert.Equal(t, tc.expected, resp) }) } } func TestDeletePrivileges(t *testing.T) { t.Parallel() tests := map[string]struct { privilege string expected *model.DeleteOperatorInterventionPrivilegeResponse }{ "Empty Privilege": { privilege: "", expected: &model.DeleteOperatorInterventionPrivilegeResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeInvalidInput}, }, }, }, "Valid Privilege": { privilege: "privilege1", expected: &model.DeleteOperatorInterventionPrivilegeResponse{}, }, "Non-Existing Privilege": { privilege: "nonexistingprivilege", expected: &model.DeleteOperatorInterventionPrivilegeResponse{ Errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeUnknownPrivilege, Privilege: func() *string { priv := "nonexistingprivilege" return &priv }()}, }, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ctx := context.Background() // Create a mock RulesEngine that checks specifically for "nonexistingprivilege" and returns an error // if true, else returns no error mockRulesEngine := &mockRulesEngine{ DeletePrivilegeFunc: func(_ context.Context, privilege string) (rulesengine.DeleteResult, error) { if privilege == "nonexistingprivilege" { return rulesengine.DeleteResult{Errors: []rulesengine.Error{ {Type: rulesengine.UnknownPrivilege, Privilege: privilege}, }}, nil } return rulesengine.DeleteResult{}, nil }, } // Create a new instance of the operatorInterventionService with the mockRulesEngine service := &operatorInterventionService{ reng: mockRulesEngine, } // Call the DeleteOperatorInterventionPrivilege method response, err := service.DeletePrivilege(ctx, model.OperatorInterventionPrivilegeInput{Name: tc.privilege}) assert.NoError(t, err) // Check the response assert.Equal(t, tc.expected, response) }) } } func TestDecomposeRoleMappings(t *testing.T) { t.Parallel() var INVALIDEDGEROLE = "INVALID_EDGE_ROLE" var EmptyString = "" tests := map[string]struct { addOiRoleMappingInput []*model.UpdateOperatorInterventionRoleMappingInput roles []string privs []string errors []*model.OperatorInterventionErrorResponse }{ "Empty": { addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{}, roles: nil, privs: nil, errors: nil, }, "Single Mapping": { addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{ { Role: "EDGE_BANNER_ADMIN", Privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "ea-admin"}, }}, }, roles: []string{"EDGE_BANNER_ADMIN"}, privs: []string{"ea-admin"}, errors: nil, }, "Multiple Mappings": { addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{ { Role: "EDGE_BANNER_ADMIN", Privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "ea-admin"}, {Name: "ea-read"}, }, }, { Role: "EDGE_ORG_ADMIN", Privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "ea-write"}, {Name: "ea-basic"}, }, }, }, roles: []string{"EDGE_BANNER_ADMIN", "EDGE_BANNER_ADMIN", "EDGE_ORG_ADMIN", "EDGE_ORG_ADMIN"}, privs: []string{"ea-admin", "ea-read", "ea-write", "ea-basic"}, errors: nil, }, "Unknown Role and privilege": { addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{ { Role: "INVALID_EDGE_ROLE", Privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: "ea-admin"}, {Name: "ea-read"}, }, }, { Role: "EDGE_ORG_ADMIN", Privileges: []*model.OperatorInterventionPrivilegeInput{ {Name: ""}, {Name: "ea-basic"}, }, }, }, roles: []string{"EDGE_ORG_ADMIN"}, privs: []string{"ea-basic"}, errors: []*model.OperatorInterventionErrorResponse{ {Type: model.OperatorInterventionErrorTypeUnknownRole, Role: &INVALIDEDGEROLE}, {Type: model.OperatorInterventionErrorTypeUnknownPrivilege, Privilege: &EmptyString}, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() roles, privs, errors := decomposeRoleMappings(tc.addOiRoleMappingInput) assert.Equal(t, len(roles), len(privs), "Expected roles and privileges to be of equivalent length") assert.Equal(t, tc.roles, roles, "unexpected roles") assert.Equal(t, tc.privs, privs, "unexpected privileges") assert.Equal(t, tc.errors, errors, "unexpected errors") }) } } // Mock implementation of the RulesEngine interface type mockRulesEngine struct { ReadPrivilegesWithFilterFunc func(ctx context.Context, names []string) ([]rulesengine.Privilege, error) AddPrivilegesFunc func(ctx context.Context, payload []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error) DeletePrivilegeFunc func(ctx context.Context, privilege string) (rulesengine.DeleteResult, error) GetDefaultRulesFunc func(ctx context.Context, privileges ...string) ([]rulesengine.ReturnRuleSet, error) AddDefaultRulesForPrivilegesFunc func(ctx context.Context, ruleset rulesengine.RuleSets) (rulesengine.AddRuleResult, error) DeleteDefaultRuleFunc func(ctx context.Context, command, privilege string) (rulesengine.DeleteResult, error) AddCommandsFunc func(ctx context.Context, payload []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error) rulesEngine } func (m *mockRulesEngine) ReadPrivilegesWithFilter(ctx context.Context, names []string) ([]rulesengine.Privilege, error) { return m.ReadPrivilegesWithFilterFunc(ctx, names) } func (m *mockRulesEngine) AddPrivileges(ctx context.Context, payload []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error) { return m.AddPrivilegesFunc(ctx, payload) } func (m *mockRulesEngine) DeletePrivilege(ctx context.Context, privilege string) (rulesengine.DeleteResult, error) { return m.DeletePrivilegeFunc(ctx, privilege) } func (m *mockRulesEngine) AddCommands(ctx context.Context, payload []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error) { return m.AddCommandsFunc(ctx, payload) } func (m *mockRulesEngine) GetDefaultRules(ctx context.Context, privileges ...string) ([]rulesengine.ReturnRuleSet, error) { return m.GetDefaultRulesFunc(ctx, privileges...) } func (m *mockRulesEngine) AddDefaultRulesForPrivileges(ctx context.Context, ruleset rulesengine.RuleSets) (rulesengine.AddRuleResult, error) { return m.AddDefaultRulesForPrivilegesFunc(ctx, ruleset) } func (m *mockRulesEngine) DeleteDefaultRule(ctx context.Context, command, privilege string) (rulesengine.DeleteResult, error) { return m.DeleteDefaultRuleFunc(ctx, command, privilege) } func TestGenerateQueryParameters(t *testing.T) { t.Parallel() tests := map[string]struct { roles []string privileges []string params []string args []any }{ "Simple": { roles: []string{"EDGE_ORG_ADMIN"}, privileges: []string{"ea-basic"}, params: []string{"($1, $2)"}, args: []any{"EDGE_ORG_ADMIN", "ea-basic"}, }, "Multiple": { roles: []string{"EDGE_ORG_ADMIN", "EDGE_ORG_ADMIN", "EDGE_BANNER_ADMIN"}, privileges: []string{"ea-basic", "ea-read", "ea-write"}, params: []string{"($1, $2)", "($3, $4)", "($5, $6)"}, args: []any{"EDGE_ORG_ADMIN", "ea-basic", "EDGE_ORG_ADMIN", "ea-read", "EDGE_BANNER_ADMIN", "ea-write"}, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() params, args := generateQueryParameters(tc.roles, tc.privileges) // Each entry in params (e.g. `($1, $2)`) contains two placeholders, // so should be half the length of args assert.Equal(t, len(params), len(args)/2, "Expected params slice to contain half the number of elements of args") assert.Equal(t, tc.params, params) assert.Equal(t, tc.args, args) }) } } // StringSliceValueConverter converts a slice of strings to a PostgreSQL array representation. type StringSliceValueConverter struct{} // ConvertValue implements the driver.ValueConverter interface. func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) { if vv, ok := v.([]string); ok { // Convert []string to a PostgreSQL array representation. // Note: Proper escaping and handling of special characters is necessary for production code. arrayStr := "{" + strings.Join(vv, ",") + "}" return arrayStr, nil } // Fallback for other types. return driver.DefaultParameterConverter.ConvertValue(v) } func TestFindMissingPrivileges(t *testing.T) { t.Parallel() tests := map[string]struct { allPrivileges []string dbContents []string mockAssertions func(sqlmock.Sqlmock) // If mockAssertions is not-nil, dbContents is not read errorAssertion assert.ErrorAssertionFunc expectedMissingPrivs []string }{ "All overlap": { allPrivileges: []string{"a", "b"}, dbContents: []string{"a", "b"}, errorAssertion: assert.NoError, expectedMissingPrivs: []string{}, }, "No Overlap": { allPrivileges: []string{"a", "b"}, dbContents: []string{"c", "d"}, errorAssertion: assert.NoError, expectedMissingPrivs: []string{"a", "b"}, }, "Partial overlap": { allPrivileges: []string{"a", "c"}, dbContents: []string{"c", "d"}, errorAssertion: assert.NoError, expectedMissingPrivs: []string{"a"}, }, "Empty DB": { allPrivileges: []string{"a", "c"}, dbContents: []string{}, errorAssertion: assert.NoError, expectedMissingPrivs: []string{"a", "c"}, }, "Query Error": { allPrivileges: []string{"a", "b"}, mockAssertions: func(s sqlmock.Sqlmock) { s.ExpectQuery(sqlquery.GetOiPrivilegesSubset). WithArgs([]string{"a", "b"}). WillReturnError(fmt.Errorf("error")) }, errorAssertion: assert.Error, expectedMissingPrivs: nil, }, "Rows Close Error": { allPrivileges: []string{"a", "b"}, mockAssertions: func(s sqlmock.Sqlmock) { s.ExpectQuery(sqlquery.GetOiPrivilegesSubset). WithArgs([]string{"a", "b"}). WillReturnRows(sqlmock.NewRows([]string{"privilege_name"}). AddRow("a"). CloseError(fmt.Errorf("error")), ). RowsWillBeClosed() }, errorAssertion: assert.Error, expectedMissingPrivs: nil, }, "Scan error": { allPrivileges: []string{"a", "b"}, mockAssertions: func(s sqlmock.Sqlmock) { s.ExpectQuery(sqlquery.GetOiPrivilegesSubset). WithArgs([]string{"a", "b"}). WillReturnRows(sqlmock.NewRows([]string{"privilege_name"}). AddRow("a"). AddRow("b"). AddRow("c"). RowError(1, fmt.Errorf("error")), ). RowsWillBeClosed() }, errorAssertion: assert.Error, expectedMissingPrivs: nil, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{})) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() mock.ExpectBegin() if tc.mockAssertions != nil { tc.mockAssertions(mock) } else { rows := sqlmock.NewRows([]string{"privilege_name"}) for _, priv := range tc.dbContents { rows = rows.AddRow(priv) } mock.ExpectQuery(sqlquery.GetOiPrivilegesSubset). WithArgs(tc.allPrivileges). WillReturnRows(rows). RowsWillBeClosed() } transaction, err := db.BeginTx(context.Background(), nil) assert.NoError(t, err) out, err := findMissingPrivs(context.Background(), transaction, tc.allPrivileges) tc.errorAssertion(t, err) assert.Equal(t, tc.expectedMissingPrivs, out) assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestDifference(t *testing.T) { // Test the set difference function defined in this package t.Parallel() tests := map[string]struct { superset []string subset []string exp []string }{ "No Entries": { superset: []string{}, subset: []string{}, exp: []string{}, }, "Standard": { superset: []string{"a", "b", "c", "d"}, subset: []string{"a", "b"}, exp: []string{"c", "d"}, }, "Equal sets": { superset: []string{"a", "b"}, subset: []string{"a", "b"}, exp: []string{}, }, "No overlap": { superset: []string{"a", "b"}, subset: []string{"c", "d"}, exp: []string{"a", "b"}, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() diff := difference(tc.superset, tc.subset) assert.Equal(t, tc.exp, diff) }) } } type mockRulesEngineTestDeleteCommandErrorType struct { rulesEngine errType rulesengine.ErrorType } func (reng mockRulesEngineTestDeleteCommandErrorType) DeleteCommand(_ context.Context, command string) (rulesengine.DeleteResult, error) { return rulesengine.DeleteResult{ Errors: []rulesengine.Error{ {Type: reng.errType, Command: command}, }, }, nil } func TestDeleteCommandErrorType(t *testing.T) { t.Parallel() command := "command" tests := map[string]struct { command string errType rulesengine.ErrorType expected *model.DeleteOperatorInterventionCommandResponse }{ "Conflict": { command: command, errType: rulesengine.Conflict, expected: &model.DeleteOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ { Type: model.OperatorInterventionErrorTypeConflict, Command: &command, }, }, }, }, "Existing Command": { command: command, errType: rulesengine.UnknownCommand, expected: &model.DeleteOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ { Type: model.OperatorInterventionErrorTypeUnknownCommand, Command: &command, }, }, }, }, "Empty Command": { expected: &model.DeleteOperatorInterventionCommandResponse{ Errors: []*model.OperatorInterventionErrorResponse{ { Type: model.OperatorInterventionErrorTypeInvalidInput, }, }, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() mockReng := mockRulesEngineTestDeleteCommandErrorType{errType: tc.errType} o := operatorInterventionService{reng: mockReng} payload := model.OperatorInterventionCommandInput{Name: tc.command} resp, err := o.DeleteCommand(context.Background(), payload) assert.NoError(t, err) assert.Equal(t, tc.expected, resp) }) } }