...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/eagateway/server/startsession_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/eagateway/server

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"edge-infra.dev/pkg/lib/fog"
    16  	"edge-infra.dev/pkg/sds/emergencyaccess/apierror"
    17  	errorhandler "edge-infra.dev/pkg/sds/emergencyaccess/apierror/handler"
    18  	"edge-infra.dev/pkg/sds/emergencyaccess/eagateway"
    19  	"edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
    20  	"edge-infra.dev/pkg/sds/emergencyaccess/remotecli"
    21  	"edge-infra.dev/pkg/sds/emergencyaccess/types"
    22  
    23  	"github.com/gin-gonic/gin"
    24  	"github.com/google/uuid"
    25  	"github.com/stretchr/testify/assert"
    26  )
    27  
    28  type startSessionTestRCLI struct {
    29  	eagateway.RemoteCLI
    30  	displayCh chan<- msgdata.CommandResponse
    31  	target    remotecli.Target
    32  }
    33  
    34  func (rcli *startSessionTestRCLI) Send(_ context.Context, _, i, _ string, _ msgdata.Request, _ ...remotecli.RCLIOption) error {
    35  	attr := defaultAttrMap
    36  	attr["bannerId"] = i
    37  	response, err := msgdata.NewCommandResponse(defaultBytes, attr)
    38  	if err != nil {
    39  		return err
    40  	}
    41  	rcli.displayCh <- response
    42  	return nil
    43  }
    44  
    45  func (rcli *startSessionTestRCLI) StartSession(_ context.Context, sessionID string, displayCh chan<- msgdata.CommandResponse, target remotecli.Target, _ ...remotecli.RCLIOption) error {
    46  	if sessionID == "fail" {
    47  		return errTestRCLIStartSessionFail
    48  	}
    49  	rcli.target = target
    50  	rcli.displayCh = displayCh
    51  	return nil
    52  }
    53  
    54  type commandResponse struct {
    55  	Data       msgdata.ResponseData       `json:"data"`
    56  	Attributes msgdata.ResponseAttributes `json:"attributes"`
    57  }
    58  
    59  type ConnectionPayload struct {
    60  	Error   string          `json:"error"`
    61  	Message commandResponse `json:"message"`
    62  }
    63  
    64  func createStartSessionRequest(ctx context.Context, sessionID string, target types.Target) (req *http.Request, cancelFunc context.CancelFunc, err error) {
    65  	payload := types.StartSessionPayload{
    66  		SessionID: sessionID,
    67  		Target:    target,
    68  	}
    69  	message, err := json.Marshal(payload)
    70  	if err != nil {
    71  		return nil, nil, err
    72  	}
    73  	ctx, cancelFunc = context.WithCancel(ctx)
    74  	req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/ea/startSession", bytes.NewReader(message))
    75  	if err != nil {
    76  		return nil, cancelFunc, err
    77  	}
    78  	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token"))
    79  	req.Header.Set("Cache-Control", "no-cache")
    80  	req.Header.Set("Accept", "text/event-stream")
    81  	req.Header.Set("Connection", "keep-alive")
    82  
    83  	setAuthHeaders(req)
    84  
    85  	return req, cancelFunc, nil
    86  }
    87  
    88  func TestStartSession(t *testing.T) {
    89  	r := httptest.NewRecorder()
    90  	gin.SetMode(gin.TestMode)
    91  	c, ginEngine := gin.CreateTestContext(r)
    92  
    93  	var authServerFailures []error
    94  
    95  	// Setup
    96  	authServer, url := authserviceServer(http.StatusOK,
    97  		WithMiddleware(verifyUserAuthHeaders(t)),
    98  
    99  		// Returns the default target
   100  		WithResolveTarget(func(w http.ResponseWriter, _ *http.Request) {
   101  			// TODO validate passed in target is correct?
   102  
   103  			data, err := json.Marshal(map[string]types.Target{
   104  				"target": defaultTarget,
   105  			})
   106  			if err != nil {
   107  				authServerFailures = append(authServerFailures, err)
   108  			}
   109  
   110  			_, err = w.Write(data)
   111  			if err != nil {
   112  				authServerFailures = append(authServerFailures, err)
   113  			}
   114  		}),
   115  	)
   116  	t.Cleanup(authServer.Close)
   117  
   118  	var rcli = &startSessionTestRCLI{}
   119  	_, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil)
   120  	assert.NoError(t, err)
   121  
   122  	// Test
   123  	sessionID := "TestStartSession"
   124  	target := types.Target{
   125  		Bannerid:   "a-banner-id",
   126  		Storeid:    "a-store-id",
   127  		Terminalid: "a-terminal-id",
   128  	}
   129  	req, cancelFunc, err := createStartSessionRequest(c, sessionID, target)
   130  	assert.NoError(t, err)
   131  
   132  	isClosed := false
   133  	go func() {
   134  		ginEngine.ServeHTTP(r, req)
   135  		assert.Equal(t, http.StatusOK, r.Result().StatusCode)
   136  		isClosed = true
   137  	}()
   138  
   139  	time.Sleep(10 * time.Millisecond)
   140  
   141  	for i := 0; i < 2; i++ {
   142  		attr := defaultAttrMap
   143  		attr["bannerId"] = fmt.Sprintf("%d", i)
   144  		expected, err := msgdata.NewCommandResponse(defaultBytes, attr)
   145  		assert.NoError(t, err)
   146  
   147  		req, err := msgdata.NewV1_0Request("echo")
   148  		assert.NoError(t, err)
   149  
   150  		err = rcli.Send(c, "", fmt.Sprintf("%d", i), uuid.NewString(), req)
   151  		assert.NoError(t, err)
   152  
   153  		var buf []byte
   154  		assert.Eventually(t, func() bool {
   155  			buf = r.Body.Bytes()
   156  			return len(buf) != 0
   157  		}, 100*time.Millisecond, 20*time.Millisecond)
   158  		dec := json.NewDecoder(bytes.NewBuffer(buf))
   159  		var received ConnectionPayload
   160  		err = dec.Decode(&received)
   161  		assert.NoError(t, err)
   162  
   163  		assert.Equal(t, expected.Data(), received.Message.Data)
   164  		assert.Equal(t, expected.Attributes(), received.Message.Attributes)
   165  
   166  		r.Body.Reset()
   167  	}
   168  
   169  	cancelFunc()
   170  	assert.Eventually(t, func() bool {
   171  		return isClosed
   172  	}, 1*time.Second, 50*time.Millisecond)
   173  	assert.Equal(t, http.StatusOK, r.Result().StatusCode)
   174  
   175  	assert.Empty(t, authServerFailures)
   176  	assert.Equal(t, defaultTarget.Projectid, rcli.target.ProjectID())
   177  	assert.Equal(t, defaultTarget.Bannerid, rcli.target.BannerID())
   178  	assert.Equal(t, defaultTarget.Storeid, rcli.target.StoreID())
   179  	assert.Equal(t, defaultTarget.Terminalid, rcli.target.TerminalID())
   180  
   181  	// Validate target is returned via api
   182  	assert.Equal(t, defaultTarget.Projectid, r.Result().Header.Get("X-EA-ProjectID"))
   183  	assert.Equal(t, defaultTarget.Bannerid, r.Result().Header.Get("X-EA-BannerID"))
   184  	assert.Equal(t, defaultTarget.Storeid, r.Result().Header.Get("X-EA-StoreID"))
   185  	assert.Equal(t, defaultTarget.Terminalid, r.Result().Header.Get("X-EA-TerminalID"))
   186  }
   187  
   188  func TestStartSessionFail(t *testing.T) {
   189  	tests := map[string]struct {
   190  		payload             interface{}
   191  		status              int
   192  		resolveTargetFunc   func(w http.ResponseWriter, r *http.Request)
   193  		authorizeTargetFunc func(w http.ResponseWriter, r *http.Request)
   194  		err                 string
   195  	}{
   196  		"Payload JSON Bind Fail": {
   197  			payload: "not valid",
   198  			status:  http.StatusBadRequest,
   199  			err:     `{"errorCode":60201, "errorMessage":"Request Error - Invalid payload structure"}`,
   200  		},
   201  		"resolveTarget returns error": {
   202  			payload: types.StartSessionPayload{
   203  				SessionID: "session-ID",
   204  				Target:    defaultTarget,
   205  			},
   206  			resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
   207  				w.WriteHeader(http.StatusInternalServerError)
   208  				errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure))
   209  				bytes, _ := json.Marshal(errResp)
   210  				_, _ = w.Write(bytes)
   211  			},
   212  			status: http.StatusInternalServerError,
   213  			err:    `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`,
   214  		},
   215  		"resolveTarget returns bad request error": {
   216  			payload: types.StartSessionPayload{
   217  				SessionID: "session-ID",
   218  				Target:    defaultTarget,
   219  			},
   220  			resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
   221  				w.WriteHeader(http.StatusInternalServerError)
   222  				errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrInvalidTarget))
   223  				bytes, _ := json.Marshal(errResp)
   224  				_, _ = w.Write(bytes)
   225  			},
   226  			status: http.StatusBadRequest,
   227  			err:    `{"errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties"}`,
   228  		},
   229  
   230  		"resolveTarget Doesn't return target": {
   231  			payload: types.StartSessionPayload{
   232  				SessionID: "not valid",
   233  				Target: types.Target{
   234  					Projectid:  "",
   235  					Bannerid:   "",
   236  					Storeid:    "storeID",
   237  					Terminalid: "terminalID",
   238  				},
   239  			},
   240  			resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
   241  				data, _ := json.Marshal(map[string]types.Target{
   242  					"target": {
   243  						Projectid:  "",
   244  						Bannerid:   "",
   245  						Storeid:    "a-store-iD",
   246  						Terminalid: "a-terminalID",
   247  					},
   248  				})
   249  				_, _ = w.Write(data)
   250  			},
   251  			status: http.StatusBadRequest,
   252  			err:    `{"errorCode":60202, "details": ["Target missing project ID", "Target missing Banner ID"], "errorMessage":"Request Error - Invalid payload properties"}`,
   253  		},
   254  		"authorizeTarget returns non-ok status": {
   255  			payload: types.StartSessionPayload{
   256  				Target:    defaultTarget,
   257  				SessionID: "session-ID",
   258  			},
   259  			status: http.StatusInternalServerError,
   260  			authorizeTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
   261  				w.WriteHeader(http.StatusInternalServerError)
   262  				errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure))
   263  				bytes, _ := json.Marshal(errResp)
   264  				_, _ = w.Write(bytes)
   265  			},
   266  			err: `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`,
   267  		},
   268  		"Payload JSON missing properties": {
   269  			payload: types.StartSessionPayload{
   270  				SessionID: "",
   271  				Target: types.Target{
   272  					Projectid:  "",
   273  					Bannerid:   "",
   274  					Storeid:    "storeID",
   275  					Terminalid: "terminalID",
   276  				},
   277  			},
   278  			status: http.StatusBadRequest,
   279  			err:    `{"errorCode":60202, "details":["Payload missing Session ID"], "errorMessage":"Request Error - Invalid payload properties"}`,
   280  		},
   281  		"RCLI Start Session Fail": {
   282  			payload: types.StartSessionPayload{
   283  				SessionID: "fail",
   284  				Target: types.Target{
   285  					Projectid:  "projectID",
   286  					Bannerid:   "bannerID",
   287  					Storeid:    "storeID",
   288  					Terminalid: "terminalID",
   289  				},
   290  			},
   291  			status: http.StatusInternalServerError,
   292  			err:    `{"errorCode":61101, "errorMessage":"Subscription failure - Failed to initialize subscription"}`,
   293  		},
   294  	}
   295  
   296  	for name, tc := range tests {
   297  		t.Run(name, func(t *testing.T) {
   298  			r := httptest.NewRecorder()
   299  			gin.SetMode(gin.TestMode)
   300  			_, ginEngine := gin.CreateTestContext(r)
   301  
   302  			// Setup
   303  			resolveTargetFunc := defaultResolveTarget()
   304  			if tc.resolveTargetFunc != nil {
   305  				resolveTargetFunc = tc.resolveTargetFunc
   306  			}
   307  			authorizeTargetFunc := defaultAuthorizeTarget(http.StatusOK)
   308  			if tc.authorizeTargetFunc != nil {
   309  				authorizeTargetFunc = tc.authorizeTargetFunc
   310  			}
   311  			authServer, url := authserviceServer(
   312  				http.StatusOK,
   313  				WithResolveTarget(resolveTargetFunc),
   314  				WithAuthorizeTarget(authorizeTargetFunc),
   315  			)
   316  			defer authServer.Close()
   317  
   318  			var rcli eagateway.RemoteCLI = &startSessionTestRCLI{}
   319  			_, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil)
   320  			assert.NoError(t, err)
   321  
   322  			message, err := json.Marshal(tc.payload)
   323  			assert.NoError(t, err)
   324  			req, err := http.NewRequest(http.MethodPost, "/ea/startSession", bytes.NewReader(message))
   325  			assert.NoError(t, err)
   326  			req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token"))
   327  
   328  			ginEngine.ServeHTTP(r, req)
   329  			res := r.Result()
   330  			assert.Equal(t, tc.status, res.StatusCode)
   331  
   332  			buf := strings.Builder{}
   333  			_, err = io.Copy(&buf, r.Result().Body)
   334  			assert.NoError(t, err)
   335  
   336  			assert.JSONEq(t, tc.err, buf.String())
   337  		})
   338  	}
   339  }
   340  
   341  func TestStartSessionAudit(t *testing.T) {
   342  	r := httptest.NewRecorder()
   343  	gin.SetMode(gin.TestMode)
   344  	c, ginEngine := gin.CreateTestContext(r)
   345  
   346  	// Setup
   347  	resolveTargetFunc := defaultResolveTarget()
   348  	authServer, url := authserviceServer(
   349  		http.StatusOK,
   350  		WithResolveTarget(resolveTargetFunc),
   351  	)
   352  	defer authServer.Close()
   353  
   354  	var rcli eagateway.RemoteCLI = &startSessionTestRCLI{}
   355  	// test logger writes to a byte buffer so it can be read from test
   356  	b := bytes.Buffer{}
   357  	log := fog.New(fog.To(&b))
   358  	_, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, log, rcli, nil)
   359  	assert.NoError(t, err)
   360  
   361  	payload := types.StartSessionPayload{
   362  		SessionID: "SessionID",
   363  		Target: types.Target{
   364  			Projectid:  "projectID",
   365  			Bannerid:   "bannerID",
   366  			Storeid:    "storeID",
   367  			Terminalid: "terminalID",
   368  		},
   369  	}
   370  	req, cancelFunc, err := createStartSessionRequest(c, payload.SessionID, payload.Target)
   371  	assert.NoError(t, err)
   372  
   373  	go ginEngine.ServeHTTP(r, req) // blocking because of long lived request
   374  	validateAuditLogOnDelay(t, &b, "New session started")
   375  	cancelFunc()
   376  	validateAuditLogOnDelay(t, &b, "Session ended")
   377  }
   378  
   379  // checks an audit log with matching message is added to the byte buffer for 0.5 seconds
   380  func validateAuditLogOnDelay(t *testing.T, b *bytes.Buffer, logmsg string) {
   381  	assert.Eventually(t, func() bool {
   382  		return validateAuditLog(b, logmsg)
   383  	}, 500*time.Millisecond, 20*time.Millisecond)
   384  }
   385  
   386  func validateAuditLog(b *bytes.Buffer, logmsg string) bool {
   387  	// split the log on newline char to test whether the condition is satisfied in a single entry rather than across all logs
   388  	lst := strings.Split(b.String(), "\n")
   389  
   390  	// Assertion funcs. Keep as functions to enable boolean short circuting in
   391  	// if statement
   392  	containsLogMsg := func(str string) bool {
   393  		return strings.Contains(str, logmsg)
   394  	}
   395  	containsUserKey := func(str string) bool {
   396  		return strings.Contains(str, fmt.Sprintf("%q:%q", "userID", "user"))
   397  	}
   398  
   399  	for _, str := range lst {
   400  		if containsLogMsg(str) && containsUserKey(str) {
   401  			return true
   402  		}
   403  	}
   404  
   405  	return false
   406  }
   407  

View as plain text