1 package server
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11
12 rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules"
13
14 "github.com/gin-gonic/gin"
15 "github.com/google/shlex"
16 "github.com/stretchr/testify/assert"
17 )
18
19 type MockRulesEngine struct {
20 AddedNames map[string][]string
21 Conflict bool
22 RulesEngine
23 }
24
25 func NewMockReng() MockRulesEngine {
26 return MockRulesEngine{AddedNames: map[string][]string{
27 "commands": {},
28 "privs": {},
29 "rules": {},
30 }}
31 }
32
33 func (MockRulesEngine) GetEARolesForCommand(_ context.Context, _ rulesengine.Command, _ string) ([]string, error) {
34 return []string{}, nil
35 }
36
37 func (MockRulesEngine) UserHasRoles(_ string, _ []string, _ []string) bool {
38 return false
39 }
40
41 func (mreng *MockRulesEngine) ReadCommands(_ context.Context) ([]rulesengine.Command, error) {
42 lst := []rulesengine.Command{}
43
44 for _, name := range mreng.AddedNames["commands"] {
45 lst = append(lst, rulesengine.Command{ID: "test", Name: name})
46 }
47 return lst, nil
48 }
49
50 func (mreng *MockRulesEngine) ReadPrivileges(_ context.Context) ([]rulesengine.Privilege, error) {
51 lst := []rulesengine.Privilege{}
52
53 for _, name := range mreng.AddedNames["privs"] {
54 lst = append(lst, rulesengine.Privilege{ID: "test", Name: name})
55 }
56 return lst, nil
57 }
58
59 func (mreng *MockRulesEngine) ReadCommand(_ context.Context, name string) (rulesengine.Command, error) {
60 for _, namein := range mreng.AddedNames["commands"] {
61 if name == namein {
62 return rulesengine.Command{Name: name, ID: "test"}, nil
63 }
64 }
65 return rulesengine.Command{}, nil
66 }
67
68 func (mreng MockRulesEngine) ReadPrivilege(_ context.Context, name string) (rulesengine.Privilege, error) {
69 for _, namein := range mreng.AddedNames["privs"] {
70 if name == namein {
71 return rulesengine.Privilege{Name: name, ID: "test"}, nil
72 }
73 }
74 return rulesengine.Privilege{}, nil
75 }
76
77 func (mreng *MockRulesEngine) ReadAllDefaultRules(_ context.Context) ([]rulesengine.Rule, error) {
78 res := []rulesengine.Rule{}
79 for _, rule := range mreng.AddedNames["rules"] {
80 vals, err := shlex.Split(rule)
81 if err != nil {
82 return res, err
83 }
84 res = append(res, rulesengine.Rule{
85 Command: rulesengine.Command{Name: vals[0], ID: "test"},
86 Privileges: []rulesengine.Privilege{{Name: vals[1], ID: "test"}},
87 })
88 }
89 return res, nil
90 }
91 func (mreng *MockRulesEngine) ReadDefaultRulesForCommand(ctx context.Context, _ string) ([]rulesengine.Rule, error) {
92 return mreng.ReadAllDefaultRules(ctx)
93 }
94
95 func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) {
96 gin.SetMode(gin.TestMode)
97 ctx, ginEngine := gin.CreateTestContext(r)
98 return ctx, ginEngine
99 }
100
101 type postDefaultRulesMock struct {
102 RulesEngine
103
104 dataRet rulesengine.AddRuleResult
105 errRet error
106
107 callCount int
108 rules rulesengine.WriteRules
109 }
110
111 func (pdrb *postDefaultRulesMock) AddDefaultRules(_ context.Context, rules rulesengine.WriteRules) (rulesengine.AddRuleResult, error) {
112 pdrb.callCount = pdrb.callCount + 1
113 pdrb.rules = rules
114 return pdrb.dataRet, pdrb.errRet
115 }
116
117 func TestPostDefaultRules(t *testing.T) {
118 t.Parallel()
119
120 tests := map[string]struct {
121 reqBody string
122
123 mockDataRet rulesengine.AddRuleResult
124 mockErrRet error
125
126 expMockCalledCount int
127 expCalledRules rulesengine.WriteRules
128 expCode int
129
130 jsonAssert StringAssertionFunc
131 }{
132 "Ok": {
133 reqBody: `[
134 {"command": "ls", "privileges": ["read","write"]},
135 {"command": "cat", "privileges": ["read","write"]}
136 ]`,
137
138 expMockCalledCount: 1,
139 expCalledRules: rulesengine.WriteRules{
140 {Command: "ls", Privileges: []string{"read", "write"}},
141 {Command: "cat", Privileges: []string{"read", "write"}},
142 },
143
144 expCode: http.StatusOK,
145 jsonAssert: JSONEmpty(),
146 },
147 "Invalid JSON": {
148 reqBody: `[{"comm`,
149
150 expMockCalledCount: 0,
151
152 expCode: http.StatusBadRequest,
153 jsonAssert: JSONEmpty(),
154 },
155 "Invalid payload": {
156 reqBody: `[{"command": "", "privileges": ["", ""]}]`,
157
158 expMockCalledCount: 0,
159
160 expCode: http.StatusBadRequest,
161 jsonAssert: JSONEmpty(),
162 },
163
164 "Rulesengine Error": {
165 reqBody: `[
166 {"command": "ls", "privileges": ["read","write"]},
167 {"command": "cat", "privileges": ["read","write"]}
168 ]`,
169
170 mockDataRet: rulesengine.AddRuleResult{},
171 mockErrRet: fmt.Errorf("an error occurred"),
172
173 expMockCalledCount: 1,
174 expCalledRules: rulesengine.WriteRules{
175 {Command: "ls", Privileges: []string{"read", "write"}},
176 {Command: "cat", Privileges: []string{"read", "write"}},
177 },
178
179 expCode: http.StatusInternalServerError,
180 jsonAssert: JSONEmpty(),
181 },
182 "Rulesengine Conflict": {
183 reqBody: `[
184 {"command": "not-here", "privileges": ["read","write"]},
185 {"command": "cat", "privileges": ["read","not-here"]}
186 ]`,
187
188 mockDataRet: rulesengine.AddRuleResult{Errors: []rulesengine.Error{
189 {Command: "not-here", Type: rulesengine.UnknownCommand},
190 {Privilege: "not-here", Type: rulesengine.UnknownPrivilege},
191 }},
192 mockErrRet: nil,
193
194 expMockCalledCount: 1,
195 expCalledRules: rulesengine.WriteRules{
196 {Command: "not-here", Privileges: []string{"read", "write"}},
197 {Command: "cat", Privileges: []string{"read", "not-here"}},
198 },
199
200 expCode: http.StatusNotFound,
201 jsonAssert: JSONEq(`{
202 "errors": [
203 {"command":"not-here","type":"Unknown Command"},
204 {"privilege":"not-here","type":"Unknown Privilege"}
205 ]
206 }`),
207 },
208 }
209
210 for name, tc := range tests {
211 tc := tc
212 t.Run(name, func(t *testing.T) {
213 t.Parallel()
214
215 ruleseng := postDefaultRulesMock{
216 dataRet: tc.mockDataRet,
217 errRet: tc.mockErrRet,
218 }
219
220 log := newLogger()
221
222 r := httptest.NewRecorder()
223 _, ginEngine := getTestGinContext(r)
224 _, err := New(ginEngine, &ruleseng, log)
225 assert.NoError(t, err)
226
227 req, err := http.NewRequest(http.MethodPost, "/admin/rules/default/commands", strings.NewReader(tc.reqBody))
228 assert.NoError(t, err)
229
230 ginEngine.ServeHTTP(r, req)
231
232 assert.Equal(t, tc.expCode, r.Result().StatusCode)
233
234 assert.Equal(t, tc.expMockCalledCount, ruleseng.callCount)
235 assert.Equal(t, tc.expCalledRules, ruleseng.rules)
236
237 tc.jsonAssert(t, r.Body.String())
238 })
239 }
240 }
241
242
243 func TestReadAllDefaultRules(t *testing.T) {
244 log := newLogger()
245 t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
246 ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}}
247
248 r := httptest.NewRecorder()
249 _, ginEngine := getTestGinContext(r)
250 _, err := New(ginEngine, &ruleseng, log)
251 assert.Nil(t, err)
252
253 req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil)
254 assert.NoError(t, err)
255 ginEngine.ServeHTTP(r, req)
256 response := r.Result()
257
258 assert.Equal(t, response.StatusCode, http.StatusOK)
259 var respData []rulesengine.Rule
260 err = json.Unmarshal(r.Body.Bytes(), &respData)
261 assert.NoError(t, err)
262 assert.Equal(t, []rulesengine.Rule{{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}}, respData)
263 }
264
265
266 func TestReadDefaultRuleForCommand(t *testing.T) {
267 log := newLogger()
268 t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
269 ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}}
270
271 r := httptest.NewRecorder()
272 _, ginEngine := getTestGinContext(r)
273 _, err := New(ginEngine, &ruleseng, log)
274 assert.Nil(t, err)
275
276 req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands/ls", nil)
277 assert.NoError(t, err)
278 ginEngine.ServeHTTP(r, req)
279 response := r.Result()
280
281 assert.Equal(t, response.StatusCode, http.StatusOK)
282 var respData rulesengine.Rule
283 err = json.Unmarshal(r.Body.Bytes(), &respData)
284 assert.NoError(t, err)
285 assert.Equal(t, rulesengine.Rule{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}, respData)
286 }
287
288
289 func TestReadDefaultRulesNoRules(t *testing.T) {
290 log := newLogger()
291 t.Setenv("RCLI_RES_DATA_DIR", "./testdata")
292 ruleseng := MockRulesEngine{}
293
294 r := httptest.NewRecorder()
295 _, ginEngine := getTestGinContext(r)
296 _, err := New(ginEngine, &ruleseng, log)
297 assert.Nil(t, err)
298
299 req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil)
300 assert.NoError(t, err)
301 ginEngine.ServeHTTP(r, req)
302 response := r.Result()
303
304 assert.Equal(t, response.StatusCode, http.StatusOK)
305 var respData []rulesengine.Rule
306 err = json.Unmarshal(r.Body.Bytes(), &respData)
307 assert.NoError(t, err)
308 assert.Equal(t, []rulesengine.Rule(nil), respData)
309 }
310
311 var (
312 retVal = rulesengine.RuleWithOverrides{
313 Command: rulesengine.Command{
314 Name: "testCommand",
315 ID: "testCommandID",
316 },
317 Banners: []rulesengine.BannerPrivOverrides{{
318 Banner: rulesengine.Banner{
319 BannerName: "testBannerName",
320 BannerID: "testBannerID",
321 },
322 Privileges: []rulesengine.Privilege{
323 {
324 Name: "testPriv1",
325 ID: "testPrivID1",
326 },
327 },
328 }},
329 Default: rulesengine.DefaultRule{Privileges: []rulesengine.Privilege{{
330 Name: "testPriv2",
331 ID: "testPrivID2",
332 }}},
333 }
334
335 retValString = `{
336 "command": {
337 "id": "testCommandID",
338 "name": "testCommand"
339 },
340 "default": {
341 "privileges": [
342 {
343 "id": "testPrivID2",
344 "name": "testPriv2"
345 }
346 ]
347 },
348 "banners": [
349 {
350 "banner": {
351 "id": "testBannerID",
352 "name": "testBannerName"
353 },
354 "privileges": [
355 {
356 "id": "testPrivID1",
357 "name": "testPriv1"
358 }
359 ]
360 }
361 ]
362 }`
363 retvalNoBanners = rulesengine.RuleWithOverrides{
364 Command: rulesengine.Command{
365 Name: "testCommand",
366 ID: "testCommandID",
367 },
368 Default: rulesengine.DefaultRule{
369 Privileges: []rulesengine.Privilege{{
370 Name: "testPriv2",
371 ID: "testPrivID2",
372 }},
373 },
374 Banners: []rulesengine.BannerPrivOverrides{},
375 }
376 retValStringNoBanners = `{
377 "command": {
378 "id": "testCommandID",
379 "name": "testCommand"
380 },
381 "default": {
382 "privileges": [
383 {
384 "id": "testPrivID2",
385 "name": "testPriv2"
386 }
387 ]
388 },
389 "banners": []
390 }`
391 retValNoDefaults = rulesengine.RuleWithOverrides{
392 Command: rulesengine.Command{
393 Name: "testCommand",
394 ID: "testCommandID",
395 },
396 Banners: []rulesengine.BannerPrivOverrides{{
397 Banner: rulesengine.Banner{
398 BannerName: "testBannerName",
399 BannerID: "testBannerID",
400 },
401 Privileges: []rulesengine.Privilege{
402 {
403 Name: "testPriv1",
404 ID: "testPrivID1",
405 },
406 },
407 }},
408 Default: rulesengine.DefaultRule{},
409 }
410 retValNoDefaultsString = `{
411 "command": {
412 "id": "testCommandID",
413 "name": "testCommand"
414 },
415 "default": {},
416 "banners": [
417 {
418 "banner": {
419 "id": "testBannerID",
420 "name": "testBannerName"
421 },
422 "privileges": [
423 {
424 "id": "testPrivID1",
425 "name": "testPriv1"
426 }
427 ]
428 }
429 ]
430 }`
431
432 retValCommandOnly = rulesengine.RuleWithOverrides{
433 Command: rulesengine.Command{
434 Name: "testCommand",
435 ID: "testCommandID",
436 },
437 Banners: []rulesengine.BannerPrivOverrides{},
438 Default: rulesengine.DefaultRule{},
439 }
440 retValCommandOnlyString = `{
441 "command": {
442 "id": "testCommandID",
443 "name": "testCommand"
444 },
445 "default": {},
446 "banners": []
447 }`
448 )
449
450 type getAllRulesMock struct {
451 rulesengine.RulesEngine
452 retVal rulesengine.RuleWithOverrides
453 retErr error
454 }
455
456 func (gm getAllRulesMock) ReadAllRulesForCommand(_ context.Context, _ string) (rulesengine.RuleWithOverrides, error) {
457 return gm.retVal, gm.retErr
458 }
459 func TestReadAllRulesForCommand(t *testing.T) {
460 t.Parallel()
461 tests := map[string]struct {
462
463 url string
464
465 mreng getAllRulesMock
466
467
468 expStatus int
469 expOut string
470 }{
471 "Nominal": {
472 url: "/admin/rules/commands/testCommand",
473 mreng: getAllRulesMock{
474 retVal: retVal,
475 retErr: nil,
476 },
477 expStatus: 200,
478 expOut: retValString,
479 },
480 "Internal Server Error": {
481 url: "/admin/rules/commands/testCommand",
482 mreng: getAllRulesMock{
483 retVal: rulesengine.RuleWithOverrides{},
484 retErr: fmt.Errorf("something went wrong"),
485 },
486 expStatus: 500,
487 },
488 "Command Not Listed": {
489 url: "/admin/rules/commands/testCommand",
490 mreng: getAllRulesMock{
491 retVal: rulesengine.RuleWithOverrides{},
492 retErr: nil,
493 },
494 expStatus: 200,
495 expOut: "null",
496 },
497 "No Banners": {
498 url: "/admin/rules/commands/testCommand",
499 mreng: getAllRulesMock{
500 retVal: retvalNoBanners,
501 retErr: nil,
502 },
503 expStatus: 200,
504 expOut: retValStringNoBanners,
505 },
506 "No Default Privileges": {
507 url: "/admin/rules/commands/testCommand",
508 mreng: getAllRulesMock{
509 retErr: nil,
510 retVal: retValNoDefaults,
511 },
512 expStatus: 200,
513 expOut: retValNoDefaultsString,
514 },
515 "No Rules": {
516 url: "/admin/rules/commands/testCommand",
517 mreng: getAllRulesMock{
518 retErr: nil,
519 retVal: retValCommandOnly,
520 },
521 expStatus: 200,
522 expOut: retValCommandOnlyString,
523 },
524 }
525 for name, tc := range tests {
526 tc := tc
527 t.Run(name, func(t *testing.T) {
528 t.Parallel()
529 log := newLogger()
530
531 r := httptest.NewRecorder()
532 _, ginEngine := getTestGinContext(r)
533 _, err := New(ginEngine, &tc.mreng, log)
534 assert.NoError(t, err)
535
536 req, err := http.NewRequest(http.MethodGet, tc.url, nil)
537 assert.NoError(t, err)
538
539 ginEngine.ServeHTTP(r, req)
540
541 assert.Equal(t, tc.expStatus, r.Result().StatusCode)
542 if tc.expStatus == 200 && r.Result().StatusCode == 200 {
543 assert.JSONEq(t, tc.expOut, r.Body.String())
544 }
545 })
546 }
547 }
548
View as plain text