package authservice import ( "context" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "path" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "edge-infra.dev/pkg/sds/emergencyaccess/apierror" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" "edge-infra.dev/pkg/sds/emergencyaccess/retriever" "edge-infra.dev/pkg/sds/emergencyaccess/types" ) // testing helper type type helper interface { Helper() } type mockDataset struct { Dataset } type mockRetriever struct { mockArtifact func(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) } func (m *mockRetriever) Artifact(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) { return m.mockArtifact(ctx, name, artifactType) } func EqualError(message string) assert.ErrorAssertionFunc { return func(t assert.TestingT, err error, i ...interface{}) bool { if help, ok := t.(helper); ok { help.Helper() } return assert.EqualError(t, err, message, i...) } } // assert.ErrorAssertionFunc that asserts the error is an api error with the given // code, and contains the given message in the error string func APIError(code apierror.ErrorCode, message string) assert.ErrorAssertionFunc { return func(tt assert.TestingT, err error, i ...interface{}) bool { if help, ok := tt.(helper); ok { help.Helper() } if !assert.ErrorContains(tt, err, message, i...) { return false } if !assert.Implements(tt, (*apierror.APIError)(nil), err, i...) { return false } e := err.(apierror.APIError) return assert.Equal(tt, code, e.Code(), i...) } } const ( validBannerID = "bannerID" storeID = "storeID" terminalID = "terminalID" username = "username" email = "user@ncr.com" role = "test" ) func TestSuccessAuthorizeCommand(t *testing.T) { ctx := context.Background() server := rulesEngineServer() userServer := userServiceServer() defer server.Close() defer userServer.Close() ds := mockDataset{} as, err := New( Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]}, ds, nil, ) assert.NoError(t, err) ctx = types.UserIntoContext(ctx, types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}}) val, err := as.AuthorizeCommand(ctx, CommandAuthPayload{ Command: "ls", Target: Target{BannerID: validBannerID}}) assert.Nil(t, err) assert.True(t, val.Valid) } func TestGetEARolesForUserPass(t *testing.T) { t.Parallel() tests := map[string]struct { asGenerator func(ruleServer *httptest.Server, userServer *httptest.Server) (*AuthService, error) assertErr assert.ErrorAssertionFunc expRes []string }{ "Return EARoles from userservice": { asGenerator: func(ruleServer *httptest.Server, userServer *httptest.Server) (*AuthService, error) { ds := mockDataset{} return New( Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]}, ds, nil, ) }, assertErr: assert.NoError, expRes: []string{role}, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() // setup rServer, uServer := rulesEngineServer(), userServiceServer() defer rServer.Close() defer uServer.Close() // create the authservice as, err := tc.asGenerator(rServer, uServer) assert.NoError(t, err) //call the function eaRoles, err := as.getRolesForUser( context.Background(), types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}}, ) // check the returned earoles match tc.assertErr(t, err) assert.Equal(t, tc.expRes, eaRoles) }) } } func TestGetEARolesForUserFail(t *testing.T) { t.Parallel() // setup. user server returns a predictable error. uServer := badUserServiceServer() defer uServer.Close() // create the authservice ds := mockDataset{} as, err := New( Config{UserServiceHost: uServer.URL[7:]}, ds, nil, ) assert.NoError(t, err) //call the function _, err = as.getRolesForUser( context.Background(), types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}}, ) // check the error matches badUserServiceServer error assert.Contains(t, err.Error(), "service returned status 500 Internal Server Error") } func TestFailAuthorizeCommand(t *testing.T) { t.Parallel() tests := map[string]struct { comPath string payload CommandAuthPayload assertError assert.ErrorAssertionFunc ctx context.Context }{ "404 bad relative path": { comPath: "badpath", payload: CommandAuthPayload{Command: "ls", Target: Target{BannerID: validBannerID}}, assertError: EqualError("rules engine returned status 404 Not Found"), ctx: types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}, Banners: []string{validBannerID}}), }, "Permission denied on command": { comPath: "validatecommand", payload: CommandAuthPayload{Command: "rm", Target: Target{BannerID: validBannerID}}, assertError: assert.NoError, ctx: types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}, Banners: []string{validBannerID}}), }, "No EARoles": { comPath: "validatecommand", payload: CommandAuthPayload{Command: "ls", Target: Target{BannerID: validBannerID}}, assertError: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"), ctx: types.UserIntoContext(context.Background(), types.User{Email: email, Banners: []string{validBannerID}}), }, "403 no permission for target": { payload: CommandAuthPayload{Command: "rm", Target: Target{BannerID: validBannerID}}, assertError: EqualError("60003: User Authorization Failure - User not permitted to perform this action. Error: banner not found in user struct"), ctx: types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}}), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() server := rulesEngineServer() userServer := userServiceServer() defer server.Close() defer userServer.Close() ds := mockDataset{} as, err := New( Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]}, ds, nil, ) assert.NoError(t, err) as.validateComPath = tc.comPath val, err := as.AuthorizeCommand(tc.ctx, tc.payload) tc.assertError(t, err) assert.False(t, val.Valid) }) } } func userServiceServer() *httptest.Server { mux := http.NewServeMux() mux.HandleFunc(path.Join("/", getEARolesPath), func(w http.ResponseWriter, r *http.Request) { values := r.URL.Query() role := values.Get("role") res := []string{role} // values.Get returns an empty string on bad match. want to return an empty slice if this is the case, not a []string{""} or []string{nil} if role == "" { res = []string{} } b, err := json.Marshal(res) if err != nil { return } _, err = w.Write(b) if err != nil { return } }) server := httptest.NewServer(mux) return server } func badUserServiceServer() *httptest.Server { mux := http.NewServeMux() mux.HandleFunc(path.Join("/", getEARolesPath), func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) }) server := httptest.NewServer(mux) return server } type rulesOpts func(*rulesEngineOpts) func expCommand(name string) rulesOpts { return func(reo *rulesEngineOpts) { reo.expCommand = name } } func expType(reqType eaconst.RequestType) rulesOpts { return func(reo *rulesEngineOpts) { reo.expType = reqType } } type rulesEngineOpts struct { expCommand string expType eaconst.RequestType } func rulesEngineServer(opts ...rulesOpts) *httptest.Server { o := rulesEngineOpts{ expCommand: "ls", expType: eaconst.Command, } for _, opt := range opts { opt(&o) } mux := http.NewServeMux() mux.HandleFunc(path.Join("/", defaultValidateComPath), func(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { return } var payload RulesEnginePayload err = json.Unmarshal(data, &payload) if err != nil { w.WriteHeader(http.StatusBadRequest) return } // comparison res := Response{Valid: checkPayload(o, payload)} b, err := json.Marshal(res) if err != nil { return } _, err = w.Write(b) if err != nil { return } }) server := httptest.NewServer(mux) return server } func badRulesEngineServer(...rulesOpts) *httptest.Server { mux := http.NewServeMux() mux.HandleFunc(path.Join("/", defaultValidateComPath), func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) }) server := httptest.NewServer(mux) return server } func checkPayload(o rulesEngineOpts, payload RulesEnginePayload) bool { if payload.Command.Name != o.expCommand { return false } if payload.Command.Type != o.expType { return false } if payload.Target.BannerID != validBannerID { return false } if len(payload.Identity.EAroles) == 0 { return false } if len(payload.Identity.EAroles) > 1 || payload.Identity.EAroles[0] != role { return false } return true } func TestAuthorizeRequest(t *testing.T) { t.Parallel() user := types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}} ctx := types.UserIntoContext(context.Background(), user) userServer := userServiceServer() t.Cleanup(userServer.Close) tests := map[string]struct { rulesServer func(...rulesOpts) *httptest.Server request Request retriever Retriever expData string expAttrs map[string]string }{ "Command": { rulesServer: rulesEngineServer, request: Request{ Data: json.RawMessage(`{ "command": "ls hello there" }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, }, expData: `{ "command": "ls hello there" }`, expAttrs: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, }, "Executable": { rulesServer: func(...rulesOpts) *httptest.Server { return rulesEngineServer(expCommand("myScript"), expType(eaconst.Executable)) }, request: Request{ Data: json.RawMessage(`{ "executable": { "name": "myScript" } }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion2_0), eaconst.RequestTypeKey: string(eaconst.Executable), }, }, retriever: &mockRetriever{ mockArtifact: func(_ context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) { if name != "myScript" { return retriever.Artifact{}, fmt.Errorf("mock retriever error, unexpected artifact name: got %q", name) } if artifactType != retriever.Executable { return retriever.Artifact{}, fmt.Errorf("mock retriever error, unexpected artifact type: got %q", artifactType) } sha, err := hex.DecodeString("0f95ed04c41face74eb0fb077282821ba0493d5b3cc2c1e725c1a58c6b8f51ba") if err != nil { return retriever.Artifact{}, err } return retriever.Artifact{ Name: "myScript", Type: retriever.Executable, Artifact: []byte("#!/bin/sh\n\necho hello\n\n"), SHA: sha, }, nil }, }, expData: `{ "executable": { "name": "myScript", "contents": "IyEvYmluL3NoCgplY2hvIGhlbGxvCgo=" }, "args": null }`, expAttrs: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion2_0), eaconst.RequestTypeKey: string(eaconst.Executable), }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ruleServer := tc.rulesServer() t.Cleanup(ruleServer.Close) ds := mockDataset{} as, err := New( Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]}, ds, tc.retriever, ) require.NoError(t, err) payload := AuthorizeRequestPayload{ Request: tc.request, Target: Target{ BannerID: validBannerID, StoreID: storeID, TerminalID: terminalID, }, } req, err := as.AuthorizeRequest(ctx, payload) assert.NoError(t, err) data, err := req.Data() assert.NoError(t, err) assert.JSONEq(t, tc.expData, string(data)) assert.Equal(t, tc.expAttrs, req.Attributes()) }) } } func TestAuthorizeRequestFail(t *testing.T) { t.Parallel() validUser := types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}} validPayload := AuthorizeRequestPayload{ Request: Request{ Data: json.RawMessage(`{ "command": "ls hello there" }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, }, Target: Target{ BannerID: validBannerID, StoreID: storeID, TerminalID: terminalID, }, } validScriptPayload := AuthorizeRequestPayload{ Request: Request{ Data: json.RawMessage(`{ "executable": { "name": "myScript" } }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion2_0), eaconst.RequestTypeKey: string(eaconst.Executable), }, }, Target: Target{ BannerID: validBannerID, StoreID: storeID, TerminalID: terminalID, }, } tests := map[string]struct { ctx context.Context payload AuthorizeRequestPayload mockRetriever func(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) ruleServer func(...rulesOpts) *httptest.Server userServer func() *httptest.Server errAssert assert.ErrorAssertionFunc }{ "Failed To Create Request": { ctx: types.UserIntoContext(context.Background(), validUser), ruleServer: rulesEngineServer, userServer: userServiceServer, errAssert: EqualError("failed to create structured request from payload: failed to find version attribute"), }, "No User": { ctx: context.Background(), payload: validPayload, ruleServer: rulesEngineServer, userServer: userServiceServer, errAssert: EqualError("user struct not found in context"), }, "Get EA Roles Error": { ctx: types.UserIntoContext(context.Background(), validUser), payload: validPayload, ruleServer: rulesEngineServer, userServer: badUserServiceServer, errAssert: EqualError("error when getting ea roles: user service returned status 500 Internal Server Error"), }, "No EA Roles": { ctx: types.UserIntoContext(context.Background(), types.User{Email: email, Username: username}), payload: validPayload, ruleServer: rulesEngineServer, userServer: userServiceServer, errAssert: APIError(apierror.ErrUserMissingRoles, "no roles returned from userservice"), }, "Banner Not Authorized": { ctx: types.UserIntoContext(context.Background(), validUser), payload: AuthorizeRequestPayload{ Request: Request{ Data: json.RawMessage(`{ "command": "ls hello there" }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, }, Target: Target{ BannerID: "invalid-banner", StoreID: storeID, TerminalID: terminalID, }, }, ruleServer: rulesEngineServer, userServer: userServiceServer, errAssert: APIError(apierror.ErrUserNotAuthorized, "banner not found in user struct"), }, "Invalid Command": { ctx: types.UserIntoContext(context.Background(), validUser), payload: AuthorizeRequestPayload{ Request: Request{ Data: json.RawMessage(`{ "command": "rm" }`), Attributes: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, }, Target: Target{ BannerID: validBannerID, StoreID: storeID, TerminalID: terminalID, }, }, ruleServer: rulesEngineServer, userServer: userServiceServer, errAssert: APIError(apierror.ErrUnauthorizedCommand, "command not authorized for user on target"), }, "Rules Engine Non-OK": { ctx: types.UserIntoContext(context.Background(), validUser), payload: validPayload, ruleServer: badRulesEngineServer, userServer: userServiceServer, errAssert: EqualError("rules engine returned status 500 Internal Server Error"), }, "Failed to retrieve artifact": { ctx: types.UserIntoContext(context.Background(), validUser), payload: validScriptPayload, mockRetriever: func(context.Context, string, retriever.ArtifactType) (retriever.Artifact, error) { return retriever.Artifact{}, fmt.Errorf("error retrieving artifact") }, ruleServer: func(...rulesOpts) *httptest.Server { return rulesEngineServer(expCommand("myScript"), expType(eaconst.Executable)) }, userServer: userServiceServer, errAssert: EqualError("failed to retrieve artifact: error retrieving artifact"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ruleServer := tc.ruleServer() userServer := tc.userServer() defer ruleServer.Close() defer userServer.Close() retriever := &mockRetriever{ mockArtifact: tc.mockRetriever, } ds := mockDataset{} as, err := New( Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]}, ds, retriever, ) require.NoError(t, err) req, err := as.AuthorizeRequest(tc.ctx, tc.payload) tc.errAssert(t, err) assert.Nil(t, req) }) } } type mockDatasetTestResolveTarget struct { Dataset projectID string bannerID string storeID string terminalID string } func (ds mockDatasetTestResolveTarget) GetProjectAndBannerID(_ context.Context, banner string) (projectID string, bannerID string, err error) { if banner == "" { err = fmt.Errorf("error GetProjectIDAndBannerID") } return ds.projectID, ds.bannerID, err } func (ds mockDatasetTestResolveTarget) GetStoreID(_ context.Context, store, _ string) (storeID string, err error) { if store == "" { err = fmt.Errorf("error GetStoreID") } return ds.storeID, err } func (ds mockDatasetTestResolveTarget) GetTerminalID(_ context.Context, terminal, _ string) (terminalID string, err error) { if terminal == "" { err = fmt.Errorf("error GetTerminalID") } return ds.terminalID, err } func TestResolveTarget(t *testing.T) { t.Parallel() tests := map[string]struct { payload ResolveTargetPayload ds mockDatasetTestResolveTarget expTarget Target errAssert assert.ErrorAssertionFunc }{ "Valid": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", StoreID: "s", TerminalID: "t", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", bannerID: "bannerID", storeID: "storeID", terminalID: "terminalID", }, expTarget: Target{ ProjectID: "projectID", BannerID: "bannerID", StoreID: "storeID", TerminalID: "terminalID", }, errAssert: assert.NoError, }, "GetProjectIDAndBannerID returns err": { payload: ResolveTargetPayload{}, ds: mockDatasetTestResolveTarget{}, expTarget: Target{}, errAssert: EqualError("error GetProjectIDAndBannerID"), }, "GetStoreID returns err": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", bannerID: "bannerID", }, expTarget: Target{}, errAssert: EqualError("error GetStoreID"), }, "GetTerminalID returns err": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", StoreID: "s", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", bannerID: "bannerID", storeID: "storeID", }, expTarget: Target{}, errAssert: EqualError("error GetTerminalID"), }, "Returned ProjectID nil": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", }, }, ds: mockDatasetTestResolveTarget{ bannerID: "bannerID", }, expTarget: Target{}, errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: project not found for banner b"), }, "Returned BannerID nil": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", }, expTarget: Target{}, errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: banner b not found"), }, "Returned StoreID nil": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", StoreID: "s", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", bannerID: "bannerID", }, expTarget: Target{}, errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: store s not found in given banner b"), }, "Returned TerminalID nil": { payload: ResolveTargetPayload{ Target: Target{ ProjectID: "p", BannerID: "b", StoreID: "s", TerminalID: "t", }, }, ds: mockDatasetTestResolveTarget{ projectID: "projectID", bannerID: "bannerID", storeID: "storeID", }, expTarget: Target{}, errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: terminal t not found in given store s and banner b"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() as, err := New( Config{}, tc.ds, nil, ) assert.NoError(t, err) target, err := as.ResolveTarget(context.Background(), tc.payload) tc.errAssert(t, err) assert.Equal(t, tc.expTarget, target) }) } } func TestAuthorizeTarget(t *testing.T) { t.Parallel() tests := map[string]struct { ctx context.Context target Target errorAssertion assert.ErrorAssertionFunc }{ "Valid": { ctx: types.UserIntoContext(context.Background(), types.User{ Banners: []string{validBannerID}, Username: username, Roles: []string{role}, }), target: Target{BannerID: validBannerID}, errorAssertion: assert.NoError, }, "Error, bannerID doesn't match": { ctx: types.UserIntoContext(context.Background(), types.User{ Banners: []string{validBannerID}, Username: username, Roles: []string{role}, }), target: Target{BannerID: "not-the-same-banner-id"}, errorAssertion: EqualError(apierror.E(apierror.ErrUserNotAuthorized, errors.New("banner not found in user struct"), "User was not assigned banner").Error()), }, "Error, no user in context": { ctx: context.Background(), target: Target{BannerID: validBannerID}, errorAssertion: EqualError("user struct not in context"), }, "Error no EARoles": { ctx: types.UserIntoContext(context.Background(), types.User{ Banners: []string{validBannerID}, Username: username, Roles: []string{}, }), target: Target{BannerID: validBannerID}, errorAssertion: EqualError(apierror.E(apierror.ErrUserMissingRoles, fmt.Errorf("no roles returned from userservice")).Error()), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ds := mockDataset{} uServer := userServiceServer() defer uServer.Close() as, err := New( Config{UserServiceHost: uServer.URL[7:]}, ds, nil, ) assert.NoError(t, err) err = as.AuthorizeTarget(tc.ctx, tc.target) tc.errorAssertion(t, err) }) } } func TestAuthorizeUser(t *testing.T) { t.Parallel() tests := map[string]struct { user types.User errAssert assert.ErrorAssertionFunc }{ "Valid": { user: types.User{Roles: []string{"role1", "role2", "role3"}}, errAssert: assert.NoError, }, "Invalid No Roles": { user: types.User{}, errAssert: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"), }, "Invalid No Return": { user: types.User{Roles: []string{""}}, errAssert: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() // setup userServer := userServiceServer() defer userServer.Close() ctx := types.UserIntoContext(context.Background(), tc.user) ds := mockDataset{} as, err := New( Config{UserServiceHost: userServer.URL[7:]}, ds, nil, ) assert.NoError(t, err) err = as.AuthorizeUser(ctx) tc.errAssert(t, err) }) } }