
Source file src/edge-infra.dev/pkg/sds/emergencyaccess/authservice/server/server_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/authservice/server

     1  package server
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    16  	"edge-infra.dev/pkg/lib/fog"
    17  	"edge-infra.dev/pkg/sds/emergencyaccess/apierror"
    18  	apierrorhandler "edge-infra.dev/pkg/sds/emergencyaccess/apierror/handler"
    19  	"edge-infra.dev/pkg/sds/emergencyaccess/authservice"
    20  	"edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
    21  	"edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
    22  	"edge-infra.dev/pkg/sds/emergencyaccess/types"
    24  	"github.com/gin-gonic/gin"
    25  	"github.com/go-logr/logr"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  )
    30  // testing helper type
    31  type helper interface {
    32  	Helper()
    33  }
    35  type StringAssertionFunc func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool
    37  func JSONEq(expected string) StringAssertionFunc {
    38  	return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool {
    39  		return assert.JSONEq(t, expected, actual, msgAndArgs...)
    40  	}
    41  }
    43  func StringEqual(expected string) StringAssertionFunc {
    44  	return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool {
    45  		return assert.Equal(t, expected, actual, msgAndArgs...)
    46  	}
    47  }
    49  func JSONEmpty() StringAssertionFunc {
    50  	return func(t assert.TestingT, actual string, msgAndArgs ...interface{}) bool {
    51  		return assert.Empty(t, actual, msgAndArgs...)
    52  	}
    53  }
    55  // assert.ErrorAssertionFunc that asserts the error is an api error with the given
    56  // code, and contains the given message in the error string
    57  func APIError(code apierror.ErrorCode, message string) assert.ErrorAssertionFunc {
    58  	return func(tt assert.TestingT, err error, i ...interface{}) bool {
    59  		if help, ok := tt.(helper); ok {
    60  			help.Helper()
    61  		}
    63  		if !assert.ErrorContains(tt, err, message, i...) {
    64  			return false
    65  		}
    67  		if !assert.Implements(tt, (*apierror.APIError)(nil), err, i...) {
    68  			return false
    69  		}
    71  		e := err.(apierror.APIError)
    72  		return assert.Equal(tt, code, e.Code(), i...)
    73  	}
    74  }
    76  // helper function which sets well known auth headers to any request
    77  func setAuthHeaders(req *http.Request) {
    78  	req.Header.Set(eaconst.HeaderAuthKeyUsername, "username")
    79  	req.Header.Set(eaconst.HeaderAuthKeyEmail, "email")
    80  	req.Header.Set(eaconst.HeaderAuthKeyRoles, "role")
    81  	req.Header.Set(eaconst.HeaderAuthKeyBanners, "banner")
    82  }
    84  // helper function to generate auth request.
    85  func newAuthRequest(method string, url string, body io.Reader) (*http.Request, error) {
    86  	req, err := http.NewRequest(method, url, body)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	setAuthHeaders(req)
    91  	return req, nil
    92  }
    94  type mockDataset struct {
    95  	authservice.Dataset
    96  }
    98  func userServiceServer() *httptest.Server {
    99  	mux := http.NewServeMux()
   100  	mux.HandleFunc("/eaRoles", func(w http.ResponseWriter, r *http.Request) {
   101  		values := r.URL.Query()
   103  		role := values.Get("role")
   104  		// role may be an empty string if no role was present
   105  		var roles []string
   106  		if role != "" {
   107  			roles = []string{role}
   108  		}
   110  		b, err := json.Marshal(roles)
   111  		if err != nil {
   112  			return
   113  		}
   114  		_, err = w.Write(b)
   115  		if err != nil {
   116  			return
   117  		}
   118  	})
   119  	server := httptest.NewServer(mux)
   120  	return server
   121  }
   123  func TestAuthorizeCommand(t *testing.T) {
   124  	tests := map[string]struct {
   125  		command string
   127  		expStatus int
   128  		expValid  bool
   129  	}{
   130  		"Pass StatusOK": {
   131  			`ls`,
   132  			http.StatusOK,
   133  			true,
   134  		},
   135  		"Fail StatusOK": {
   136  			`rm`,
   137  			http.StatusOK,
   138  			false,
   139  		},
   140  		"Pass StatusOK env var": {
   141  			`A=b ls`,
   142  			http.StatusOK,
   143  			true,
   144  		},
   145  		"Fail StatusOK env var": {
   146  			`A=b rm`,
   147  			http.StatusOK,
   148  			false,
   149  		},
   150  		"Pass StatusOK escaped quotation": {
   151  			`echo \"`,
   152  			http.StatusOK,
   153  			true,
   154  		},
   155  	}
   157  	for name, tc := range tests {
   158  		t.Run(name, func(t *testing.T) {
   159  			r := httptest.NewRecorder()
   161  			// Create Gin context in test mode
   162  			gin.SetMode(gin.TestMode)
   163  			_, ginEngine := gin.CreateTestContext(r)
   165  			// create mockservers
   166  			server := mockRulesEngineServer(tc.expStatus, tc.expValid)
   167  			userServer := userServiceServer()
   168  			defer server.Close()
   169  			defer userServer.Close()
   171  			// set up authservice
   172  			ds := mockDataset{}
   173  			as, err := authservice.New(
   174  				authservice.Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]},
   175  				ds,
   176  				nil,
   177  			)
   178  			assert.NoError(t, err)
   179  			log := fog.New()
   180  			_ = New(ginEngine, log, as)
   182  			// Send test query
   183  			payload, _ := json.Marshal(map[string]interface{}{
   184  				"Command": tc.command,
   185  				"Target": types.Target{
   186  					Bannerid: "a-banner-id",
   187  				},
   188  			})
   189  			req, err := newAuthRequest(http.MethodPost,
   190  				"/authorizeCommand",
   191  				bytes.NewBuffer(payload))
   193  			req.Header.Add(eaconst.HeaderAuthKeyBanners, "a-banner-id")
   194  			assert.NoError(t, err)
   195  			ginEngine.ServeHTTP(r, req)
   197  			//retrieve response
   198  			assert.Equal(t, tc.expStatus, r.Result().StatusCode)
   200  			var respData authservice.Validation
   201  			err = unmarshalBody(r.Body, &respData)
   202  			assert.NoError(t, err)
   203  			assert.Equal(t, tc.expValid, respData.Valid)
   204  		})
   205  	}
   206  }
   208  func TestAuthorizeCommandForbidden(t *testing.T) {
   209  	t.Parallel()
   210  	tests := map[string]struct {
   211  		headerBanners []string
   212  		payloadBanner string
   213  	}{
   214  		"No Banners in header": {
   215  			headerBanners: []string{},
   216  			payloadBanner: "a-banner-id",
   217  		},
   218  		"Wrong banner in payload": {
   219  			headerBanners: []string{"a-banner-id"},
   220  			payloadBanner: "another-banner-id",
   221  		},
   222  	}
   223  	for name, tc := range tests {
   224  		tc := tc
   225  		t.Run(name, func(t *testing.T) {
   226  			t.Parallel()
   227  			r := httptest.NewRecorder()
   229  			// Create Gin context in test mode
   230  			gin.SetMode(gin.TestMode)
   231  			_, ginEngine := gin.CreateTestContext(r)
   233  			// create mockservers
   234  			server := mockRulesEngineServer(200, true)
   235  			userServer := userServiceServer()
   236  			defer server.Close()
   237  			defer userServer.Close()
   239  			// set up authservice
   240  			ds := mockDataset{}
   241  			as, err := authservice.New(
   242  				authservice.Config{RulesEngineHost: server.URL[7:], UserServiceHost: userServer.URL[7:]},
   243  				ds,
   244  				nil,
   245  			)
   246  			assert.NoError(t, err)
   247  			log := fog.New()
   248  			_ = New(ginEngine, log, as)
   250  			// Send test query
   251  			payload, _ := json.Marshal(map[string]interface{}{
   252  				"Command": "ls",
   253  				"Target": types.Target{
   254  					Bannerid: tc.payloadBanner,
   255  				},
   256  			})
   257  			req, err := http.NewRequest(http.MethodPost,
   258  				"/authorizeCommand",
   259  				bytes.NewBuffer(payload))
   260  			for _, banner := range tc.headerBanners {
   261  				req.Header.Add(eaconst.HeaderAuthKeyBanners, banner)
   262  			}
   263  			req.Header.Add(eaconst.HeaderAuthKeyEmail, "user@ncr.com")
   264  			req.Header.Add(eaconst.HeaderAuthKeyUsername, "username")
   266  			assert.NoError(t, err)
   267  			ginEngine.ServeHTTP(r, req)
   269  			//retrieve response
   270  			assert.Equal(t, http.StatusForbidden, r.Result().StatusCode)
   271  		})
   272  	}
   273  }
   275  func TestAuthCommandAudit(t *testing.T) {
   276  	r := httptest.NewRecorder()
   278  	// Create Gin context in test mode
   279  	gin.SetMode(gin.TestMode)
   280  	_, ginEngine := gin.CreateTestContext(r)
   282  	// create mockservers. needed as we want a successful authorization to be logged.
   283  	rServer, uServer := mockRulesEngineServer(200, true), userServiceServer()
   284  	defer rServer.Close()
   285  	defer uServer.Close()
   287  	// set up authservice with logging to memory
   288  	ds := mockDataset{}
   289  	as, err := authservice.New(
   290  		authservice.Config{RulesEngineHost: rServer.URL[7:], UserServiceHost: uServer.URL[7:]},
   291  		ds,
   292  		nil,
   293  	)
   294  	assert.NoError(t, err)
   295  	b := bytes.Buffer{}
   296  	log := fog.New(fog.To(&b))
   297  	_ = New(ginEngine, log, as)
   299  	// Send test query
   300  	payload, _ := json.Marshal(map[string]interface{}{
   301  		"Command": "someCommand",
   302  		"EARoles": []string{"test"},
   303  		"Target": types.Target{
   304  			Projectid:  "a-project-id",
   305  			Bannerid:   "a-banner-id",
   306  			Storeid:    "a-store-id",
   307  			Terminalid: "a-terminal-id",
   308  		},
   309  	})
   310  	req, err := newAuthRequest(http.MethodPost,
   311  		"/authorizeCommand",
   312  		bytes.NewBuffer(payload))
   313  	assert.NoError(t, err)
   314  	req.Header.Add("X-Correlation-ID", "a-command-id")
   316  	ginEngine.ServeHTTP(r, req)
   318  	// Test.
   319  	validateAuditLog(t, &b, "Authorize Command Called", map[string]string{
   320  		"command":            "someCommand",
   321  		"requestID":          "a-command-id",
   322  		"userID":             "username",
   323  		"targetProjectID":    "a-project-id",
   324  		"targetBannerUUID":   "a-banner-id",
   325  		"targetStoreUUID":    "a-store-id",
   326  		"targetTerminalUUID": "a-terminal-id",
   327  	})
   328  }
   330  func validateAuditLog(t *testing.T, b *bytes.Buffer, logmsg string, keyVals map[string]string) {
   331  	// split the log on newline char to test whether the condition is satisfied in a single entry rather than across all logs
   332  	lst := strings.Split(b.String(), "\n")
   333  	var ok bool
   334  	for _, str := range lst {
   335  		// select the log with the correct log message
   336  		if ok = strings.Contains(str, logmsg); ok {
   337  			validateKeyValPairs(t, str, keyVals)
   338  			break
   339  		}
   340  	}
   341  	assert.True(t, ok, "log with message %q not found", logmsg)
   342  }
   343  func validateKeyValPairs(t *testing.T, logString string, keyVals map[string]string) {
   344  	for name, val := range keyVals {
   345  		// Bools aren't printed with quotes
   346  		if _, err := strconv.ParseBool(val); err == nil || name == "request" {
   347  			assert.Contains(t, logString, fmt.Sprintf("%q:%s", name, val))
   348  		} else {
   349  			assert.Contains(t, logString, fmt.Sprintf("%q:%q", name, val))
   350  		}
   351  	}
   352  }
   354  // only compares the command for each payload as we are not testing anything to do with target/identity
   355  func darkmodeServer(t *testing.T, expPayload authservice.RulesEnginePayload) *httptest.Server {
   356  	mux := http.NewServeMux()
   357  	mux.HandleFunc("/validatecommand", func(w http.ResponseWriter, r *http.Request) {
   358  		// read the incoming data
   359  		data, err := io.ReadAll(r.Body)
   360  		assert.NoError(t, err)
   361  		var in authservice.RulesEnginePayload
   362  		assert.NoError(t, json.Unmarshal(data, &in))
   363  		// test
   364  		assert.Equal(t, expPayload.Command, in.Command)
   365  		//write header after tests have completed
   366  		w.WriteHeader(200)
   367  		// write response
   368  		res := authservice.Response{Valid: true}
   369  		b, err := json.Marshal(res)
   370  		assert.NoError(t, err)
   371  		_, err = w.Write(b)
   372  		assert.NoError(t, err)
   373  	})
   374  	return httptest.NewServer(mux)
   375  }
   377  // checks the payload being sent from authservice to the reng matches expected payload
   378  func TestAuthorizeCommandDarkMode(t *testing.T) {
   379  	tests := map[string]struct {
   380  		payload    authservice.CommandAuthPayload
   381  		expPayload authservice.RulesEnginePayload
   382  	}{
   383  		"Darkmode true": {
   384  			payload: authservice.CommandAuthPayload{
   385  				Command:     "ls",
   386  				Target:      authservice.Target{BannerID: "a-banner-id"},
   387  				AuthDetails: authservice.AuthDetails{DarkMode: true},
   388  			},
   389  			expPayload: authservice.RulesEnginePayload{
   390  				Command: authservice.RulesEngineCommand{
   391  					Name: "dark",
   392  					Type: "command",
   393  				},
   394  			},
   395  		},
   396  		"Darkmode false": {
   397  			payload: authservice.CommandAuthPayload{
   398  				Command:     "ls",
   399  				Target:      authservice.Target{BannerID: "a-banner-id"},
   400  				AuthDetails: authservice.AuthDetails{DarkMode: false},
   401  			},
   402  			expPayload: authservice.RulesEnginePayload{
   403  				Command: authservice.RulesEngineCommand{
   404  					Name: "ls",
   405  					Type: "command",
   406  				},
   407  			},
   408  		},
   409  	}
   410  	for name, tc := range tests {
   411  		tc := tc
   412  		t.Run(name, func(t *testing.T) {
   413  			t.Parallel()
   414  			r := httptest.NewRecorder()
   416  			// Create Gin context in test mode
   417  			gin.SetMode(gin.TestMode)
   418  			_, ginEngine := gin.CreateTestContext(r)
   420  			// create mockservers
   421  			server := darkmodeServer(t, tc.expPayload)
   422  			userServer := userServiceServer()
   423  			defer server.Close()
   424  			defer userServer.Close()
   426  			// set up authservice
   427  			ds := mockDataset{}
   428  			as, err := authservice.New(
   429  				authservice.Config{
   430  					RulesEngineHost: server.URL[7:],
   431  					UserServiceHost: userServer.URL[7:],
   432  				},
   433  				ds,
   434  				nil,
   435  			)
   436  			assert.NoError(t, err)
   437  			log := fog.New()
   438  			_ = New(ginEngine, log, as)
   440  			// Send test query
   441  			payload, err := json.Marshal(tc.payload)
   442  			assert.NoError(t, err)
   443  			req, err := newAuthRequest(http.MethodPost,
   444  				"/authorizeCommand",
   445  				bytes.NewBuffer(payload))
   446  			assert.NoError(t, err)
   447  			req.Header.Add(eaconst.HeaderAuthKeyBanners, "a-banner-id")
   448  			ginEngine.ServeHTTP(r, req)
   450  			// wait for result from server
   451  			assert.Equal(t, 200, r.Result().StatusCode)
   452  		})
   453  	}
   454  }
   455  func TestAuthorizeCommandBadPayload(t *testing.T) {
   456  	tests := map[string]struct {
   457  		payload string
   459  		expStatus int
   460  		expErr    apierror.ErrorCode
   461  	}{
   462  		"Fail Status 400 invalid json": {
   463  			`{"command":"rm","earoles":["test"]`,
   464  			http.StatusBadRequest,
   465  			apierror.ErrPayloadStructure,
   466  		},
   467  		"Fail Status 400 no command with env var": {
   468  			`{"command":"A=b","target":{"bannerid":"a-banner-id"}}`,
   469  			http.StatusBadRequest,
   470  			apierror.ErrInvalidCommand,
   471  		},
   472  		"Fail Status 400 no command with env var darkmode": {
   473  			`{"command":"A=b","target":{"bannerid":"a-banner-id"},"authDetails":{"darkmode":true}}`,
   474  			http.StatusBadRequest,
   475  			apierror.ErrInvalidCommand,
   476  		},
   477  		"Fail Status 400 no command": {
   478  			`{"target":{"bannerid":"a-banner-id"}}`,
   479  			http.StatusBadRequest,
   480  			apierror.ErrPayloadProperties,
   481  		},
   482  		"Fail Status 400 no target": {
   483  			`{"command":"rm"}`,
   484  			http.StatusBadRequest,
   485  			apierror.ErrPayloadProperties,
   486  		},
   487  	}
   489  	for name, tc := range tests {
   490  		t.Run(name, func(t *testing.T) {
   491  			r := httptest.NewRecorder()
   493  			// Create Gin context in test mode
   494  			gin.SetMode(gin.TestMode)
   495  			_, ginEngine := gin.CreateTestContext(r)
   497  			// no need to set up rules engine server since it's never reached
   499  			// set up authservice
   500  			ds := mockDataset{}
   501  			as, err := authservice.New(
   502  				authservice.Config{},
   503  				ds,
   504  				nil,
   505  			)
   506  			assert.NoError(t, err)
   507  			_ = New(ginEngine, fog.New(), as)
   509  			// Send test query
   510  			req, err := newAuthRequest(http.MethodPost,
   511  				"/authorizeCommand",
   512  				strings.NewReader(tc.payload))
   513  			assert.NoError(t, err)
   514  			ginEngine.ServeHTTP(r, req)
   516  			//retrieve response
   517  			assert.Equal(t, tc.expStatus, r.Result().StatusCode)
   519  			var e apierrorhandler.ErrorResponse
   520  			err = unmarshalBody(r.Body, &e)
   521  			assert.NoError(t, err)
   522  			assert.Equal(t, tc.expErr, e.ErrorCode)
   523  		})
   524  	}
   525  }
   527  func TestAuthorizeCommandBadCommand(t *testing.T) {
   528  	tests := map[string]struct {
   529  		command string
   531  		expStatus int
   532  		expErr    apierror.ErrorCode
   533  	}{
   534  		"Fail Status 400 quotation mark": {
   535  			`echo "`,
   536  			http.StatusBadRequest,
   537  			apierror.ErrInvalidCommand,
   538  		},
   539  		"Fail Status 400 double escaped quotation mark": {
   540  			`echo \\"`,
   541  			http.StatusBadRequest,
   542  			apierror.ErrInvalidCommand,
   543  		},
   544  	}
   546  	for name, tc := range tests {
   547  		t.Run(name, func(t *testing.T) {
   548  			r := httptest.NewRecorder()
   550  			// Create Gin context in test mode
   551  			gin.SetMode(gin.TestMode)
   552  			_, ginEngine := gin.CreateTestContext(r)
   554  			// set up authservice
   555  			ds := mockDataset{}
   556  			as, err := authservice.New(authservice.Config{}, ds, nil)
   557  			assert.NoError(t, err)
   558  			_ = New(ginEngine, fog.New(), as)
   560  			// Send test query
   561  			payload, _ := json.Marshal(map[string]interface{}{
   562  				"Command": tc.command,
   563  				"EARoles": []string{"test"},
   564  				"Target": types.Target{
   565  					Bannerid: "a-banner-id",
   566  				},
   567  			})
   568  			req, err := newAuthRequest(http.MethodPost,
   569  				"/authorizeCommand",
   570  				bytes.NewBuffer(payload))
   571  			assert.NoError(t, err)
   572  			ginEngine.ServeHTTP(r, req)
   574  			//retrieve response
   575  			assert.Equal(t, tc.expStatus, r.Result().StatusCode)
   577  			var e apierrorhandler.ErrorResponse
   578  			err = unmarshalBody(r.Body, &e)
   579  			assert.NoError(t, err)
   580  			assert.Equal(t, tc.expErr, e.ErrorCode)
   581  		})
   582  	}
   583  }
   585  func unmarshalBody(body *bytes.Buffer, v any) error {
   586  	data, err := io.ReadAll(body)
   587  	if err != nil {
   588  		return err
   589  	}
   590  	return json.Unmarshal(data, v)
   591  }
   593  type mockAuthService struct {
   594  	authorizeCommand func(ctx context.Context, payload authservice.CommandAuthPayload) (authservice.Validation, error)
   595  	authorizeRequest func(ctx context.Context, payload authservice.AuthorizeRequestPayload) (msgdata.Request, error)
   596  	authorizeTarget  func(ctx context.Context, target authservice.Target) error
   597  	authorizeUser    func(ctx context.Context) error
   598  	resolveTarget    func(ctx context.Context, payload authservice.ResolveTargetPayload) (authservice.Target, error)
   599  }
   601  func (mas mockAuthService) AuthorizeCommand(ctx context.Context, payload authservice.CommandAuthPayload) (authservice.Validation, error) {
   602  	return mas.authorizeCommand(ctx, payload)
   603  }
   605  func (mas mockAuthService) AuthorizeRequest(ctx context.Context, payload authservice.AuthorizeRequestPayload) (msgdata.Request, error) {
   606  	return mas.authorizeRequest(ctx, payload)
   607  }
   609  func (mas mockAuthService) AuthorizeTarget(ctx context.Context, target authservice.Target) error {
   610  	return mas.authorizeTarget(ctx, target)
   611  }
   613  func (mas mockAuthService) AuthorizeUser(ctx context.Context) error {
   614  	return mas.authorizeUser(ctx)
   615  }
   617  func (mas mockAuthService) ResolveTarget(ctx context.Context, payload authservice.ResolveTargetPayload) (authservice.Target, error) {
   618  	return mas.resolveTarget(ctx, payload)
   619  }
   621  func TestAuthorizeRequestSuccess(t *testing.T) {
   622  	t.Parallel()
   624  	tests := map[string]struct {
   625  		payload            []byte
   626  		expectedData       string
   627  		expectedAttributes map[string]string
   628  	}{
   629  		"1.0 Command": {
   630  			payload: []byte(`{
   631  				"request": {
   632  					"data": {
   633  						"command": "echo hello there"
   634  					},
   635  					"attributes": {
   636  						"version": "1.0",
   637  						"type": "command"
   638  					}
   639  				},
   640  				"target": {
   641  					"projectID": "project",
   642  					"bannerID": "banner",
   643  					"storeID": "store",
   644  					"terminalID": "terminal"
   645  				}
   646  			}`),
   647  			expectedData: `{
   648  				"command": "echo hello there"
   649  			}`,
   650  			expectedAttributes: map[string]string{
   651  				eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
   652  				eaconst.RequestTypeKey: string(eaconst.Command),
   653  			},
   654  		},
   655  		"2.0 Command": {
   656  			payload: []byte(`{
   657  				"request": {
   658  					"data": {
   659  						"command": "echo",
   660  						"args": ["hello", "there"]
   661  					},
   662  					"attributes": {
   663  						"version": "2.0",
   664  						"type": "command"
   665  					}
   666  				},
   667  				"target": {
   668  					"projectID": "project",
   669  					"bannerID": "banner",
   670  					"storeID": "store",
   671  					"terminalID": "terminal"
   672  				}
   673  			}`),
   674  			expectedData: `{
   675  				"command": "echo",
   676  				"args": ["hello", "there"]
   677  			}`,
   678  			expectedAttributes: map[string]string{
   679  				eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
   680  				eaconst.RequestTypeKey: string(eaconst.Command),
   681  			},
   682  		},
   683  	}
   685  	for name, tc := range tests {
   686  		tc := tc
   687  		t.Run(name, func(t *testing.T) {
   688  			t.Parallel()
   690  			r := httptest.NewRecorder()
   692  			// Create Gin context in test mode
   693  			gin.SetMode(gin.TestMode)
   694  			_, ginEngine := gin.CreateTestContext(r)
   696  			ruleServer := mockRulesEngineServer(http.StatusOK, true)
   697  			userServer := userServiceServer()
   698  			defer ruleServer.Close()
   699  			defer userServer.Close()
   701  			// set up authservice
   702  			ds := mockDataset{}
   703  			as, err := authservice.New(
   704  				authservice.Config{RulesEngineHost: ruleServer.URL[7:], UserServiceHost: userServer.URL[7:]},
   705  				ds,
   706  				nil,
   707  			)
   708  			require.NoError(t, err)
   709  			log := fog.New()
   710  			_ = New(ginEngine, log, as)
   712  			req, err := newAuthRequest(http.MethodPost, "/authorizeRequest", bytes.NewBuffer(tc.payload))
   713  			require.NoError(t, err)
   714  			ginEngine.ServeHTTP(r, req)
   716  			//retrieve response
   717  			assert.Equal(t, http.StatusOK, r.Result().StatusCode)
   719  			var resp struct {
   720  				Request authservice.Request
   721  			}
   722  			err = unmarshalBody(r.Body, &resp)
   723  			assert.NoError(t, err)
   724  			data, err := json.Marshal(resp.Request.Data)
   725  			assert.NoError(t, err)
   727  			assert.JSONEq(t, tc.expectedData, string(data))
   728  			assert.Equal(t, tc.expectedAttributes, resp.Request.Attributes)
   729  		})
   730  	}
   731  }
   733  func TestAuthorizeRequestFail(t *testing.T) {
   734  	t.Parallel()
   736  	tests := map[string]struct {
   737  		payload          []byte
   738  		authorizeRequest func(context.Context, authservice.AuthorizeRequestPayload) (msgdata.Request, error)
   739  		expStatus        int
   740  		expError         apierror.ErrorCode
   741  	}{
   742  		"Invalid Payload Structure": {
   743  			payload: []byte(`{
   744  				"request": {
   745  					"data": {
   746  						"command": "echo he}`),
   747  			expStatus: http.StatusBadRequest,
   748  			expError:  apierror.ErrPayloadStructure,
   749  		},
   750  		"Invalid Payload Details": {
   751  			payload: []byte(`{
   752  				"request": {
   753  					"data": {
   754  						"command": ""
   755  					},
   756  					"attributes": {
   757  						"version": "1.0",
   758  						"type": "command"
   759  					}
   760  				},
   761  				"target": {
   762  					"projectID": "project",
   763  					"bannerID": "",
   764  					"storeID": "store",
   765  					"terminalID": "terminal"
   766  				}
   767  			}`),
   768  			expStatus: http.StatusBadRequest,
   769  			expError:  apierror.ErrPayloadProperties,
   770  		},
   771  		"Send Failure": {
   772  			payload: []byte(`{
   773  				"request": {
   774  					"data": {
   775  						"command": "echo hello there"
   776  					},
   777  					"attributes": {
   778  						"version": "1.0",
   779  						"type": "command"
   780  					}
   781  				},
   782  				"target": {
   783  					"projectID": "project",
   784  					"bannerID": "banner",
   785  					"storeID": "store",
   786  					"terminalID": "terminal"
   787  				}
   788  			}`),
   789  			authorizeRequest: func(_ context.Context, _ authservice.AuthorizeRequestPayload) (msgdata.Request, error) {
   790  				return nil, errors.New("error")
   791  			},
   792  			expStatus: http.StatusInternalServerError,
   793  			expError:  apierror.ErrSendFailure,
   794  		},
   795  		"Unauthorized Command": {
   796  			payload: []byte(`{
   797  				"request": {
   798  					"data": {
   799  						"command": "echo hello there"
   800  					},
   801  					"attributes": {
   802  						"version": "1.0",
   803  						"type": "command"
   804  					}
   805  				},
   806  				"target": {
   807  					"projectID": "project",
   808  					"bannerID": "banner",
   809  					"storeID": "store",
   810  					"terminalID": "terminal"
   811  				}
   812  			}`),
   813  			authorizeRequest: func(_ context.Context, _ authservice.AuthorizeRequestPayload) (msgdata.Request, error) {
   814  				return nil, apierror.E(apierror.ErrUnauthorizedCommand, errors.New("error"))
   815  			},
   816  			expStatus: http.StatusForbidden,
   817  			expError:  apierror.ErrUnauthorizedCommand,
   818  		},
   819  	}
   821  	for name, tc := range tests {
   822  		tc := tc
   823  		t.Run(name, func(t *testing.T) {
   824  			t.Parallel()
   826  			r := httptest.NewRecorder()
   828  			// Create Gin context in test mode
   829  			gin.SetMode(gin.TestMode)
   830  			_, ginEngine := gin.CreateTestContext(r)
   832  			// set up authservice
   833  			as := mockAuthService{
   834  				authorizeRequest: tc.authorizeRequest,
   835  			}
   836  			log := fog.New()
   837  			_ = New(ginEngine, log, as)
   839  			req, err := newAuthRequest(http.MethodPost, "/authorizeRequest", bytes.NewBuffer(tc.payload))
   840  			require.NoError(t, err)
   841  			ginEngine.ServeHTTP(r, req)
   843  			//retrieve response
   844  			assert.Equal(t, tc.expStatus, r.Result().StatusCode)
   846  			var e apierrorhandler.ErrorResponse
   847  			err = unmarshalBody(r.Body, &e)
   848  			assert.NoError(t, err)
   849  			assert.Equal(t, tc.expError, e.ErrorCode)
   850  		})
   851  	}
   852  }
   854  func TestAuthorizeRequestAudit(t *testing.T) {
   855  	t.Parallel()
   857  	tests := map[string]struct {
   858  		authorizeRequest func(context.Context, authservice.AuthorizeRequestPayload) (msgdata.Request, error)
   859  		expAuth          bool
   860  	}{
   861  		"Success": {
   862  			authorizeRequest: func(_ context.Context, _ authservice.AuthorizeRequestPayload) (msgdata.Request, error) {
   863  				return nil, nil
   864  			},
   865  			expAuth: true,
   866  		},
   867  		"Unauthorized": {
   868  			authorizeRequest: func(_ context.Context, _ authservice.AuthorizeRequestPayload) (msgdata.Request, error) {
   869  				return nil, errors.New("error")
   870  			},
   871  			expAuth: false,
   872  		},
   873  	}
   875  	for name, tc := range tests {
   876  		tc := tc
   877  		t.Run(name, func(t *testing.T) {
   878  			t.Parallel()
   880  			// Create Gin context in test mode
   881  			r := httptest.NewRecorder()
   882  			gin.SetMode(gin.TestMode)
   883  			_, ginEngine := gin.CreateTestContext(r)
   885  			// set up authservice with logging to memory
   886  			b := bytes.Buffer{}
   887  			log := fog.New(fog.To(&b))
   888  			as := mockAuthService{authorizeRequest: tc.authorizeRequest}
   889  			_ = New(ginEngine, log, as)
   891  			// Create request map for log comparison later
   892  			requestMap := map[string]map[string]string{
   893  				"data": {
   894  					"command": "echo hello there",
   895  				},
   896  				"attributes": {
   897  					"type":    "command",
   898  					"version": "1.0",
   899  				},
   900  			}
   901  			requestBytes, err := json.Marshal(requestMap)
   902  			assert.NoError(t, err)
   903  			payload, err := json.Marshal(map[string]interface{}{
   904  				"request": requestMap,
   905  				"target": authservice.Target{
   906  					ProjectID:  "project",
   907  					BannerID:   "banner",
   908  					StoreID:    "store",
   909  					TerminalID: "terminal",
   910  				},
   911  			})
   912  			require.NoError(t, err)
   913  			req, err := newAuthRequest(http.MethodPost, "/authorizeRequest", bytes.NewBuffer(payload))
   914  			require.NoError(t, err)
   915  			req.Header.Add("X-Correlation-ID", "a-command-id")
   917  			ginEngine.ServeHTTP(r, req)
   919  			// Retrieve audit log line and convert to map
   920  			auditLog, err := getAuditLogString(&b, "Authorize Request Called")
   921  			assert.NoError(t, err)
   922  			var auditMap map[string]interface{}
   923  			err = json.Unmarshal([]byte(auditLog), &auditMap)
   924  			assert.NoError(t, err)
   926  			// Convert audit log "request" value to JSON and compare with actual request
   927  			auditReqMap, ok := auditMap["request"].(map[string]interface{})
   928  			assert.True(t, ok)
   929  			auditReqBytes, err := json.Marshal(auditReqMap)
   930  			assert.NoError(t, err)
   931  			assert.JSONEq(t, string(requestBytes), string(auditReqBytes))
   933  			// Iterate through the rest of the expected audit values
   934  			expectedLogKeyVals := map[string]interface{}{
   935  				"requestID":          "a-command-id",
   936  				"userID":             "username",
   937  				"targetProjectID":    "project",
   938  				"targetBannerUUID":   "banner",
   939  				"targetStoreUUID":    "store",
   940  				"targetTerminalUUID": "terminal",
   941  				"commandAuthorized":  tc.expAuth,
   942  			}
   943  			for expKey, expVal := range expectedLogKeyVals {
   944  				auditVal, ok := auditMap[expKey]
   945  				assert.True(t, ok)
   946  				assert.Equal(t, expVal, auditVal)
   947  			}
   948  		})
   949  	}
   950  }
   952  func getAuditLogString(b *bytes.Buffer, logmsg string) (string, error) {
   953  	// split the log on newline char to test whether the condition is satisfied in a single entry rather than across all logs
   954  	lst := strings.Split(b.String(), "\n")
   955  	var ok bool
   956  	for _, str := range lst {
   957  		// select the log with the correct log message
   958  		if ok = strings.Contains(str, logmsg); ok {
   959  			return str, nil
   960  		}
   961  	}
   962  	return "", errors.New("could not find matching log message")
   963  }
   965  func mockRulesEngineServer(statusCode int, valid bool) *httptest.Server {
   966  	mux := http.NewServeMux()
   967  	mux.HandleFunc("/validatecommand", func(w http.ResponseWriter, _ *http.Request) {
   968  		w.WriteHeader(statusCode)
   969  		res := authservice.Response{Valid: valid}
   970  		b, err := json.Marshal(res)
   971  		if err != nil {
   972  			return
   973  		}
   974  		_, err = w.Write(b)
   975  		if err != nil {
   976  			return
   977  		}
   978  	})
   979  	return httptest.NewServer(mux)
   980  }
   982  type mockDatasetTestResolveTarget struct {
   983  	authservice.Dataset
   985  	projectID  string
   986  	bannerID   string
   987  	storeID    string
   988  	terminalID string
   989  }
   991  const errVal = "err"
   993  func (ds mockDatasetTestResolveTarget) GetProjectAndBannerID(_ context.Context, banner string) (projectID string, bannerID string, err error) {
   994  	if banner == errVal {
   995  		err = fmt.Errorf("error GetProjectIDAndBannerID")
   996  	}
   997  	return ds.projectID, ds.bannerID, err
   998  }
  1000  func (ds mockDatasetTestResolveTarget) GetStoreID(_ context.Context, store, _ string) (storeID string, err error) {
  1001  	if store == errVal {
  1002  		err = fmt.Errorf("error GetStoreID")
  1003  	}
  1004  	return ds.storeID, err
  1005  }
  1007  func (ds mockDatasetTestResolveTarget) GetTerminalID(_ context.Context, terminal, _ string) (terminalID string, err error) {
  1008  	if terminal == errVal {
  1009  		err = fmt.Errorf("error GetTerminalID")
  1010  	}
  1011  	return ds.terminalID, err
  1012  }
  1014  func TestResolveTarget(t *testing.T) {
  1015  	t.Parallel()
  1017  	tests := map[string]struct {
  1018  		data      []byte
  1019  		ds        mockDatasetTestResolveTarget
  1020  		expCode   int
  1021  		expOutput StringAssertionFunc
  1022  	}{
  1023  		"Valid": {
  1024  			data: []byte(`{
  1025  				"target": {
  1026  					"bannerid": "b",
  1027  					"storeid": "s",
  1028  					"terminalid": "t"
  1029  				}
  1030  			}`),
  1031  			ds: mockDatasetTestResolveTarget{
  1032  				projectID:  "projectID",
  1033  				bannerID:   "bannerID",
  1034  				storeID:    "storeID",
  1035  				terminalID: "terminalID",
  1036  			},
  1037  			expCode: http.StatusOK,
  1038  			expOutput: JSONEq(`{
  1039  				"target": {
  1040  					"projectid": "projectID",
  1041  					"bannerid": "bannerID",
  1042  					"storeid": "storeID",
  1043  					"terminalid": "terminalID"
  1044  				}
  1045  			}`),
  1046  		},
  1047  		"Bad Payload": {
  1048  			data:      []byte(`{"targ}`),
  1049  			expCode:   http.StatusBadRequest,
  1050  			expOutput: JSONEq(`{"errorCode":60201, "errorMessage":"Request Error - Invalid payload structure"}`),
  1051  		},
  1052  		"Invalid Payload": {
  1053  			data: []byte(`{
  1054  				"target": {}
  1055  			}`),
  1056  			expCode:   http.StatusBadRequest,
  1057  			expOutput: JSONEq(`{"errorCode":60202,"errorMessage":"Request Error - Invalid payload properties","details":["Payload missing banner ID","Payload missing store ID","Payload missing terminal ID"]}`),
  1058  		},
  1059  		"Failed To Authorize": {
  1060  			data: []byte(`{
  1061  				"target": {
  1062  					"bannerid": "err",
  1063  					"storeid": "err",
  1064  					"terminalid": "err"
  1065  				}
  1066  			}`),
  1067  			expCode:   http.StatusInternalServerError,
  1068  			expOutput: JSONEq(`{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`),
  1069  		},
  1070  	}
  1072  	for name, tc := range tests {
  1073  		tc := tc
  1074  		t.Run(name, func(t *testing.T) {
  1075  			t.Parallel()
  1077  			r := httptest.NewRecorder()
  1078  			gin.SetMode(gin.TestMode)
  1079  			_, ginEngine := gin.CreateTestContext(r)
  1081  			as, err := authservice.New(
  1082  				authservice.Config{},
  1083  				tc.ds,
  1084  				nil,
  1085  			)
  1086  			assert.NoError(t, err)
  1087  			_ = New(ginEngine, logr.Discard(), as)
  1089  			req, err := newAuthRequest(http.MethodPost,
  1090  				"/resolveTarget",
  1091  				bytes.NewBuffer(tc.data))
  1092  			assert.NoError(t, err)
  1093  			ginEngine.ServeHTTP(r, req)
  1095  			assert.Equal(t, tc.expCode, r.Result().StatusCode)
  1097  			data, err := io.ReadAll(r.Body)
  1098  			assert.NoError(t, err)
  1099  			tc.expOutput(t, string(data))
  1100  		})
  1101  	}
  1102  }
  1104  const (
  1105  	uuid1         = "78587bb1-6ca2-4d2d-a223-1ee642514b97"
  1106  	uuid2         = "35cc70eb-689d-49d4-8bd8-fa1cb8b0928f"
  1107  	uuid3         = "79bf815d-8e64-4b01-b12e-1f173a322766"
  1108  	uuid4         = "113f6c32-5501-44ba-9cd5-76530be5aa67"
  1109  	payloadString = `
  1110  	{
  1111  	"target": {
  1112  		"projectid":"%s",
  1113  		"bannerid": "%s",
  1114  		"storeid": "%s",
  1115  		"terminalid": "%s"
  1116  	}
  1117  	}`
  1118  )
  1120  func TestAuthorizeTarget(t *testing.T) {
  1121  	// Testing whether the endpoint returns correct error codes. There are no strict requirements on UUIDs for this endpoint,
  1122  	// but the UUID/string parse structure has been left incase this changes.
  1123  	tests := map[string]struct {
  1124  		data     []byte
  1125  		expCode  int
  1126  		bannerID string
  1127  	}{
  1128  		"Valid": {
  1129  			data:     []byte(fmt.Sprintf(payloadString, uuid1, uuid2, uuid3, uuid4)),
  1130  			expCode:  200,
  1131  			bannerID: uuid2,
  1132  		},
  1133  		"Wrong bannerID (forbidden)": {
  1134  			data:     []byte(fmt.Sprintf(payloadString, uuid1, uuid2, uuid3, uuid4)),
  1135  			expCode:  403,
  1136  			bannerID: uuid1,
  1137  		},
  1138  		"Bad payload": {
  1139  			data:     []byte("{"),
  1140  			expCode:  400,
  1141  			bannerID: uuid1,
  1142  		},
  1143  	}
  1144  	for name, tc := range tests {
  1145  		tc := tc
  1146  		t.Run(name, func(t *testing.T) {
  1147  			t.Parallel()
  1149  			r := httptest.NewRecorder()
  1150  			gin.SetMode(gin.TestMode)
  1151  			_, ginEngine := gin.CreateTestContext(r)
  1152  			// userservice server
  1153  			uServer := userServiceServer()
  1154  			defer uServer.Close()
  1155  			as, err := authservice.New(
  1156  				authservice.Config{UserServiceHost: uServer.URL[7:]},
  1157  				mockDataset{},
  1158  				nil,
  1159  			)
  1160  			assert.NoError(t, err)
  1161  			_ = New(ginEngine, logr.Discard(), as)
  1163  			req, err := newAuthRequest(http.MethodPost,
  1164  				"/authorizeTarget",
  1165  				bytes.NewBuffer(tc.data))
  1166  			assert.NoError(t, err)
  1167  			req.Header.Add(eaconst.HeaderAuthKeyBanners, tc.bannerID)
  1168  			ginEngine.ServeHTTP(r, req)
  1170  			assert.Equal(t, tc.expCode, r.Result().StatusCode)
  1171  		})
  1172  	}
  1173  }
  1174  func TestAuthTargetAudit(t *testing.T) {
  1175  	// Setup
  1176  	r := httptest.NewRecorder()
  1177  	gin.SetMode(gin.TestMode)
  1178  	_, ginEngine := gin.CreateTestContext(r)
  1179  	// userservice server
  1180  	uServer := userServiceServer()
  1181  	defer uServer.Close()
  1182  	// new dataset
  1183  	ds := mockDataset{}
  1184  	// new authservice. rules engine not needed for authTarget.
  1185  	as, err := authservice.New(
  1186  		authservice.Config{UserServiceHost: uServer.URL[7:]},
  1187  		ds,
  1188  		nil,
  1189  	)
  1190  	assert.NoError(t, err)
  1192  	// logging to memory so it can be read from test
  1193  	b := bytes.Buffer{}
  1194  	log := fog.New(fog.To(&b))
  1195  	_ = New(ginEngine, log, as)
  1197  	// setup request
  1198  	req, err := newAuthRequest(http.MethodPost,
  1199  		"/authorizeTarget",
  1200  		bytes.NewBuffer([]byte(fmt.Sprintf(payloadString, uuid1, uuid2, uuid3, uuid4))))
  1201  	assert.NoError(t, err)
  1202  	req.Header.Add(eaconst.HeaderAuthKeyBanners, uuid2)
  1203  	// serve request
  1204  	ginEngine.ServeHTTP(r, req)
  1205  	assert.Equal(t, 200, r.Result().StatusCode)
  1207  	// Test
  1208  	validateAuditLog(t, &b, "Authorize Target Called", map[string]string{
  1209  		"userID":             "username",
  1210  		"targetProjectID":    uuid1,
  1211  		"targetBannerUUID":   uuid2,
  1212  		"targetStoreUUID":    uuid3,
  1213  		"targetTerminalUUID": uuid4,
  1214  	})
  1215  }
  1217  func TestAuthorizeUser(t *testing.T) {
  1218  	t.Parallel()
  1220  	tests := map[string]struct {
  1221  		ctx            context.Context
  1222  		setAuthHeaders func(req *http.Request)
  1223  		expCode        int
  1224  	}{
  1225  		"Valid": {
  1226  			setAuthHeaders: setAuthHeaders,
  1227  			expCode:        http.StatusOK,
  1228  		},
  1229  		"No User": {
  1230  			ctx:            context.Background(),
  1231  			setAuthHeaders: func(_ *http.Request) {},
  1232  			expCode:        http.StatusForbidden,
  1233  		},
  1234  		"Invalid Roles": {
  1235  			setAuthHeaders: func(req *http.Request) {
  1236  				req.Header.Set(eaconst.HeaderAuthKeyUsername, "username")
  1237  				req.Header.Set(eaconst.HeaderAuthKeyEmail, "email")
  1238  				req.Header.Set(eaconst.HeaderAuthKeyBanners, "banner")
  1239  			},
  1240  			expCode: http.StatusForbidden,
  1241  		},
  1242  	}
  1244  	for name, tc := range tests {
  1245  		tc := tc
  1246  		t.Run(name, func(t *testing.T) {
  1247  			t.Parallel()
  1249  			r := httptest.NewRecorder()
  1250  			gin.SetMode(gin.TestMode)
  1251  			_, ginEngine := gin.CreateTestContext(r)
  1253  			// Setup
  1254  			userServer := userServiceServer()
  1255  			defer userServer.Close()
  1257  			ds := mockDataset{}
  1258  			as, err := authservice.New(
  1259  				authservice.Config{UserServiceHost: userServer.URL[7:]},
  1260  				ds,
  1261  				nil,
  1262  			)
  1263  			assert.NoError(t, err)
  1264  			_ = New(ginEngine, fog.New(), as)
  1266  			// Test
  1267  			req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/authorizeUser", nil)
  1268  			assert.NoError(t, err)
  1269  			tc.setAuthHeaders(req)
  1270  			ginEngine.ServeHTTP(r, req)
  1272  			assert.Equal(t, tc.expCode, r.Result().StatusCode)
  1273  		})
  1274  	}
  1275  }
  1277  func TestHealth(t *testing.T) {
  1278  	t.Parallel()
  1280  	tests := map[string]struct {
  1281  		checks []func() error
  1283  		expCode int
  1284  		expData string
  1285  	}{
  1286  		"No checks": {
  1287  			checks:  nil,
  1288  			expCode: http.StatusOK,
  1289  			expData: "ok",
  1290  		},
  1291  		"Passing check": {
  1292  			checks:  []func() error{func() error { return nil }},
  1293  			expCode: http.StatusOK,
  1294  			expData: "ok",
  1295  		},
  1296  		"Failing check": {
  1297  			checks:  []func() error{func() error { return fmt.Errorf("this is bad") }},
  1298  			expCode: http.StatusServiceUnavailable,
  1299  			expData: "failed health check: this is bad",
  1300  		},
  1301  		"Two checks": {
  1302  			checks:  []func() error{func() error { return nil }, func() error { return fmt.Errorf("this is bad") }},
  1303  			expCode: http.StatusServiceUnavailable,
  1304  			expData: "failed health check: this is bad",
  1305  		},
  1306  	}
  1308  	for name, tc := range tests {
  1309  		tc := tc
  1310  		t.Run(name, func(t *testing.T) {
  1311  			t.Parallel()
  1313  			r := httptest.NewRecorder()
  1314  			gin.SetMode(gin.TestMode)
  1315  			_, ginEngine := gin.CreateTestContext(r)
  1317  			as, err := authservice.New(
  1318  				authservice.Config{},
  1319  				nil,
  1320  				nil,
  1321  			)
  1322  			assert.NoError(t, err)
  1324  			_ = New(ginEngine, logr.Discard(), as, tc.checks...)
  1326  			req, err := newAuthRequest(http.MethodGet, "/health", nil)
  1327  			assert.NoError(t, err)
  1329  			ginEngine.ServeHTTP(r, req)
  1331  			assert.Equal(t, tc.expCode, r.Result().StatusCode)
  1333  			data, err := io.ReadAll(r.Body)
  1334  			assert.NoError(t, err)
  1335  			assert.Equal(t, tc.expData, string(data))
  1336  		})
  1337  	}
  1338  }

