package server import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/emergencyaccess/apierror" errorhandler "edge-infra.dev/pkg/sds/emergencyaccess/apierror/handler" "edge-infra.dev/pkg/sds/emergencyaccess/eagateway" "edge-infra.dev/pkg/sds/emergencyaccess/msgdata" "edge-infra.dev/pkg/sds/emergencyaccess/remotecli" "edge-infra.dev/pkg/sds/emergencyaccess/types" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/assert" ) type startSessionTestRCLI struct { eagateway.RemoteCLI displayCh chan<- msgdata.CommandResponse target remotecli.Target } func (rcli *startSessionTestRCLI) Send(_ context.Context, _, i, _ string, _ msgdata.Request, _ ...remotecli.RCLIOption) error { attr := defaultAttrMap attr["bannerId"] = i response, err := msgdata.NewCommandResponse(defaultBytes, attr) if err != nil { return err } rcli.displayCh <- response return nil } func (rcli *startSessionTestRCLI) StartSession(_ context.Context, sessionID string, displayCh chan<- msgdata.CommandResponse, target remotecli.Target, _ ...remotecli.RCLIOption) error { if sessionID == "fail" { return errTestRCLIStartSessionFail } rcli.target = target rcli.displayCh = displayCh return nil } type commandResponse struct { Data msgdata.ResponseData `json:"data"` Attributes msgdata.ResponseAttributes `json:"attributes"` } type ConnectionPayload struct { Error string `json:"error"` Message commandResponse `json:"message"` } func createStartSessionRequest(ctx context.Context, sessionID string, target types.Target) (req *http.Request, cancelFunc context.CancelFunc, err error) { payload := types.StartSessionPayload{ SessionID: sessionID, Target: target, } message, err := json.Marshal(payload) if err != nil { return nil, nil, err } ctx, cancelFunc = context.WithCancel(ctx) req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/ea/startSession", bytes.NewReader(message)) if err != nil { return nil, cancelFunc, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token")) req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Connection", "keep-alive") setAuthHeaders(req) return req, cancelFunc, nil } func TestStartSession(t *testing.T) { r := httptest.NewRecorder() gin.SetMode(gin.TestMode) c, ginEngine := gin.CreateTestContext(r) var authServerFailures []error // Setup authServer, url := authserviceServer(http.StatusOK, WithMiddleware(verifyUserAuthHeaders(t)), // Returns the default target WithResolveTarget(func(w http.ResponseWriter, _ *http.Request) { // TODO validate passed in target is correct? data, err := json.Marshal(map[string]types.Target{ "target": defaultTarget, }) if err != nil { authServerFailures = append(authServerFailures, err) } _, err = w.Write(data) if err != nil { authServerFailures = append(authServerFailures, err) } }), ) t.Cleanup(authServer.Close) var rcli = &startSessionTestRCLI{} _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil) assert.NoError(t, err) // Test sessionID := "TestStartSession" target := types.Target{ Bannerid: "a-banner-id", Storeid: "a-store-id", Terminalid: "a-terminal-id", } req, cancelFunc, err := createStartSessionRequest(c, sessionID, target) assert.NoError(t, err) isClosed := false go func() { ginEngine.ServeHTTP(r, req) assert.Equal(t, http.StatusOK, r.Result().StatusCode) isClosed = true }() time.Sleep(10 * time.Millisecond) for i := 0; i < 2; i++ { attr := defaultAttrMap attr["bannerId"] = fmt.Sprintf("%d", i) expected, err := msgdata.NewCommandResponse(defaultBytes, attr) assert.NoError(t, err) req, err := msgdata.NewV1_0Request("echo") assert.NoError(t, err) err = rcli.Send(c, "", fmt.Sprintf("%d", i), uuid.NewString(), req) assert.NoError(t, err) var buf []byte assert.Eventually(t, func() bool { buf = r.Body.Bytes() return len(buf) != 0 }, 100*time.Millisecond, 20*time.Millisecond) dec := json.NewDecoder(bytes.NewBuffer(buf)) var received ConnectionPayload err = dec.Decode(&received) assert.NoError(t, err) assert.Equal(t, expected.Data(), received.Message.Data) assert.Equal(t, expected.Attributes(), received.Message.Attributes) r.Body.Reset() } cancelFunc() assert.Eventually(t, func() bool { return isClosed }, 1*time.Second, 50*time.Millisecond) assert.Equal(t, http.StatusOK, r.Result().StatusCode) assert.Empty(t, authServerFailures) assert.Equal(t, defaultTarget.Projectid, rcli.target.ProjectID()) assert.Equal(t, defaultTarget.Bannerid, rcli.target.BannerID()) assert.Equal(t, defaultTarget.Storeid, rcli.target.StoreID()) assert.Equal(t, defaultTarget.Terminalid, rcli.target.TerminalID()) // Validate target is returned via api assert.Equal(t, defaultTarget.Projectid, r.Result().Header.Get("X-EA-ProjectID")) assert.Equal(t, defaultTarget.Bannerid, r.Result().Header.Get("X-EA-BannerID")) assert.Equal(t, defaultTarget.Storeid, r.Result().Header.Get("X-EA-StoreID")) assert.Equal(t, defaultTarget.Terminalid, r.Result().Header.Get("X-EA-TerminalID")) } func TestStartSessionFail(t *testing.T) { tests := map[string]struct { payload interface{} status int resolveTargetFunc func(w http.ResponseWriter, r *http.Request) authorizeTargetFunc func(w http.ResponseWriter, r *http.Request) err string }{ "Payload JSON Bind Fail": { payload: "not valid", status: http.StatusBadRequest, err: `{"errorCode":60201, "errorMessage":"Request Error - Invalid payload structure"}`, }, "resolveTarget returns error": { payload: types.StartSessionPayload{ SessionID: "session-ID", Target: defaultTarget, }, resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure)) bytes, _ := json.Marshal(errResp) _, _ = w.Write(bytes) }, status: http.StatusInternalServerError, err: `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`, }, "resolveTarget returns bad request error": { payload: types.StartSessionPayload{ SessionID: "session-ID", Target: defaultTarget, }, resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrInvalidTarget)) bytes, _ := json.Marshal(errResp) _, _ = w.Write(bytes) }, status: http.StatusBadRequest, err: `{"errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties"}`, }, "resolveTarget Doesn't return target": { payload: types.StartSessionPayload{ SessionID: "not valid", Target: types.Target{ Projectid: "", Bannerid: "", Storeid: "storeID", Terminalid: "terminalID", }, }, resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) { data, _ := json.Marshal(map[string]types.Target{ "target": { Projectid: "", Bannerid: "", Storeid: "a-store-iD", Terminalid: "a-terminalID", }, }) _, _ = w.Write(data) }, status: http.StatusBadRequest, err: `{"errorCode":60202, "details": ["Target missing project ID", "Target missing Banner ID"], "errorMessage":"Request Error - Invalid payload properties"}`, }, "authorizeTarget returns non-ok status": { payload: types.StartSessionPayload{ Target: defaultTarget, SessionID: "session-ID", }, status: http.StatusInternalServerError, authorizeTargetFunc: func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure)) bytes, _ := json.Marshal(errResp) _, _ = w.Write(bytes) }, err: `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`, }, "Payload JSON missing properties": { payload: types.StartSessionPayload{ SessionID: "", Target: types.Target{ Projectid: "", Bannerid: "", Storeid: "storeID", Terminalid: "terminalID", }, }, status: http.StatusBadRequest, err: `{"errorCode":60202, "details":["Payload missing Session ID"], "errorMessage":"Request Error - Invalid payload properties"}`, }, "RCLI Start Session Fail": { payload: types.StartSessionPayload{ SessionID: "fail", Target: types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", }, }, status: http.StatusInternalServerError, err: `{"errorCode":61101, "errorMessage":"Subscription failure - Failed to initialize subscription"}`, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { r := httptest.NewRecorder() gin.SetMode(gin.TestMode) _, ginEngine := gin.CreateTestContext(r) // Setup resolveTargetFunc := defaultResolveTarget() if tc.resolveTargetFunc != nil { resolveTargetFunc = tc.resolveTargetFunc } authorizeTargetFunc := defaultAuthorizeTarget(http.StatusOK) if tc.authorizeTargetFunc != nil { authorizeTargetFunc = tc.authorizeTargetFunc } authServer, url := authserviceServer( http.StatusOK, WithResolveTarget(resolveTargetFunc), WithAuthorizeTarget(authorizeTargetFunc), ) defer authServer.Close() var rcli eagateway.RemoteCLI = &startSessionTestRCLI{} _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil) assert.NoError(t, err) message, err := json.Marshal(tc.payload) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, "/ea/startSession", bytes.NewReader(message)) assert.NoError(t, err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token")) ginEngine.ServeHTTP(r, req) res := r.Result() assert.Equal(t, tc.status, res.StatusCode) buf := strings.Builder{} _, err = io.Copy(&buf, r.Result().Body) assert.NoError(t, err) assert.JSONEq(t, tc.err, buf.String()) }) } } func TestStartSessionAudit(t *testing.T) { r := httptest.NewRecorder() gin.SetMode(gin.TestMode) c, ginEngine := gin.CreateTestContext(r) // Setup resolveTargetFunc := defaultResolveTarget() authServer, url := authserviceServer( http.StatusOK, WithResolveTarget(resolveTargetFunc), ) defer authServer.Close() var rcli eagateway.RemoteCLI = &startSessionTestRCLI{} // test logger writes to a byte buffer so it can be read from test b := bytes.Buffer{} log := fog.New(fog.To(&b)) _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, log, rcli, nil) assert.NoError(t, err) payload := types.StartSessionPayload{ SessionID: "SessionID", Target: types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", }, } req, cancelFunc, err := createStartSessionRequest(c, payload.SessionID, payload.Target) assert.NoError(t, err) go ginEngine.ServeHTTP(r, req) // blocking because of long lived request validateAuditLogOnDelay(t, &b, "New session started") cancelFunc() validateAuditLogOnDelay(t, &b, "Session ended") } // checks an audit log with matching message is added to the byte buffer for 0.5 seconds func validateAuditLogOnDelay(t *testing.T, b *bytes.Buffer, logmsg string) { assert.Eventually(t, func() bool { return validateAuditLog(b, logmsg) }, 500*time.Millisecond, 20*time.Millisecond) } func validateAuditLog(b *bytes.Buffer, logmsg string) bool { // split the log on newline char to test whether the condition is satisfied in a single entry rather than across all logs lst := strings.Split(b.String(), "\n") // Assertion funcs. Keep as functions to enable boolean short circuting in // if statement containsLogMsg := func(str string) bool { return strings.Contains(str, logmsg) } containsUserKey := func(str string) bool { return strings.Contains(str, fmt.Sprintf("%q:%q", "userID", "user")) } for _, str := range lst { if containsLogMsg(str) && containsUserKey(str) { return true } } return false }