1 package services
2
3 import (
4 "context"
5 "database/sql/driver"
6 "fmt"
7 "strings"
8 "testing"
9
10 "github.com/DATA-DOG/go-sqlmock"
11 "github.com/stretchr/testify/assert"
12
13 "edge-infra.dev/pkg/edge/api/graph/model"
14 sqlquery "edge-infra.dev/pkg/edge/api/sql"
15 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
16 )
17
18 func TestCreatePrivileges(t *testing.T) {
19 t.Parallel()
20 tests := map[string]struct {
21 privileges []*model.OperatorInterventionPrivilegeInput
22 expected *model.CreateOperatorInterventionPrivilegeResponse
23 }{
24 "Single Privilege": {
25 privileges: []*model.OperatorInterventionPrivilegeInput{
26 {Name: "privilege1"},
27 },
28 expected: &model.CreateOperatorInterventionPrivilegeResponse{},
29 },
30 "Multiple Privileges": {
31 privileges: []*model.OperatorInterventionPrivilegeInput{
32 {Name: "privilege1"},
33 {Name: "privilege2"},
34 {Name: "privilege3"},
35 },
36 expected: &model.CreateOperatorInterventionPrivilegeResponse{},
37 },
38 "No Privileges": {
39 privileges: []*model.OperatorInterventionPrivilegeInput{},
40 expected: &model.CreateOperatorInterventionPrivilegeResponse{
41 Errors: []*model.OperatorInterventionErrorResponse{
42 {Type: model.OperatorInterventionErrorTypeInvalidInput}}},
43 },
44 "Empty Privilege": {
45 privileges: []*model.OperatorInterventionPrivilegeInput{
46 {Name: ""},
47 },
48 expected: &model.CreateOperatorInterventionPrivilegeResponse{
49 Errors: []*model.OperatorInterventionErrorResponse{
50 {Type: model.OperatorInterventionErrorTypeInvalidInput,
51 Privilege: func() *string {
52 priv := ""
53 return &priv
54 }(),
55 }}},
56 },
57 "Invalid Privileges": {
58 privileges: []*model.OperatorInterventionPrivilegeInput{
59 {Name: "abc!"},
60 {Name: "abc 123"},
61 {Name: "abc_123"},
62 {Name: "123abc"},
63 {Name: "a"},
64 {Name: "-abc"},
65 },
66 expected: &model.CreateOperatorInterventionPrivilegeResponse{
67 Errors: []*model.OperatorInterventionErrorResponse{
68 {
69 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
70 priv := "abc!"
71 return &priv
72 }(),
73 },
74 {
75 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
76 priv := "abc 123"
77 return &priv
78 }(),
79 },
80 {
81 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
82 priv := "abc_123"
83 return &priv
84 }(),
85 },
86 {
87 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
88 priv := "123abc"
89 return &priv
90 }(),
91 },
92 {
93 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
94 priv := "a"
95 return &priv
96 }(),
97 },
98 {
99 Type: model.OperatorInterventionErrorTypeInvalidInput, Privilege: func() *string {
100 priv := "-abc"
101 return &priv
102 }(),
103 },
104 },
105 },
106 },
107 }
108 for name, tc := range tests {
109 tc := tc
110 t.Run(name, func(t *testing.T) {
111 t.Parallel()
112
113 ctx := context.Background()
114
115
116 mockRulesEngine := &mockRulesEngine{
117 AddPrivilegesFunc: func(_ context.Context, _ []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error) {
118 return rulesengine.AddNameResult{}, nil
119 },
120 }
121
122
123 service := &operatorInterventionService{
124 reng: mockRulesEngine,
125 }
126
127
128 response, err := service.CreatePrivileges(ctx, tc.privileges)
129 assert.NoError(t, err)
130
131
132 assert.Equal(t, tc.expected, response)
133 })
134 }
135 }
136
137 func TestCreateCommands(t *testing.T) {
138 t.Parallel()
139
140 tests := map[string]struct {
141 input []*model.OperatorInterventionCommandInput
142 expected *model.CreateOperatorInterventionCommandResponse
143 }{
144 "Single Command": {
145 input: []*model.OperatorInterventionCommandInput{
146 {Name: "command1"},
147 },
148 expected: &model.CreateOperatorInterventionCommandResponse{},
149 },
150 "Valid": {
151 input: []*model.OperatorInterventionCommandInput{
152 {Name: "command.1"},
153 {Name: "command-2"},
154 {Name: "command_3"},
155 {Name: "command4"},
156 {Name: "/command5"},
157 },
158 expected: &model.CreateOperatorInterventionCommandResponse{},
159 },
160 "Invalid": {
161 input: []*model.OperatorInterventionCommandInput{
162 {Name: "command 1"},
163 {Name: "command!2"},
164 },
165 expected: &model.CreateOperatorInterventionCommandResponse{
166 Errors: []*model.OperatorInterventionErrorResponse{
167 {Type: model.OperatorInterventionErrorTypeInvalidInput, Command: func() *string {
168 command := "command 1"
169 return &command
170 }(),
171 },
172 {Type: model.OperatorInterventionErrorTypeInvalidInput, Command: func() *string {
173 command := "command!2"
174 return &command
175 }(),
176 },
177 },
178 },
179 },
180 "Empty": {
181 input: []*model.OperatorInterventionCommandInput{
182 {Name: ""},
183 },
184 expected: &model.CreateOperatorInterventionCommandResponse{
185 Errors: []*model.OperatorInterventionErrorResponse{
186 {Type: model.OperatorInterventionErrorTypeInvalidInput,
187 Command: func() *string {
188 command := ""
189 return &command
190 }(),
191 },
192 },
193 },
194 },
195 "Nil": {
196 expected: &model.CreateOperatorInterventionCommandResponse{
197 Errors: []*model.OperatorInterventionErrorResponse{
198 {Type: model.OperatorInterventionErrorTypeInvalidInput},
199 },
200 },
201 },
202 }
203
204 for name, tc := range tests {
205 tc := tc
206 t.Run(name, func(t *testing.T) {
207 t.Parallel()
208
209 mockRulesEngine := &mockRulesEngine{
210 AddCommandsFunc: func(_ context.Context, _ []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error) {
211 return rulesengine.AddNameResult{}, nil
212 },
213 }
214
215 o := operatorInterventionService{
216 reng: mockRulesEngine,
217 }
218
219
220 resp, err := o.CreateCommands(context.Background(), tc.input)
221 assert.NoError(t, err)
222 assert.Equal(t, tc.expected, resp)
223 })
224 }
225 }
226
227 func TestDeletePrivileges(t *testing.T) {
228 t.Parallel()
229
230 tests := map[string]struct {
231 privilege string
232 expected *model.DeleteOperatorInterventionPrivilegeResponse
233 }{
234 "Empty Privilege": {
235 privilege: "",
236 expected: &model.DeleteOperatorInterventionPrivilegeResponse{
237 Errors: []*model.OperatorInterventionErrorResponse{
238 {Type: model.OperatorInterventionErrorTypeInvalidInput},
239 },
240 },
241 },
242 "Valid Privilege": {
243 privilege: "privilege1",
244 expected: &model.DeleteOperatorInterventionPrivilegeResponse{},
245 },
246 "Non-Existing Privilege": {
247 privilege: "nonexistingprivilege",
248 expected: &model.DeleteOperatorInterventionPrivilegeResponse{
249 Errors: []*model.OperatorInterventionErrorResponse{
250 {Type: model.OperatorInterventionErrorTypeUnknownPrivilege, Privilege: func() *string {
251 priv := "nonexistingprivilege"
252 return &priv
253 }()},
254 },
255 },
256 },
257 }
258
259 for name, tc := range tests {
260 tc := tc
261 t.Run(name, func(t *testing.T) {
262 t.Parallel()
263
264 ctx := context.Background()
265
266
267
268 mockRulesEngine := &mockRulesEngine{
269 DeletePrivilegeFunc: func(_ context.Context, privilege string) (rulesengine.DeleteResult, error) {
270 if privilege == "nonexistingprivilege" {
271 return rulesengine.DeleteResult{Errors: []rulesengine.Error{
272 {Type: rulesengine.UnknownPrivilege, Privilege: privilege},
273 }}, nil
274 }
275 return rulesengine.DeleteResult{}, nil
276 },
277 }
278
279
280 service := &operatorInterventionService{
281 reng: mockRulesEngine,
282 }
283
284
285 response, err := service.DeletePrivilege(ctx, model.OperatorInterventionPrivilegeInput{Name: tc.privilege})
286 assert.NoError(t, err)
287
288
289 assert.Equal(t, tc.expected, response)
290 })
291 }
292 }
293
294 func TestDecomposeRoleMappings(t *testing.T) {
295 t.Parallel()
296
297 var INVALIDEDGEROLE = "INVALID_EDGE_ROLE"
298 var EmptyString = ""
299
300 tests := map[string]struct {
301 addOiRoleMappingInput []*model.UpdateOperatorInterventionRoleMappingInput
302
303 roles []string
304 privs []string
305 errors []*model.OperatorInterventionErrorResponse
306 }{
307 "Empty": {
308 addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{},
309 roles: nil,
310 privs: nil,
311 errors: nil,
312 },
313 "Single Mapping": {
314 addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{
315 {
316 Role: "EDGE_BANNER_ADMIN",
317 Privileges: []*model.OperatorInterventionPrivilegeInput{
318 {Name: "ea-admin"},
319 }},
320 },
321
322 roles: []string{"EDGE_BANNER_ADMIN"},
323 privs: []string{"ea-admin"},
324 errors: nil,
325 },
326 "Multiple Mappings": {
327 addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{
328 {
329 Role: "EDGE_BANNER_ADMIN",
330 Privileges: []*model.OperatorInterventionPrivilegeInput{
331 {Name: "ea-admin"},
332 {Name: "ea-read"},
333 },
334 },
335 {
336 Role: "EDGE_ORG_ADMIN",
337 Privileges: []*model.OperatorInterventionPrivilegeInput{
338 {Name: "ea-write"},
339 {Name: "ea-basic"},
340 },
341 },
342 },
343
344 roles: []string{"EDGE_BANNER_ADMIN", "EDGE_BANNER_ADMIN", "EDGE_ORG_ADMIN", "EDGE_ORG_ADMIN"},
345 privs: []string{"ea-admin", "ea-read", "ea-write", "ea-basic"},
346 errors: nil,
347 },
348 "Unknown Role and privilege": {
349 addOiRoleMappingInput: []*model.UpdateOperatorInterventionRoleMappingInput{
350 {
351 Role: "INVALID_EDGE_ROLE",
352 Privileges: []*model.OperatorInterventionPrivilegeInput{
353 {Name: "ea-admin"},
354 {Name: "ea-read"},
355 },
356 },
357 {
358 Role: "EDGE_ORG_ADMIN",
359 Privileges: []*model.OperatorInterventionPrivilegeInput{
360 {Name: ""},
361 {Name: "ea-basic"},
362 },
363 },
364 },
365
366 roles: []string{"EDGE_ORG_ADMIN"},
367 privs: []string{"ea-basic"},
368 errors: []*model.OperatorInterventionErrorResponse{
369 {Type: model.OperatorInterventionErrorTypeUnknownRole, Role: &INVALIDEDGEROLE},
370 {Type: model.OperatorInterventionErrorTypeUnknownPrivilege, Privilege: &EmptyString},
371 },
372 },
373 }
374
375 for name, tc := range tests {
376 tc := tc
377 t.Run(name, func(t *testing.T) {
378 t.Parallel()
379
380 roles, privs, errors := decomposeRoleMappings(tc.addOiRoleMappingInput)
381 assert.Equal(t, len(roles), len(privs), "Expected roles and privileges to be of equivalent length")
382
383 assert.Equal(t, tc.roles, roles, "unexpected roles")
384 assert.Equal(t, tc.privs, privs, "unexpected privileges")
385 assert.Equal(t, tc.errors, errors, "unexpected errors")
386 })
387 }
388 }
389
390
391 type mockRulesEngine struct {
392 ReadPrivilegesWithFilterFunc func(ctx context.Context, names []string) ([]rulesengine.Privilege, error)
393 AddPrivilegesFunc func(ctx context.Context, payload []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error)
394 DeletePrivilegeFunc func(ctx context.Context, privilege string) (rulesengine.DeleteResult, error)
395 GetDefaultRulesFunc func(ctx context.Context, privileges ...string) ([]rulesengine.ReturnRuleSet, error)
396 AddDefaultRulesForPrivilegesFunc func(ctx context.Context, ruleset rulesengine.RuleSets) (rulesengine.AddRuleResult, error)
397 DeleteDefaultRuleFunc func(ctx context.Context, command, privilege string) (rulesengine.DeleteResult, error)
398 AddCommandsFunc func(ctx context.Context, payload []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error)
399 rulesEngine
400 }
401
402 func (m *mockRulesEngine) ReadPrivilegesWithFilter(ctx context.Context, names []string) ([]rulesengine.Privilege, error) {
403 return m.ReadPrivilegesWithFilterFunc(ctx, names)
404 }
405
406 func (m *mockRulesEngine) AddPrivileges(ctx context.Context, payload []rulesengine.PostPrivilegePayload) (rulesengine.AddNameResult, error) {
407 return m.AddPrivilegesFunc(ctx, payload)
408 }
409
410 func (m *mockRulesEngine) DeletePrivilege(ctx context.Context, privilege string) (rulesengine.DeleteResult, error) {
411 return m.DeletePrivilegeFunc(ctx, privilege)
412 }
413
414 func (m *mockRulesEngine) AddCommands(ctx context.Context, payload []rulesengine.PostCommandPayload) (rulesengine.AddNameResult, error) {
415 return m.AddCommandsFunc(ctx, payload)
416 }
417
418 func (m *mockRulesEngine) GetDefaultRules(ctx context.Context, privileges ...string) ([]rulesengine.ReturnRuleSet, error) {
419 return m.GetDefaultRulesFunc(ctx, privileges...)
420 }
421
422 func (m *mockRulesEngine) AddDefaultRulesForPrivileges(ctx context.Context, ruleset rulesengine.RuleSets) (rulesengine.AddRuleResult, error) {
423 return m.AddDefaultRulesForPrivilegesFunc(ctx, ruleset)
424 }
425
426 func (m *mockRulesEngine) DeleteDefaultRule(ctx context.Context, command, privilege string) (rulesengine.DeleteResult, error) {
427 return m.DeleteDefaultRuleFunc(ctx, command, privilege)
428 }
429
430 func TestGenerateQueryParameters(t *testing.T) {
431 t.Parallel()
432
433 tests := map[string]struct {
434 roles []string
435 privileges []string
436
437 params []string
438 args []any
439 }{
440 "Simple": {
441 roles: []string{"EDGE_ORG_ADMIN"},
442 privileges: []string{"ea-basic"},
443
444 params: []string{"($1, $2)"},
445 args: []any{"EDGE_ORG_ADMIN", "ea-basic"},
446 },
447 "Multiple": {
448 roles: []string{"EDGE_ORG_ADMIN", "EDGE_ORG_ADMIN", "EDGE_BANNER_ADMIN"},
449 privileges: []string{"ea-basic", "ea-read", "ea-write"},
450
451 params: []string{"($1, $2)", "($3, $4)", "($5, $6)"},
452 args: []any{"EDGE_ORG_ADMIN", "ea-basic", "EDGE_ORG_ADMIN", "ea-read", "EDGE_BANNER_ADMIN", "ea-write"},
453 },
454 }
455
456 for name, tc := range tests {
457 tc := tc
458 t.Run(name, func(t *testing.T) {
459 t.Parallel()
460
461 params, args := generateQueryParameters(tc.roles, tc.privileges)
462
463
464
465 assert.Equal(t, len(params), len(args)/2, "Expected params slice to contain half the number of elements of args")
466
467 assert.Equal(t, tc.params, params)
468 assert.Equal(t, tc.args, args)
469 })
470 }
471 }
472
473
474 type StringSliceValueConverter struct{}
475
476
477 func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
478 if vv, ok := v.([]string); ok {
479
480
481 arrayStr := "{" + strings.Join(vv, ",") + "}"
482 return arrayStr, nil
483 }
484
485 return driver.DefaultParameterConverter.ConvertValue(v)
486 }
487
488 func TestFindMissingPrivileges(t *testing.T) {
489 t.Parallel()
490
491 tests := map[string]struct {
492 allPrivileges []string
493 dbContents []string
494 mockAssertions func(sqlmock.Sqlmock)
495
496 errorAssertion assert.ErrorAssertionFunc
497 expectedMissingPrivs []string
498 }{
499 "All overlap": {
500 allPrivileges: []string{"a", "b"},
501 dbContents: []string{"a", "b"},
502
503 errorAssertion: assert.NoError,
504 expectedMissingPrivs: []string{},
505 },
506 "No Overlap": {
507 allPrivileges: []string{"a", "b"},
508 dbContents: []string{"c", "d"},
509
510 errorAssertion: assert.NoError,
511 expectedMissingPrivs: []string{"a", "b"},
512 },
513 "Partial overlap": {
514 allPrivileges: []string{"a", "c"},
515 dbContents: []string{"c", "d"},
516
517 errorAssertion: assert.NoError,
518 expectedMissingPrivs: []string{"a"},
519 },
520 "Empty DB": {
521 allPrivileges: []string{"a", "c"},
522 dbContents: []string{},
523
524 errorAssertion: assert.NoError,
525 expectedMissingPrivs: []string{"a", "c"},
526 },
527
528 "Query Error": {
529 allPrivileges: []string{"a", "b"},
530 mockAssertions: func(s sqlmock.Sqlmock) {
531 s.ExpectQuery(sqlquery.GetOiPrivilegesSubset).
532 WithArgs([]string{"a", "b"}).
533 WillReturnError(fmt.Errorf("error"))
534 },
535 errorAssertion: assert.Error,
536 expectedMissingPrivs: nil,
537 },
538 "Rows Close Error": {
539 allPrivileges: []string{"a", "b"},
540 mockAssertions: func(s sqlmock.Sqlmock) {
541 s.ExpectQuery(sqlquery.GetOiPrivilegesSubset).
542 WithArgs([]string{"a", "b"}).
543 WillReturnRows(sqlmock.NewRows([]string{"privilege_name"}).
544 AddRow("a").
545 CloseError(fmt.Errorf("error")),
546 ).
547 RowsWillBeClosed()
548 },
549 errorAssertion: assert.Error,
550 expectedMissingPrivs: nil,
551 },
552 "Scan error": {
553 allPrivileges: []string{"a", "b"},
554 mockAssertions: func(s sqlmock.Sqlmock) {
555 s.ExpectQuery(sqlquery.GetOiPrivilegesSubset).
556 WithArgs([]string{"a", "b"}).
557 WillReturnRows(sqlmock.NewRows([]string{"privilege_name"}).
558 AddRow("a").
559 AddRow("b").
560 AddRow("c").
561 RowError(1, fmt.Errorf("error")),
562 ).
563 RowsWillBeClosed()
564 },
565 errorAssertion: assert.Error,
566 expectedMissingPrivs: nil,
567 },
568 }
569
570 for name, tc := range tests {
571 tc := tc
572 t.Run(name, func(t *testing.T) {
573 t.Parallel()
574
575 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{}))
576 if err != nil {
577 t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
578 }
579 defer db.Close()
580
581 mock.ExpectBegin()
582
583 if tc.mockAssertions != nil {
584 tc.mockAssertions(mock)
585 } else {
586 rows := sqlmock.NewRows([]string{"privilege_name"})
587 for _, priv := range tc.dbContents {
588 rows = rows.AddRow(priv)
589 }
590 mock.ExpectQuery(sqlquery.GetOiPrivilegesSubset).
591 WithArgs(tc.allPrivileges).
592 WillReturnRows(rows).
593 RowsWillBeClosed()
594 }
595
596 transaction, err := db.BeginTx(context.Background(), nil)
597 assert.NoError(t, err)
598
599 out, err := findMissingPrivs(context.Background(), transaction, tc.allPrivileges)
600 tc.errorAssertion(t, err)
601
602 assert.Equal(t, tc.expectedMissingPrivs, out)
603
604 assert.NoError(t, mock.ExpectationsWereMet())
605 })
606 }
607 }
608
609 func TestDifference(t *testing.T) {
610
611 t.Parallel()
612
613 tests := map[string]struct {
614 superset []string
615 subset []string
616 exp []string
617 }{
618 "No Entries": {
619 superset: []string{},
620 subset: []string{},
621 exp: []string{},
622 },
623 "Standard": {
624 superset: []string{"a", "b", "c", "d"},
625 subset: []string{"a", "b"},
626 exp: []string{"c", "d"},
627 },
628 "Equal sets": {
629 superset: []string{"a", "b"},
630 subset: []string{"a", "b"},
631 exp: []string{},
632 },
633 "No overlap": {
634 superset: []string{"a", "b"},
635 subset: []string{"c", "d"},
636 exp: []string{"a", "b"},
637 },
638 }
639
640 for name, tc := range tests {
641 tc := tc
642 t.Run(name, func(t *testing.T) {
643 t.Parallel()
644
645 diff := difference(tc.superset, tc.subset)
646 assert.Equal(t, tc.exp, diff)
647 })
648 }
649 }
650
651 type mockRulesEngineTestDeleteCommandErrorType struct {
652 rulesEngine
653 errType rulesengine.ErrorType
654 }
655
656 func (reng mockRulesEngineTestDeleteCommandErrorType) DeleteCommand(_ context.Context, command string) (rulesengine.DeleteResult, error) {
657 return rulesengine.DeleteResult{
658 Errors: []rulesengine.Error{
659 {Type: reng.errType, Command: command},
660 },
661 }, nil
662 }
663
664 func TestDeleteCommandErrorType(t *testing.T) {
665 t.Parallel()
666
667 command := "command"
668
669 tests := map[string]struct {
670 command string
671 errType rulesengine.ErrorType
672 expected *model.DeleteOperatorInterventionCommandResponse
673 }{
674 "Conflict": {
675 command: command,
676 errType: rulesengine.Conflict,
677 expected: &model.DeleteOperatorInterventionCommandResponse{
678 Errors: []*model.OperatorInterventionErrorResponse{
679 {
680 Type: model.OperatorInterventionErrorTypeConflict,
681 Command: &command,
682 },
683 },
684 },
685 },
686 "Existing Command": {
687 command: command,
688 errType: rulesengine.UnknownCommand,
689 expected: &model.DeleteOperatorInterventionCommandResponse{
690 Errors: []*model.OperatorInterventionErrorResponse{
691 {
692 Type: model.OperatorInterventionErrorTypeUnknownCommand,
693 Command: &command,
694 },
695 },
696 },
697 },
698 "Empty Command": {
699 expected: &model.DeleteOperatorInterventionCommandResponse{
700 Errors: []*model.OperatorInterventionErrorResponse{
701 {
702 Type: model.OperatorInterventionErrorTypeInvalidInput,
703 },
704 },
705 },
706 },
707 }
708
709 for name, tc := range tests {
710 tc := tc
711 t.Run(name, func(t *testing.T) {
712 t.Parallel()
713
714 mockReng := mockRulesEngineTestDeleteCommandErrorType{errType: tc.errType}
715 o := operatorInterventionService{reng: mockReng}
716
717 payload := model.OperatorInterventionCommandInput{Name: tc.command}
718 resp, err := o.DeleteCommand(context.Background(), payload)
719
720 assert.NoError(t, err)
721 assert.Equal(t, tc.expected, resp)
722 })
723 }
724 }
725
View as plain text