1 package server
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13 "time"
14
15 "edge-infra.dev/pkg/lib/fog"
16 "edge-infra.dev/pkg/sds/emergencyaccess/apierror"
17 errorhandler "edge-infra.dev/pkg/sds/emergencyaccess/apierror/handler"
18 "edge-infra.dev/pkg/sds/emergencyaccess/eagateway"
19 "edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
20 "edge-infra.dev/pkg/sds/emergencyaccess/remotecli"
21 "edge-infra.dev/pkg/sds/emergencyaccess/types"
22
23 "github.com/gin-gonic/gin"
24 "github.com/google/uuid"
25 "github.com/stretchr/testify/assert"
26 )
27
28 type startSessionTestRCLI struct {
29 eagateway.RemoteCLI
30 displayCh chan<- msgdata.CommandResponse
31 target remotecli.Target
32 }
33
34 func (rcli *startSessionTestRCLI) Send(_ context.Context, _, i, _ string, _ msgdata.Request, _ ...remotecli.RCLIOption) error {
35 attr := defaultAttrMap
36 attr["bannerId"] = i
37 response, err := msgdata.NewCommandResponse(defaultBytes, attr)
38 if err != nil {
39 return err
40 }
41 rcli.displayCh <- response
42 return nil
43 }
44
45 func (rcli *startSessionTestRCLI) StartSession(_ context.Context, sessionID string, displayCh chan<- msgdata.CommandResponse, target remotecli.Target, _ ...remotecli.RCLIOption) error {
46 if sessionID == "fail" {
47 return errTestRCLIStartSessionFail
48 }
49 rcli.target = target
50 rcli.displayCh = displayCh
51 return nil
52 }
53
54 type commandResponse struct {
55 Data msgdata.ResponseData `json:"data"`
56 Attributes msgdata.ResponseAttributes `json:"attributes"`
57 }
58
59 type ConnectionPayload struct {
60 Error string `json:"error"`
61 Message commandResponse `json:"message"`
62 }
63
64 func createStartSessionRequest(ctx context.Context, sessionID string, target types.Target) (req *http.Request, cancelFunc context.CancelFunc, err error) {
65 payload := types.StartSessionPayload{
66 SessionID: sessionID,
67 Target: target,
68 }
69 message, err := json.Marshal(payload)
70 if err != nil {
71 return nil, nil, err
72 }
73 ctx, cancelFunc = context.WithCancel(ctx)
74 req, err = http.NewRequestWithContext(ctx, http.MethodPost, "/ea/startSession", bytes.NewReader(message))
75 if err != nil {
76 return nil, cancelFunc, err
77 }
78 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token"))
79 req.Header.Set("Cache-Control", "no-cache")
80 req.Header.Set("Accept", "text/event-stream")
81 req.Header.Set("Connection", "keep-alive")
82
83 setAuthHeaders(req)
84
85 return req, cancelFunc, nil
86 }
87
88 func TestStartSession(t *testing.T) {
89 r := httptest.NewRecorder()
90 gin.SetMode(gin.TestMode)
91 c, ginEngine := gin.CreateTestContext(r)
92
93 var authServerFailures []error
94
95
96 authServer, url := authserviceServer(http.StatusOK,
97 WithMiddleware(verifyUserAuthHeaders(t)),
98
99
100 WithResolveTarget(func(w http.ResponseWriter, _ *http.Request) {
101
102
103 data, err := json.Marshal(map[string]types.Target{
104 "target": defaultTarget,
105 })
106 if err != nil {
107 authServerFailures = append(authServerFailures, err)
108 }
109
110 _, err = w.Write(data)
111 if err != nil {
112 authServerFailures = append(authServerFailures, err)
113 }
114 }),
115 )
116 t.Cleanup(authServer.Close)
117
118 var rcli = &startSessionTestRCLI{}
119 _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil)
120 assert.NoError(t, err)
121
122
123 sessionID := "TestStartSession"
124 target := types.Target{
125 Bannerid: "a-banner-id",
126 Storeid: "a-store-id",
127 Terminalid: "a-terminal-id",
128 }
129 req, cancelFunc, err := createStartSessionRequest(c, sessionID, target)
130 assert.NoError(t, err)
131
132 isClosed := false
133 go func() {
134 ginEngine.ServeHTTP(r, req)
135 assert.Equal(t, http.StatusOK, r.Result().StatusCode)
136 isClosed = true
137 }()
138
139 time.Sleep(10 * time.Millisecond)
140
141 for i := 0; i < 2; i++ {
142 attr := defaultAttrMap
143 attr["bannerId"] = fmt.Sprintf("%d", i)
144 expected, err := msgdata.NewCommandResponse(defaultBytes, attr)
145 assert.NoError(t, err)
146
147 req, err := msgdata.NewV1_0Request("echo")
148 assert.NoError(t, err)
149
150 err = rcli.Send(c, "", fmt.Sprintf("%d", i), uuid.NewString(), req)
151 assert.NoError(t, err)
152
153 var buf []byte
154 assert.Eventually(t, func() bool {
155 buf = r.Body.Bytes()
156 return len(buf) != 0
157 }, 100*time.Millisecond, 20*time.Millisecond)
158 dec := json.NewDecoder(bytes.NewBuffer(buf))
159 var received ConnectionPayload
160 err = dec.Decode(&received)
161 assert.NoError(t, err)
162
163 assert.Equal(t, expected.Data(), received.Message.Data)
164 assert.Equal(t, expected.Attributes(), received.Message.Attributes)
165
166 r.Body.Reset()
167 }
168
169 cancelFunc()
170 assert.Eventually(t, func() bool {
171 return isClosed
172 }, 1*time.Second, 50*time.Millisecond)
173 assert.Equal(t, http.StatusOK, r.Result().StatusCode)
174
175 assert.Empty(t, authServerFailures)
176 assert.Equal(t, defaultTarget.Projectid, rcli.target.ProjectID())
177 assert.Equal(t, defaultTarget.Bannerid, rcli.target.BannerID())
178 assert.Equal(t, defaultTarget.Storeid, rcli.target.StoreID())
179 assert.Equal(t, defaultTarget.Terminalid, rcli.target.TerminalID())
180
181
182 assert.Equal(t, defaultTarget.Projectid, r.Result().Header.Get("X-EA-ProjectID"))
183 assert.Equal(t, defaultTarget.Bannerid, r.Result().Header.Get("X-EA-BannerID"))
184 assert.Equal(t, defaultTarget.Storeid, r.Result().Header.Get("X-EA-StoreID"))
185 assert.Equal(t, defaultTarget.Terminalid, r.Result().Header.Get("X-EA-TerminalID"))
186 }
187
188 func TestStartSessionFail(t *testing.T) {
189 tests := map[string]struct {
190 payload interface{}
191 status int
192 resolveTargetFunc func(w http.ResponseWriter, r *http.Request)
193 authorizeTargetFunc func(w http.ResponseWriter, r *http.Request)
194 err string
195 }{
196 "Payload JSON Bind Fail": {
197 payload: "not valid",
198 status: http.StatusBadRequest,
199 err: `{"errorCode":60201, "errorMessage":"Request Error - Invalid payload structure"}`,
200 },
201 "resolveTarget returns error": {
202 payload: types.StartSessionPayload{
203 SessionID: "session-ID",
204 Target: defaultTarget,
205 },
206 resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
207 w.WriteHeader(http.StatusInternalServerError)
208 errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure))
209 bytes, _ := json.Marshal(errResp)
210 _, _ = w.Write(bytes)
211 },
212 status: http.StatusInternalServerError,
213 err: `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`,
214 },
215 "resolveTarget returns bad request error": {
216 payload: types.StartSessionPayload{
217 SessionID: "session-ID",
218 Target: defaultTarget,
219 },
220 resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
221 w.WriteHeader(http.StatusInternalServerError)
222 errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrInvalidTarget))
223 bytes, _ := json.Marshal(errResp)
224 _, _ = w.Write(bytes)
225 },
226 status: http.StatusBadRequest,
227 err: `{"errorCode": 61202, "errorMessage": "Request Error - Invalid Target properties"}`,
228 },
229
230 "resolveTarget Doesn't return target": {
231 payload: types.StartSessionPayload{
232 SessionID: "not valid",
233 Target: types.Target{
234 Projectid: "",
235 Bannerid: "",
236 Storeid: "storeID",
237 Terminalid: "terminalID",
238 },
239 },
240 resolveTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
241 data, _ := json.Marshal(map[string]types.Target{
242 "target": {
243 Projectid: "",
244 Bannerid: "",
245 Storeid: "a-store-iD",
246 Terminalid: "a-terminalID",
247 },
248 })
249 _, _ = w.Write(data)
250 },
251 status: http.StatusBadRequest,
252 err: `{"errorCode":60202, "details": ["Target missing project ID", "Target missing Banner ID"], "errorMessage":"Request Error - Invalid payload properties"}`,
253 },
254 "authorizeTarget returns non-ok status": {
255 payload: types.StartSessionPayload{
256 Target: defaultTarget,
257 SessionID: "session-ID",
258 },
259 status: http.StatusInternalServerError,
260 authorizeTargetFunc: func(w http.ResponseWriter, _ *http.Request) {
261 w.WriteHeader(http.StatusInternalServerError)
262 errResp := errorhandler.NewErrorResponse(apierror.E(apierror.ErrAuthFailure))
263 bytes, _ := json.Marshal(errResp)
264 _, _ = w.Write(bytes)
265 },
266 err: `{"errorCode":60101, "errorMessage":"User Authorization Failure - Failed to authorize user"}`,
267 },
268 "Payload JSON missing properties": {
269 payload: types.StartSessionPayload{
270 SessionID: "",
271 Target: types.Target{
272 Projectid: "",
273 Bannerid: "",
274 Storeid: "storeID",
275 Terminalid: "terminalID",
276 },
277 },
278 status: http.StatusBadRequest,
279 err: `{"errorCode":60202, "details":["Payload missing Session ID"], "errorMessage":"Request Error - Invalid payload properties"}`,
280 },
281 "RCLI Start Session Fail": {
282 payload: types.StartSessionPayload{
283 SessionID: "fail",
284 Target: types.Target{
285 Projectid: "projectID",
286 Bannerid: "bannerID",
287 Storeid: "storeID",
288 Terminalid: "terminalID",
289 },
290 },
291 status: http.StatusInternalServerError,
292 err: `{"errorCode":61101, "errorMessage":"Subscription failure - Failed to initialize subscription"}`,
293 },
294 }
295
296 for name, tc := range tests {
297 t.Run(name, func(t *testing.T) {
298 r := httptest.NewRecorder()
299 gin.SetMode(gin.TestMode)
300 _, ginEngine := gin.CreateTestContext(r)
301
302
303 resolveTargetFunc := defaultResolveTarget()
304 if tc.resolveTargetFunc != nil {
305 resolveTargetFunc = tc.resolveTargetFunc
306 }
307 authorizeTargetFunc := defaultAuthorizeTarget(http.StatusOK)
308 if tc.authorizeTargetFunc != nil {
309 authorizeTargetFunc = tc.authorizeTargetFunc
310 }
311 authServer, url := authserviceServer(
312 http.StatusOK,
313 WithResolveTarget(resolveTargetFunc),
314 WithAuthorizeTarget(authorizeTargetFunc),
315 )
316 defer authServer.Close()
317
318 var rcli eagateway.RemoteCLI = &startSessionTestRCLI{}
319 _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, newLogger(), rcli, nil)
320 assert.NoError(t, err)
321
322 message, err := json.Marshal(tc.payload)
323 assert.NoError(t, err)
324 req, err := http.NewRequest(http.MethodPost, "/ea/startSession", bytes.NewReader(message))
325 assert.NoError(t, err)
326 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "my_jwt_token"))
327
328 ginEngine.ServeHTTP(r, req)
329 res := r.Result()
330 assert.Equal(t, tc.status, res.StatusCode)
331
332 buf := strings.Builder{}
333 _, err = io.Copy(&buf, r.Result().Body)
334 assert.NoError(t, err)
335
336 assert.JSONEq(t, tc.err, buf.String())
337 })
338 }
339 }
340
341 func TestStartSessionAudit(t *testing.T) {
342 r := httptest.NewRecorder()
343 gin.SetMode(gin.TestMode)
344 c, ginEngine := gin.CreateTestContext(r)
345
346
347 resolveTargetFunc := defaultResolveTarget()
348 authServer, url := authserviceServer(
349 http.StatusOK,
350 WithResolveTarget(resolveTargetFunc),
351 )
352 defer authServer.Close()
353
354 var rcli eagateway.RemoteCLI = &startSessionTestRCLI{}
355
356 b := bytes.Buffer{}
357 log := fog.New(fog.To(&b))
358 _, err := New(eagateway.Config{AuthServiceHost: url}, ginEngine, log, rcli, nil)
359 assert.NoError(t, err)
360
361 payload := types.StartSessionPayload{
362 SessionID: "SessionID",
363 Target: types.Target{
364 Projectid: "projectID",
365 Bannerid: "bannerID",
366 Storeid: "storeID",
367 Terminalid: "terminalID",
368 },
369 }
370 req, cancelFunc, err := createStartSessionRequest(c, payload.SessionID, payload.Target)
371 assert.NoError(t, err)
372
373 go ginEngine.ServeHTTP(r, req)
374 validateAuditLogOnDelay(t, &b, "New session started")
375 cancelFunc()
376 validateAuditLogOnDelay(t, &b, "Session ended")
377 }
378
379
380 func validateAuditLogOnDelay(t *testing.T, b *bytes.Buffer, logmsg string) {
381 assert.Eventually(t, func() bool {
382 return validateAuditLog(b, logmsg)
383 }, 500*time.Millisecond, 20*time.Millisecond)
384 }
385
386 func validateAuditLog(b *bytes.Buffer, logmsg string) bool {
387
388 lst := strings.Split(b.String(), "\n")
389
390
391
392 containsLogMsg := func(str string) bool {
393 return strings.Contains(str, logmsg)
394 }
395 containsUserKey := func(str string) bool {
396 return strings.Contains(str, fmt.Sprintf("%q:%q", "userID", "user"))
397 }
398
399 for _, str := range lst {
400 if containsLogMsg(str) && containsUserKey(str) {
401 return true
402 }
403 }
404
405 return false
406 }
407
View as plain text