...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/authservice/authservice_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/authservice

     1  package authservice
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"path"
    13  	"testing"
    14  
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	"edge-infra.dev/pkg/sds/emergencyaccess/apierror"
    19  	"edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
    20  	"edge-infra.dev/pkg/sds/emergencyaccess/retriever"
    21  	"edge-infra.dev/pkg/sds/emergencyaccess/types"
    22  )
    23  
    24  // testing helper type
    25  type helper interface {
    26  	Helper()
    27  }
    28  
    29  type mockDataset struct {
    30  	Dataset
    31  }
    32  
    33  type mockRetriever struct {
    34  	mockArtifact func(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error)
    35  }
    36  
    37  func (m *mockRetriever) Artifact(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) {
    38  	return m.mockArtifact(ctx, name, artifactType)
    39  }
    40  
    41  func EqualError(message string) assert.ErrorAssertionFunc {
    42  	return func(t assert.TestingT, err error, i ...interface{}) bool {
    43  		if help, ok := t.(helper); ok {
    44  			help.Helper()
    45  		}
    46  
    47  		return assert.EqualError(t, err, message, i...)
    48  	}
    49  }
    50  
    51  // assert.ErrorAssertionFunc that asserts the error is an api error with the given
    52  // code, and contains the given message in the error string
    53  func APIError(code apierror.ErrorCode, message string) assert.ErrorAssertionFunc {
    54  	return func(tt assert.TestingT, err error, i ...interface{}) bool {
    55  		if help, ok := tt.(helper); ok {
    56  			help.Helper()
    57  		}
    58  
    59  		if !assert.ErrorContains(tt, err, message, i...) {
    60  			return false
    61  		}
    62  
    63  		if !assert.Implements(tt, (*apierror.APIError)(nil), err, i...) {
    64  			return false
    65  		}
    66  
    67  		e := err.(apierror.APIError)
    68  		return assert.Equal(tt, code, e.Code(), i...)
    69  	}
    70  }
    71  
    72  const (
    73  	validBannerID = "bannerID"
    74  	storeID       = "storeID"
    75  	terminalID    = "terminalID"
    76  	username      = "username"
    77  	email         = "user@ncr.com"
    78  	role          = "test"
    79  )
    80  
    81  func TestSuccessAuthorizeCommand(t *testing.T) {
    82  	ctx := context.Background()
    83  
    84  	server := rulesEngineServer()
    85  	userServer := userServiceServer()
    86  	defer server.Close()
    87  	defer userServer.Close()
    88  
    89  	ds := mockDataset{}
    90  	as, err := New(
    91  		Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]},
    92  		ds,
    93  		nil,
    94  	)
    95  	assert.NoError(t, err)
    96  
    97  	ctx = types.UserIntoContext(ctx, types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}})
    98  	val, err := as.AuthorizeCommand(ctx,
    99  		CommandAuthPayload{
   100  			Command: "ls",
   101  			Target:  Target{BannerID: validBannerID}})
   102  	assert.Nil(t, err)
   103  	assert.True(t, val.Valid)
   104  }
   105  
   106  func TestGetEARolesForUserPass(t *testing.T) {
   107  	t.Parallel()
   108  	tests := map[string]struct {
   109  		asGenerator func(ruleServer *httptest.Server, userServer *httptest.Server) (*AuthService, error)
   110  		assertErr   assert.ErrorAssertionFunc
   111  		expRes      []string
   112  	}{
   113  		"Return EARoles from userservice": {
   114  			asGenerator: func(ruleServer *httptest.Server, userServer *httptest.Server) (*AuthService, error) {
   115  				ds := mockDataset{}
   116  				return New(
   117  					Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]},
   118  					ds,
   119  					nil,
   120  				)
   121  			},
   122  			assertErr: assert.NoError,
   123  			expRes:    []string{role},
   124  		},
   125  	}
   126  
   127  	for name, tc := range tests {
   128  		tc := tc
   129  		t.Run(name, func(t *testing.T) {
   130  			t.Parallel()
   131  			// setup
   132  			rServer, uServer := rulesEngineServer(), userServiceServer()
   133  			defer rServer.Close()
   134  			defer uServer.Close()
   135  
   136  			// create the authservice
   137  			as, err := tc.asGenerator(rServer, uServer)
   138  			assert.NoError(t, err)
   139  
   140  			//call the function
   141  			eaRoles, err := as.getRolesForUser(
   142  				context.Background(),
   143  				types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}},
   144  			)
   145  
   146  			// check the returned earoles match
   147  			tc.assertErr(t, err)
   148  			assert.Equal(t, tc.expRes, eaRoles)
   149  		})
   150  	}
   151  }
   152  func TestGetEARolesForUserFail(t *testing.T) {
   153  	t.Parallel()
   154  	// setup. user server returns a predictable error.
   155  	uServer := badUserServiceServer()
   156  	defer uServer.Close()
   157  
   158  	// create the authservice
   159  	ds := mockDataset{}
   160  	as, err := New(
   161  		Config{UserServiceHost: uServer.URL[7:]},
   162  		ds,
   163  		nil,
   164  	)
   165  	assert.NoError(t, err)
   166  
   167  	//call the function
   168  	_, err = as.getRolesForUser(
   169  		context.Background(),
   170  		types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}},
   171  	)
   172  
   173  	// check the error matches badUserServiceServer error
   174  	assert.Contains(t, err.Error(), "service returned status 500 Internal Server Error")
   175  }
   176  
   177  func TestFailAuthorizeCommand(t *testing.T) {
   178  	t.Parallel()
   179  	tests := map[string]struct {
   180  		comPath     string
   181  		payload     CommandAuthPayload
   182  		assertError assert.ErrorAssertionFunc
   183  		ctx         context.Context
   184  	}{
   185  		"404 bad relative path": {
   186  			comPath:     "badpath",
   187  			payload:     CommandAuthPayload{Command: "ls", Target: Target{BannerID: validBannerID}},
   188  			assertError: EqualError("rules engine returned status 404 Not Found"),
   189  			ctx:         types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}, Banners: []string{validBannerID}}),
   190  		},
   191  		"Permission denied on command": {
   192  			comPath:     "validatecommand",
   193  			payload:     CommandAuthPayload{Command: "rm", Target: Target{BannerID: validBannerID}},
   194  			assertError: assert.NoError,
   195  			ctx:         types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}, Banners: []string{validBannerID}}),
   196  		},
   197  		"No EARoles": {
   198  			comPath:     "validatecommand",
   199  			payload:     CommandAuthPayload{Command: "ls", Target: Target{BannerID: validBannerID}},
   200  			assertError: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"),
   201  			ctx:         types.UserIntoContext(context.Background(), types.User{Email: email, Banners: []string{validBannerID}}),
   202  		},
   203  		"403 no permission for target": {
   204  			payload:     CommandAuthPayload{Command: "rm", Target: Target{BannerID: validBannerID}},
   205  			assertError: EqualError("60003: User Authorization Failure - User not permitted to perform this action. Error: banner not found in user struct"),
   206  			ctx:         types.UserIntoContext(context.Background(), types.User{Email: email, Roles: []string{role}}),
   207  		},
   208  	}
   209  
   210  	for name, tc := range tests {
   211  		tc := tc
   212  		t.Run(name, func(t *testing.T) {
   213  			t.Parallel()
   214  			server := rulesEngineServer()
   215  			userServer := userServiceServer()
   216  			defer server.Close()
   217  			defer userServer.Close()
   218  
   219  			ds := mockDataset{}
   220  			as, err := New(
   221  				Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]},
   222  				ds,
   223  				nil,
   224  			)
   225  			assert.NoError(t, err)
   226  
   227  			as.validateComPath = tc.comPath
   228  			val, err := as.AuthorizeCommand(tc.ctx, tc.payload)
   229  			tc.assertError(t, err)
   230  			assert.False(t, val.Valid)
   231  		})
   232  	}
   233  }
   234  
   235  func userServiceServer() *httptest.Server {
   236  	mux := http.NewServeMux()
   237  	mux.HandleFunc(path.Join("/", getEARolesPath), func(w http.ResponseWriter, r *http.Request) {
   238  		values := r.URL.Query()
   239  		role := values.Get("role")
   240  
   241  		res := []string{role}
   242  		// values.Get returns an empty string on bad match. want to return an empty slice if this is the case, not a []string{""} or []string{nil}
   243  		if role == "" {
   244  			res = []string{}
   245  		}
   246  		b, err := json.Marshal(res)
   247  		if err != nil {
   248  			return
   249  		}
   250  		_, err = w.Write(b)
   251  		if err != nil {
   252  			return
   253  		}
   254  	})
   255  	server := httptest.NewServer(mux)
   256  	return server
   257  }
   258  
   259  func badUserServiceServer() *httptest.Server {
   260  	mux := http.NewServeMux()
   261  	mux.HandleFunc(path.Join("/", getEARolesPath), func(w http.ResponseWriter, _ *http.Request) {
   262  		w.WriteHeader(http.StatusInternalServerError)
   263  	})
   264  	server := httptest.NewServer(mux)
   265  	return server
   266  }
   267  
   268  type rulesOpts func(*rulesEngineOpts)
   269  
   270  func expCommand(name string) rulesOpts {
   271  	return func(reo *rulesEngineOpts) {
   272  		reo.expCommand = name
   273  	}
   274  }
   275  
   276  func expType(reqType eaconst.RequestType) rulesOpts {
   277  	return func(reo *rulesEngineOpts) {
   278  		reo.expType = reqType
   279  	}
   280  }
   281  
   282  type rulesEngineOpts struct {
   283  	expCommand string
   284  	expType    eaconst.RequestType
   285  }
   286  
   287  func rulesEngineServer(opts ...rulesOpts) *httptest.Server {
   288  	o := rulesEngineOpts{
   289  		expCommand: "ls",
   290  		expType:    eaconst.Command,
   291  	}
   292  	for _, opt := range opts {
   293  		opt(&o)
   294  	}
   295  
   296  	mux := http.NewServeMux()
   297  	mux.HandleFunc(path.Join("/", defaultValidateComPath), func(w http.ResponseWriter, r *http.Request) {
   298  		data, err := io.ReadAll(r.Body)
   299  		if err != nil {
   300  			return
   301  		}
   302  
   303  		var payload RulesEnginePayload
   304  		err = json.Unmarshal(data, &payload)
   305  		if err != nil {
   306  			w.WriteHeader(http.StatusBadRequest)
   307  			return
   308  		}
   309  
   310  		// comparison
   311  		res := Response{Valid: checkPayload(o, payload)}
   312  
   313  		b, err := json.Marshal(res)
   314  		if err != nil {
   315  			return
   316  		}
   317  
   318  		_, err = w.Write(b)
   319  		if err != nil {
   320  			return
   321  		}
   322  	})
   323  	server := httptest.NewServer(mux)
   324  	return server
   325  }
   326  
   327  func badRulesEngineServer(...rulesOpts) *httptest.Server {
   328  	mux := http.NewServeMux()
   329  	mux.HandleFunc(path.Join("/", defaultValidateComPath), func(w http.ResponseWriter, _ *http.Request) {
   330  		w.WriteHeader(http.StatusInternalServerError)
   331  	})
   332  	server := httptest.NewServer(mux)
   333  	return server
   334  }
   335  
   336  func checkPayload(o rulesEngineOpts, payload RulesEnginePayload) bool {
   337  	if payload.Command.Name != o.expCommand {
   338  		return false
   339  	}
   340  
   341  	if payload.Command.Type != o.expType {
   342  		return false
   343  	}
   344  
   345  	if payload.Target.BannerID != validBannerID {
   346  		return false
   347  	}
   348  	if len(payload.Identity.EAroles) == 0 {
   349  		return false
   350  	}
   351  
   352  	if len(payload.Identity.EAroles) > 1 || payload.Identity.EAroles[0] != role {
   353  		return false
   354  	}
   355  	return true
   356  }
   357  
   358  func TestAuthorizeRequest(t *testing.T) {
   359  	t.Parallel()
   360  
   361  	user := types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}}
   362  	ctx := types.UserIntoContext(context.Background(), user)
   363  
   364  	userServer := userServiceServer()
   365  	t.Cleanup(userServer.Close)
   366  
   367  	tests := map[string]struct {
   368  		rulesServer func(...rulesOpts) *httptest.Server
   369  		request     Request
   370  		retriever   Retriever
   371  		expData     string
   372  		expAttrs    map[string]string
   373  	}{
   374  		"Command": {
   375  			rulesServer: rulesEngineServer,
   376  			request: Request{
   377  				Data: json.RawMessage(`{
   378  					"command": "ls hello there"
   379  				}`),
   380  				Attributes: map[string]string{
   381  					eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   382  					eaconst.RequestTypeKey: string(eaconst.Command),
   383  				},
   384  			},
   385  			expData: `{
   386  				"command": "ls hello there"
   387  			}`,
   388  			expAttrs: map[string]string{
   389  				eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   390  				eaconst.RequestTypeKey: string(eaconst.Command),
   391  			},
   392  		},
   393  		"Executable": {
   394  			rulesServer: func(...rulesOpts) *httptest.Server {
   395  				return rulesEngineServer(expCommand("myScript"), expType(eaconst.Executable))
   396  			},
   397  			request: Request{
   398  				Data: json.RawMessage(`{
   399  					"executable": {
   400  						"name": "myScript"
   401  					}
   402  				}`),
   403  				Attributes: map[string]string{
   404  					eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
   405  					eaconst.RequestTypeKey: string(eaconst.Executable),
   406  				},
   407  			},
   408  			retriever: &mockRetriever{
   409  				mockArtifact: func(_ context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error) {
   410  					if name != "myScript" {
   411  						return retriever.Artifact{}, fmt.Errorf("mock retriever error, unexpected artifact name: got %q", name)
   412  					}
   413  					if artifactType != retriever.Executable {
   414  						return retriever.Artifact{}, fmt.Errorf("mock retriever error, unexpected artifact type: got %q", artifactType)
   415  					}
   416  
   417  					sha, err := hex.DecodeString("0f95ed04c41face74eb0fb077282821ba0493d5b3cc2c1e725c1a58c6b8f51ba")
   418  					if err != nil {
   419  						return retriever.Artifact{}, err
   420  					}
   421  					return retriever.Artifact{
   422  						Name:     "myScript",
   423  						Type:     retriever.Executable,
   424  						Artifact: []byte("#!/bin/sh\n\necho hello\n\n"),
   425  						SHA:      sha,
   426  					}, nil
   427  				},
   428  			},
   429  			expData: `{
   430  				"executable": {
   431  					"name": "myScript",
   432  					"contents": "IyEvYmluL3NoCgplY2hvIGhlbGxvCgo="
   433  				},
   434  				"args": null
   435  			}`,
   436  			expAttrs: map[string]string{
   437  				eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
   438  				eaconst.RequestTypeKey: string(eaconst.Executable),
   439  			},
   440  		},
   441  	}
   442  
   443  	for name, tc := range tests {
   444  		tc := tc
   445  		t.Run(name, func(t *testing.T) {
   446  			t.Parallel()
   447  
   448  			ruleServer := tc.rulesServer()
   449  			t.Cleanup(ruleServer.Close)
   450  
   451  			ds := mockDataset{}
   452  			as, err := New(
   453  				Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]},
   454  				ds,
   455  				tc.retriever,
   456  			)
   457  			require.NoError(t, err)
   458  
   459  			payload := AuthorizeRequestPayload{
   460  				Request: tc.request,
   461  				Target: Target{
   462  					BannerID:   validBannerID,
   463  					StoreID:    storeID,
   464  					TerminalID: terminalID,
   465  				},
   466  			}
   467  			req, err := as.AuthorizeRequest(ctx, payload)
   468  			assert.NoError(t, err)
   469  
   470  			data, err := req.Data()
   471  			assert.NoError(t, err)
   472  
   473  			assert.JSONEq(t, tc.expData, string(data))
   474  			assert.Equal(t, tc.expAttrs, req.Attributes())
   475  		})
   476  	}
   477  }
   478  
   479  func TestAuthorizeRequestFail(t *testing.T) {
   480  	t.Parallel()
   481  
   482  	validUser := types.User{Email: email, Username: username, Roles: []string{role}, Banners: []string{validBannerID}}
   483  	validPayload := AuthorizeRequestPayload{
   484  		Request: Request{
   485  			Data: json.RawMessage(`{
   486  				"command": "ls hello there"
   487  			}`),
   488  			Attributes: map[string]string{
   489  				eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   490  				eaconst.RequestTypeKey: string(eaconst.Command),
   491  			},
   492  		},
   493  		Target: Target{
   494  			BannerID:   validBannerID,
   495  			StoreID:    storeID,
   496  			TerminalID: terminalID,
   497  		},
   498  	}
   499  
   500  	validScriptPayload := AuthorizeRequestPayload{
   501  		Request: Request{
   502  			Data: json.RawMessage(`{
   503  				"executable": {
   504  					"name": "myScript"
   505  				}
   506  			}`),
   507  			Attributes: map[string]string{
   508  				eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
   509  				eaconst.RequestTypeKey: string(eaconst.Executable),
   510  			},
   511  		},
   512  		Target: Target{
   513  			BannerID:   validBannerID,
   514  			StoreID:    storeID,
   515  			TerminalID: terminalID,
   516  		},
   517  	}
   518  
   519  	tests := map[string]struct {
   520  		ctx           context.Context
   521  		payload       AuthorizeRequestPayload
   522  		mockRetriever func(ctx context.Context, name string, artifactType retriever.ArtifactType) (retriever.Artifact, error)
   523  		ruleServer    func(...rulesOpts) *httptest.Server
   524  		userServer    func() *httptest.Server
   525  		errAssert     assert.ErrorAssertionFunc
   526  	}{
   527  		"Failed To Create Request": {
   528  			ctx:        types.UserIntoContext(context.Background(), validUser),
   529  			ruleServer: rulesEngineServer,
   530  			userServer: userServiceServer,
   531  			errAssert:  EqualError("failed to create structured request from payload: failed to find version attribute"),
   532  		},
   533  		"No User": {
   534  			ctx:        context.Background(),
   535  			payload:    validPayload,
   536  			ruleServer: rulesEngineServer,
   537  			userServer: userServiceServer,
   538  			errAssert:  EqualError("user struct not found in context"),
   539  		},
   540  		"Get EA Roles Error": {
   541  			ctx:        types.UserIntoContext(context.Background(), validUser),
   542  			payload:    validPayload,
   543  			ruleServer: rulesEngineServer,
   544  			userServer: badUserServiceServer,
   545  			errAssert:  EqualError("error when getting ea roles: user service returned status 500 Internal Server Error"),
   546  		},
   547  		"No EA Roles": {
   548  			ctx:        types.UserIntoContext(context.Background(), types.User{Email: email, Username: username}),
   549  			payload:    validPayload,
   550  			ruleServer: rulesEngineServer,
   551  			userServer: userServiceServer,
   552  			errAssert:  APIError(apierror.ErrUserMissingRoles, "no roles returned from userservice"),
   553  		},
   554  		"Banner Not Authorized": {
   555  			ctx: types.UserIntoContext(context.Background(), validUser),
   556  			payload: AuthorizeRequestPayload{
   557  				Request: Request{
   558  					Data: json.RawMessage(`{
   559  						"command": "ls hello there"
   560  					}`),
   561  					Attributes: map[string]string{
   562  						eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   563  						eaconst.RequestTypeKey: string(eaconst.Command),
   564  					},
   565  				},
   566  				Target: Target{
   567  					BannerID:   "invalid-banner",
   568  					StoreID:    storeID,
   569  					TerminalID: terminalID,
   570  				},
   571  			},
   572  			ruleServer: rulesEngineServer,
   573  			userServer: userServiceServer,
   574  			errAssert:  APIError(apierror.ErrUserNotAuthorized, "banner not found in user struct"),
   575  		},
   576  		"Invalid Command": {
   577  			ctx: types.UserIntoContext(context.Background(), validUser),
   578  			payload: AuthorizeRequestPayload{
   579  				Request: Request{
   580  					Data: json.RawMessage(`{
   581  						"command": "rm"
   582  					}`),
   583  					Attributes: map[string]string{
   584  						eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   585  						eaconst.RequestTypeKey: string(eaconst.Command),
   586  					},
   587  				},
   588  				Target: Target{
   589  					BannerID:   validBannerID,
   590  					StoreID:    storeID,
   591  					TerminalID: terminalID,
   592  				},
   593  			},
   594  			ruleServer: rulesEngineServer,
   595  			userServer: userServiceServer,
   596  			errAssert:  APIError(apierror.ErrUnauthorizedCommand, "command not authorized for user on target"),
   597  		},
   598  		"Rules Engine Non-OK": {
   599  			ctx:        types.UserIntoContext(context.Background(), validUser),
   600  			payload:    validPayload,
   601  			ruleServer: badRulesEngineServer,
   602  			userServer: userServiceServer,
   603  			errAssert:  EqualError("rules engine returned status 500 Internal Server Error"),
   604  		},
   605  		"Failed to retrieve artifact": {
   606  			ctx:     types.UserIntoContext(context.Background(), validUser),
   607  			payload: validScriptPayload,
   608  			mockRetriever: func(context.Context, string, retriever.ArtifactType) (retriever.Artifact, error) {
   609  				return retriever.Artifact{}, fmt.Errorf("error retrieving artifact")
   610  			},
   611  			ruleServer: func(...rulesOpts) *httptest.Server {
   612  				return rulesEngineServer(expCommand("myScript"), expType(eaconst.Executable))
   613  			},
   614  			userServer: userServiceServer,
   615  			errAssert:  EqualError("failed to retrieve artifact: error retrieving artifact"),
   616  		},
   617  	}
   618  
   619  	for name, tc := range tests {
   620  		tc := tc
   621  		t.Run(name, func(t *testing.T) {
   622  			t.Parallel()
   623  
   624  			ruleServer := tc.ruleServer()
   625  			userServer := tc.userServer()
   626  			defer ruleServer.Close()
   627  			defer userServer.Close()
   628  
   629  			retriever := &mockRetriever{
   630  				mockArtifact: tc.mockRetriever,
   631  			}
   632  
   633  			ds := mockDataset{}
   634  			as, err := New(
   635  				Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]},
   636  				ds,
   637  				retriever,
   638  			)
   639  			require.NoError(t, err)
   640  
   641  			req, err := as.AuthorizeRequest(tc.ctx, tc.payload)
   642  			tc.errAssert(t, err)
   643  			assert.Nil(t, req)
   644  		})
   645  	}
   646  }
   647  
   648  type mockDatasetTestResolveTarget struct {
   649  	Dataset
   650  
   651  	projectID  string
   652  	bannerID   string
   653  	storeID    string
   654  	terminalID string
   655  }
   656  
   657  func (ds mockDatasetTestResolveTarget) GetProjectAndBannerID(_ context.Context, banner string) (projectID string, bannerID string, err error) {
   658  	if banner == "" {
   659  		err = fmt.Errorf("error GetProjectIDAndBannerID")
   660  	}
   661  	return ds.projectID, ds.bannerID, err
   662  }
   663  
   664  func (ds mockDatasetTestResolveTarget) GetStoreID(_ context.Context, store, _ string) (storeID string, err error) {
   665  	if store == "" {
   666  		err = fmt.Errorf("error GetStoreID")
   667  	}
   668  	return ds.storeID, err
   669  }
   670  
   671  func (ds mockDatasetTestResolveTarget) GetTerminalID(_ context.Context, terminal, _ string) (terminalID string, err error) {
   672  	if terminal == "" {
   673  		err = fmt.Errorf("error GetTerminalID")
   674  	}
   675  	return ds.terminalID, err
   676  }
   677  
   678  func TestResolveTarget(t *testing.T) {
   679  	t.Parallel()
   680  
   681  	tests := map[string]struct {
   682  		payload ResolveTargetPayload
   683  		ds      mockDatasetTestResolveTarget
   684  
   685  		expTarget Target
   686  		errAssert assert.ErrorAssertionFunc
   687  	}{
   688  		"Valid": {
   689  			payload: ResolveTargetPayload{
   690  				Target: Target{
   691  					ProjectID:  "p",
   692  					BannerID:   "b",
   693  					StoreID:    "s",
   694  					TerminalID: "t",
   695  				},
   696  			},
   697  			ds: mockDatasetTestResolveTarget{
   698  				projectID:  "projectID",
   699  				bannerID:   "bannerID",
   700  				storeID:    "storeID",
   701  				terminalID: "terminalID",
   702  			},
   703  			expTarget: Target{
   704  				ProjectID:  "projectID",
   705  				BannerID:   "bannerID",
   706  				StoreID:    "storeID",
   707  				TerminalID: "terminalID",
   708  			},
   709  			errAssert: assert.NoError,
   710  		},
   711  		"GetProjectIDAndBannerID returns err": {
   712  			payload:   ResolveTargetPayload{},
   713  			ds:        mockDatasetTestResolveTarget{},
   714  			expTarget: Target{},
   715  			errAssert: EqualError("error GetProjectIDAndBannerID"),
   716  		},
   717  		"GetStoreID returns err": {
   718  			payload: ResolveTargetPayload{
   719  				Target: Target{
   720  					ProjectID: "p",
   721  					BannerID:  "b",
   722  				},
   723  			},
   724  			ds: mockDatasetTestResolveTarget{
   725  				projectID: "projectID",
   726  				bannerID:  "bannerID",
   727  			},
   728  			expTarget: Target{},
   729  			errAssert: EqualError("error GetStoreID"),
   730  		},
   731  		"GetTerminalID returns err": {
   732  			payload: ResolveTargetPayload{
   733  				Target: Target{
   734  					ProjectID: "p",
   735  					BannerID:  "b",
   736  					StoreID:   "s",
   737  				},
   738  			},
   739  			ds: mockDatasetTestResolveTarget{
   740  				projectID: "projectID",
   741  				bannerID:  "bannerID",
   742  				storeID:   "storeID",
   743  			},
   744  			expTarget: Target{},
   745  			errAssert: EqualError("error GetTerminalID"),
   746  		},
   747  		"Returned ProjectID nil": {
   748  			payload: ResolveTargetPayload{
   749  				Target: Target{
   750  					ProjectID: "p",
   751  					BannerID:  "b",
   752  				},
   753  			},
   754  			ds: mockDatasetTestResolveTarget{
   755  				bannerID: "bannerID",
   756  			},
   757  			expTarget: Target{},
   758  			errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: project not found for banner b"),
   759  		},
   760  		"Returned BannerID nil": {
   761  			payload: ResolveTargetPayload{
   762  				Target: Target{
   763  					ProjectID: "p",
   764  					BannerID:  "b",
   765  				},
   766  			},
   767  			ds: mockDatasetTestResolveTarget{
   768  				projectID: "projectID",
   769  			},
   770  			expTarget: Target{},
   771  			errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: banner b not found"),
   772  		},
   773  		"Returned StoreID nil": {
   774  			payload: ResolveTargetPayload{
   775  				Target: Target{
   776  					ProjectID: "p",
   777  					BannerID:  "b",
   778  					StoreID:   "s",
   779  				},
   780  			},
   781  			ds: mockDatasetTestResolveTarget{
   782  				projectID: "projectID",
   783  				bannerID:  "bannerID",
   784  			},
   785  			expTarget: Target{},
   786  			errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: store s not found in given banner b"),
   787  		},
   788  		"Returned TerminalID nil": {
   789  			payload: ResolveTargetPayload{
   790  				Target: Target{
   791  					ProjectID:  "p",
   792  					BannerID:   "b",
   793  					StoreID:    "s",
   794  					TerminalID: "t",
   795  				},
   796  			},
   797  			ds: mockDatasetTestResolveTarget{
   798  				projectID: "projectID",
   799  				bannerID:  "bannerID",
   800  				storeID:   "storeID",
   801  			},
   802  			expTarget: Target{},
   803  			errAssert: EqualError("61202: Request Error - Invalid Target properties. Error: terminal t not found in given store s and banner b"),
   804  		},
   805  	}
   806  
   807  	for name, tc := range tests {
   808  		tc := tc
   809  		t.Run(name, func(t *testing.T) {
   810  			t.Parallel()
   811  
   812  			as, err := New(
   813  				Config{},
   814  				tc.ds,
   815  				nil,
   816  			)
   817  			assert.NoError(t, err)
   818  
   819  			target, err := as.ResolveTarget(context.Background(), tc.payload)
   820  
   821  			tc.errAssert(t, err)
   822  			assert.Equal(t, tc.expTarget, target)
   823  		})
   824  	}
   825  }
   826  
   827  func TestAuthorizeTarget(t *testing.T) {
   828  	t.Parallel()
   829  
   830  	tests := map[string]struct {
   831  		ctx            context.Context
   832  		target         Target
   833  		errorAssertion assert.ErrorAssertionFunc
   834  	}{
   835  		"Valid": {
   836  			ctx: types.UserIntoContext(context.Background(), types.User{
   837  				Banners:  []string{validBannerID},
   838  				Username: username,
   839  				Roles:    []string{role},
   840  			}),
   841  			target:         Target{BannerID: validBannerID},
   842  			errorAssertion: assert.NoError,
   843  		},
   844  		"Error, bannerID doesn't match": {
   845  			ctx: types.UserIntoContext(context.Background(), types.User{
   846  				Banners:  []string{validBannerID},
   847  				Username: username,
   848  				Roles:    []string{role},
   849  			}),
   850  			target:         Target{BannerID: "not-the-same-banner-id"},
   851  			errorAssertion: EqualError(apierror.E(apierror.ErrUserNotAuthorized, errors.New("banner not found in user struct"), "User was not assigned banner").Error()),
   852  		},
   853  		"Error, no user in context": {
   854  			ctx:            context.Background(),
   855  			target:         Target{BannerID: validBannerID},
   856  			errorAssertion: EqualError("user struct not in context"),
   857  		},
   858  		"Error no EARoles": {
   859  			ctx: types.UserIntoContext(context.Background(), types.User{
   860  				Banners:  []string{validBannerID},
   861  				Username: username,
   862  				Roles:    []string{},
   863  			}),
   864  			target:         Target{BannerID: validBannerID},
   865  			errorAssertion: EqualError(apierror.E(apierror.ErrUserMissingRoles, fmt.Errorf("no roles returned from userservice")).Error()),
   866  		},
   867  	}
   868  	for name, tc := range tests {
   869  		tc := tc
   870  		t.Run(name, func(t *testing.T) {
   871  			t.Parallel()
   872  			ds := mockDataset{}
   873  			uServer := userServiceServer()
   874  			defer uServer.Close()
   875  			as, err := New(
   876  				Config{UserServiceHost: uServer.URL[7:]},
   877  				ds,
   878  				nil,
   879  			)
   880  			assert.NoError(t, err)
   881  			err = as.AuthorizeTarget(tc.ctx, tc.target)
   882  			tc.errorAssertion(t, err)
   883  		})
   884  	}
   885  }
   886  
   887  func TestAuthorizeUser(t *testing.T) {
   888  	t.Parallel()
   889  
   890  	tests := map[string]struct {
   891  		user      types.User
   892  		errAssert assert.ErrorAssertionFunc
   893  	}{
   894  		"Valid": {
   895  			user:      types.User{Roles: []string{"role1", "role2", "role3"}},
   896  			errAssert: assert.NoError,
   897  		},
   898  		"Invalid No Roles": {
   899  			user:      types.User{},
   900  			errAssert: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"),
   901  		},
   902  		"Invalid No Return": {
   903  			user:      types.User{Roles: []string{""}},
   904  			errAssert: EqualError("60002: User Authorization Failure - User does not have the required roles. Error: no roles returned from userservice"),
   905  		},
   906  	}
   907  
   908  	for name, tc := range tests {
   909  		tc := tc
   910  		t.Run(name, func(t *testing.T) {
   911  			t.Parallel()
   912  
   913  			// setup
   914  			userServer := userServiceServer()
   915  			defer userServer.Close()
   916  
   917  			ctx := types.UserIntoContext(context.Background(), tc.user)
   918  
   919  			ds := mockDataset{}
   920  			as, err := New(
   921  				Config{UserServiceHost: userServer.URL[7:]},
   922  				ds,
   923  				nil,
   924  			)
   925  			assert.NoError(t, err)
   926  
   927  			err = as.AuthorizeUser(ctx)
   928  			tc.errAssert(t, err)
   929  		})
   930  	}
   931  }
   932  

View as plain text