package server import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" "github.com/gin-gonic/gin" "github.com/google/shlex" "github.com/stretchr/testify/assert" ) type MockRulesEngine struct { AddedNames map[string][]string Conflict bool RulesEngine } func NewMockReng() MockRulesEngine { return MockRulesEngine{AddedNames: map[string][]string{ "commands": {}, "privs": {}, "rules": {}, }} } func (MockRulesEngine) GetEARolesForCommand(_ context.Context, _ rulesengine.Command, _ string) ([]string, error) { return []string{}, nil } func (MockRulesEngine) UserHasRoles(_ string, _ []string, _ []string) bool { return false } func (mreng *MockRulesEngine) ReadCommands(_ context.Context) ([]rulesengine.Command, error) { lst := []rulesengine.Command{} for _, name := range mreng.AddedNames["commands"] { lst = append(lst, rulesengine.Command{ID: "test", Name: name}) } return lst, nil } func (mreng *MockRulesEngine) ReadPrivileges(_ context.Context) ([]rulesengine.Privilege, error) { lst := []rulesengine.Privilege{} for _, name := range mreng.AddedNames["privs"] { lst = append(lst, rulesengine.Privilege{ID: "test", Name: name}) } return lst, nil } func (mreng *MockRulesEngine) ReadCommand(_ context.Context, name string) (rulesengine.Command, error) { for _, namein := range mreng.AddedNames["commands"] { if name == namein { return rulesengine.Command{Name: name, ID: "test"}, nil } } return rulesengine.Command{}, nil } func (mreng MockRulesEngine) ReadPrivilege(_ context.Context, name string) (rulesengine.Privilege, error) { for _, namein := range mreng.AddedNames["privs"] { if name == namein { return rulesengine.Privilege{Name: name, ID: "test"}, nil } } return rulesengine.Privilege{}, nil } func (mreng *MockRulesEngine) ReadAllDefaultRules(_ context.Context) ([]rulesengine.Rule, error) { res := []rulesengine.Rule{} for _, rule := range mreng.AddedNames["rules"] { vals, err := shlex.Split(rule) if err != nil { return res, err } res = append(res, rulesengine.Rule{ Command: rulesengine.Command{Name: vals[0], ID: "test"}, Privileges: []rulesengine.Privilege{{Name: vals[1], ID: "test"}}, }) } return res, nil } func (mreng *MockRulesEngine) ReadDefaultRulesForCommand(ctx context.Context, _ string) ([]rulesengine.Rule, error) { return mreng.ReadAllDefaultRules(ctx) } func getTestGinContext(r *httptest.ResponseRecorder) (*gin.Context, *gin.Engine) { gin.SetMode(gin.TestMode) ctx, ginEngine := gin.CreateTestContext(r) return ctx, ginEngine } type postDefaultRulesMock struct { RulesEngine dataRet rulesengine.AddRuleResult errRet error callCount int rules rulesengine.WriteRules } func (pdrb *postDefaultRulesMock) AddDefaultRules(_ context.Context, rules rulesengine.WriteRules) (rulesengine.AddRuleResult, error) { pdrb.callCount = pdrb.callCount + 1 pdrb.rules = rules return pdrb.dataRet, pdrb.errRet } func TestPostDefaultRules(t *testing.T) { t.Parallel() tests := map[string]struct { reqBody string mockDataRet rulesengine.AddRuleResult mockErrRet error expMockCalledCount int expCalledRules rulesengine.WriteRules expCode int jsonAssert StringAssertionFunc }{ "Ok": { reqBody: `[ {"command": "ls", "privileges": ["read","write"]}, {"command": "cat", "privileges": ["read","write"]} ]`, expMockCalledCount: 1, expCalledRules: rulesengine.WriteRules{ {Command: "ls", Privileges: []string{"read", "write"}}, {Command: "cat", Privileges: []string{"read", "write"}}, }, expCode: http.StatusOK, jsonAssert: JSONEmpty(), }, "Invalid JSON": { reqBody: `[{"comm`, expMockCalledCount: 0, expCode: http.StatusBadRequest, jsonAssert: JSONEmpty(), }, "Invalid payload": { reqBody: `[{"command": "", "privileges": ["", ""]}]`, expMockCalledCount: 0, expCode: http.StatusBadRequest, jsonAssert: JSONEmpty(), }, "Rulesengine Error": { reqBody: `[ {"command": "ls", "privileges": ["read","write"]}, {"command": "cat", "privileges": ["read","write"]} ]`, mockDataRet: rulesengine.AddRuleResult{}, mockErrRet: fmt.Errorf("an error occurred"), expMockCalledCount: 1, expCalledRules: rulesengine.WriteRules{ {Command: "ls", Privileges: []string{"read", "write"}}, {Command: "cat", Privileges: []string{"read", "write"}}, }, expCode: http.StatusInternalServerError, jsonAssert: JSONEmpty(), }, "Rulesengine Conflict": { reqBody: `[ {"command": "not-here", "privileges": ["read","write"]}, {"command": "cat", "privileges": ["read","not-here"]} ]`, mockDataRet: rulesengine.AddRuleResult{Errors: []rulesengine.Error{ {Command: "not-here", Type: rulesengine.UnknownCommand}, {Privilege: "not-here", Type: rulesengine.UnknownPrivilege}, }}, mockErrRet: nil, expMockCalledCount: 1, expCalledRules: rulesengine.WriteRules{ {Command: "not-here", Privileges: []string{"read", "write"}}, {Command: "cat", Privileges: []string{"read", "not-here"}}, }, expCode: http.StatusNotFound, jsonAssert: JSONEq(`{ "errors": [ {"command":"not-here","type":"Unknown Command"}, {"privilege":"not-here","type":"Unknown Privilege"} ] }`), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ruleseng := postDefaultRulesMock{ dataRet: tc.mockDataRet, errRet: tc.mockErrRet, } log := newLogger() r := httptest.NewRecorder() _, ginEngine := getTestGinContext(r) _, err := New(ginEngine, &ruleseng, log) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, "/admin/rules/default/commands", strings.NewReader(tc.reqBody)) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) assert.Equal(t, tc.expCode, r.Result().StatusCode) assert.Equal(t, tc.expMockCalledCount, ruleseng.callCount) assert.Equal(t, tc.expCalledRules, ruleseng.rules) tc.jsonAssert(t, r.Body.String()) }) } } // nolint:dupl func TestReadAllDefaultRules(t *testing.T) { log := newLogger() t.Setenv("RCLI_RES_DATA_DIR", "./testdata") ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}} r := httptest.NewRecorder() _, ginEngine := getTestGinContext(r) _, err := New(ginEngine, &ruleseng, log) assert.Nil(t, err) req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) response := r.Result() assert.Equal(t, response.StatusCode, http.StatusOK) var respData []rulesengine.Rule err = json.Unmarshal(r.Body.Bytes(), &respData) assert.NoError(t, err) assert.Equal(t, []rulesengine.Rule{{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}}, respData) } // nolint:dupl func TestReadDefaultRuleForCommand(t *testing.T) { log := newLogger() t.Setenv("RCLI_RES_DATA_DIR", "./testdata") ruleseng := MockRulesEngine{AddedNames: map[string][]string{"rules": {"ls basic"}}} r := httptest.NewRecorder() _, ginEngine := getTestGinContext(r) _, err := New(ginEngine, &ruleseng, log) assert.Nil(t, err) req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands/ls", nil) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) response := r.Result() assert.Equal(t, response.StatusCode, http.StatusOK) var respData rulesengine.Rule err = json.Unmarshal(r.Body.Bytes(), &respData) assert.NoError(t, err) assert.Equal(t, rulesengine.Rule{Command: rulesengine.Command{Name: "ls", ID: "test"}, Privileges: []rulesengine.Privilege{{Name: "basic", ID: "test"}}}, respData) } // nolint:dupl func TestReadDefaultRulesNoRules(t *testing.T) { log := newLogger() t.Setenv("RCLI_RES_DATA_DIR", "./testdata") ruleseng := MockRulesEngine{} r := httptest.NewRecorder() _, ginEngine := getTestGinContext(r) _, err := New(ginEngine, &ruleseng, log) assert.Nil(t, err) req, err := http.NewRequest(http.MethodGet, "/admin/rules/default/commands", nil) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) response := r.Result() assert.Equal(t, response.StatusCode, http.StatusOK) var respData []rulesengine.Rule err = json.Unmarshal(r.Body.Bytes(), &respData) assert.NoError(t, err) assert.Equal(t, []rulesengine.Rule(nil), respData) } var ( retVal = rulesengine.RuleWithOverrides{ Command: rulesengine.Command{ Name: "testCommand", ID: "testCommandID", }, Banners: []rulesengine.BannerPrivOverrides{{ Banner: rulesengine.Banner{ BannerName: "testBannerName", BannerID: "testBannerID", }, Privileges: []rulesengine.Privilege{ { Name: "testPriv1", ID: "testPrivID1", }, }, }}, Default: rulesengine.DefaultRule{Privileges: []rulesengine.Privilege{{ Name: "testPriv2", ID: "testPrivID2", }}}, } retValString = `{ "command": { "id": "testCommandID", "name": "testCommand" }, "default": { "privileges": [ { "id": "testPrivID2", "name": "testPriv2" } ] }, "banners": [ { "banner": { "id": "testBannerID", "name": "testBannerName" }, "privileges": [ { "id": "testPrivID1", "name": "testPriv1" } ] } ] }` retvalNoBanners = rulesengine.RuleWithOverrides{ Command: rulesengine.Command{ Name: "testCommand", ID: "testCommandID", }, Default: rulesengine.DefaultRule{ Privileges: []rulesengine.Privilege{{ Name: "testPriv2", ID: "testPrivID2", }}, }, Banners: []rulesengine.BannerPrivOverrides{}, } retValStringNoBanners = `{ "command": { "id": "testCommandID", "name": "testCommand" }, "default": { "privileges": [ { "id": "testPrivID2", "name": "testPriv2" } ] }, "banners": [] }` retValNoDefaults = rulesengine.RuleWithOverrides{ Command: rulesengine.Command{ Name: "testCommand", ID: "testCommandID", }, Banners: []rulesengine.BannerPrivOverrides{{ Banner: rulesengine.Banner{ BannerName: "testBannerName", BannerID: "testBannerID", }, Privileges: []rulesengine.Privilege{ { Name: "testPriv1", ID: "testPrivID1", }, }, }}, Default: rulesengine.DefaultRule{}, } retValNoDefaultsString = `{ "command": { "id": "testCommandID", "name": "testCommand" }, "default": {}, "banners": [ { "banner": { "id": "testBannerID", "name": "testBannerName" }, "privileges": [ { "id": "testPrivID1", "name": "testPriv1" } ] } ] }` retValCommandOnly = rulesengine.RuleWithOverrides{ Command: rulesengine.Command{ Name: "testCommand", ID: "testCommandID", }, Banners: []rulesengine.BannerPrivOverrides{}, Default: rulesengine.DefaultRule{}, } retValCommandOnlyString = `{ "command": { "id": "testCommandID", "name": "testCommand" }, "default": {}, "banners": [] }` ) type getAllRulesMock struct { rulesengine.RulesEngine retVal rulesengine.RuleWithOverrides retErr error } func (gm getAllRulesMock) ReadAllRulesForCommand(_ context.Context, _ string) (rulesengine.RuleWithOverrides, error) { return gm.retVal, gm.retErr } func TestReadAllRulesForCommand(t *testing.T) { t.Parallel() tests := map[string]struct { // request details url string // defined rules engine that we will be using mreng getAllRulesMock // expected results from http expStatus int expOut string }{ "Nominal": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retVal: retVal, retErr: nil, }, expStatus: 200, expOut: retValString, }, "Internal Server Error": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retVal: rulesengine.RuleWithOverrides{}, retErr: fmt.Errorf("something went wrong"), }, expStatus: 500, }, "Command Not Listed": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retVal: rulesengine.RuleWithOverrides{}, retErr: nil, }, expStatus: 200, expOut: "null", }, "No Banners": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retVal: retvalNoBanners, retErr: nil, }, expStatus: 200, expOut: retValStringNoBanners, }, "No Default Privileges": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retErr: nil, retVal: retValNoDefaults, }, expStatus: 200, expOut: retValNoDefaultsString, }, "No Rules": { url: "/admin/rules/commands/testCommand", mreng: getAllRulesMock{ retErr: nil, retVal: retValCommandOnly, }, expStatus: 200, expOut: retValCommandOnlyString, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() log := newLogger() r := httptest.NewRecorder() _, ginEngine := getTestGinContext(r) _, err := New(ginEngine, &tc.mreng, log) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, tc.url, nil) assert.NoError(t, err) ginEngine.ServeHTTP(r, req) assert.Equal(t, tc.expStatus, r.Result().StatusCode) if tc.expStatus == 200 && r.Result().StatusCode == 200 { assert.JSONEq(t, tc.expOut, r.Body.String()) } }) } }