1 package remotecli
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "strings"
8 "sync"
9 "testing"
10 "time"
11
12 "github.com/go-logr/logr"
13 "github.com/go-logr/logr/funcr"
14 "github.com/google/uuid"
15 "github.com/stretchr/testify/assert"
16
17 "edge-infra.dev/pkg/lib/fog"
18 "edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
19 "edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
20 )
21
22 var (
23 defaultTarget = target{
24 projectID: "project",
25 bannerID: "banner",
26 storeID: "store",
27 terminalID: "terminal",
28 }
29 defaultIdentity = "user"
30
31 defaultWaitTime = 20 * time.Millisecond
32 defaultTickTime = 1 * time.Microsecond
33 )
34
35 type target struct {
36 projectID string
37 bannerID string
38 storeID string
39 terminalID string
40 }
41
42 func (t target) ProjectID() string { return t.projectID }
43 func (t target) BannerID() string { return t.bannerID }
44 func (t target) StoreID() string { return t.storeID }
45 func (t target) TerminalID() string { return t.terminalID }
46
47 type messageService struct {
48 subscriptionFunc func(subscriptioID, projectID string)
49 stopPublishFunc func(string, string)
50 subscribeWatchCtx bool
51
52 pubCommandID string
53 }
54
55 func (ms *messageService) Subscribe(
56 ctx context.Context,
57 subscriptionID string,
58 projectID string,
59 _ func(context.Context, msgdata.CommandResponse),
60 _ map[string]string,
61 ) error {
62 if ms.subscriptionFunc != nil {
63 ms.subscriptionFunc(subscriptionID, projectID)
64 }
65
66
67 if ms.subscribeWatchCtx {
68 <-ctx.Done()
69 }
70
71 if projectID == "subscribe_error" {
72 return fmt.Errorf("subscribe returned error")
73 }
74 return nil
75 }
76
77 func (ms *messageService) Publish(_ context.Context, _ string, projectID string, msg msgdata.Request) error {
78 ms.pubCommandID = msg.Attributes()[eaconst.CommandIDKey]
79
80 if projectID == "publish_error" {
81 return fmt.Errorf("publish returned error")
82 }
83 return nil
84 }
85
86 func (ms messageService) StopPublish(topicID string, projectID string) {
87 if ms.stopPublishFunc != nil {
88 ms.stopPublishFunc(topicID, projectID)
89 }
90 }
91
92 func TestValidateTarget(t *testing.T) {
93 assert.NoError(t, validateTarget(defaultTarget))
94
95 badTarget := target{}
96 assert.Len(t, validateTarget(badTarget), 4)
97 }
98
99 func TestNewRemoteCLI(t *testing.T) {
100 ctx := context.Background()
101 ms := messageService{}
102 expected := &RemoteCLI{
103 msgService: &ms,
104 sessionLock: &sync.RWMutex{},
105 context: ctx,
106
107 sessionData: map[string]sessionData{},
108 subscriptionData: map[string]subscriptionData{},
109 topicData: map[string]topicData{},
110 }
111
112 rcli := New(ctx, &ms)
113
114 assert.Equal(t, expected, rcli)
115 }
116
117 func TestStartSession(t *testing.T) {
118 buf := bytes.Buffer{}
119 ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf)))
120 ch := make(chan msgdata.CommandResponse)
121
122 ms := messageService{
123 subscribeWatchCtx: true,
124 }
125
126 rcli := New(ctx, &ms)
127 sessionID := uuid.NewString()
128
129 target := defaultTarget
130 target.projectID = "subscribe_error"
131
132 err := rcli.StartSession(ctx, sessionID, ch, target)
133 assert.NoError(t, err)
134
135 assert.Equal(t, target, rcli.sessionData[sessionID].target)
136
137 expSubscriptionID := "sub.store.dsds-ea-response"
138 assert.True(t, rcli.subscriptionData[expSubscriptionID].sessions.HasMember(sessionID))
139
140 cancelFunc()
141
142 assert.Eventually(t, func() bool {
143 return strings.Contains(buf.String(), "subscribe returned error")
144 }, 4*time.Second, 1*time.Millisecond, "logs:\n%s\ndoes not contain expected string: %s", buf.String(), "subscribe returned error")
145 }
146
147 func TestStartSession_ListensToContextDone(t *testing.T) {
148 buf := bytes.Buffer{}
149 ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf)))
150 target := defaultTarget
151 ch := make(chan msgdata.CommandResponse)
152 sessionID := uuid.NewString()
153
154
155
156 endSubscriptionChan := make(chan struct{})
157 subFunc := func(_, _ string) {
158 <-endSubscriptionChan
159 }
160 ms := messageService{
161 subscriptionFunc: subFunc,
162 }
163
164 rcli := New(ctx, &ms)
165
166 err := rcli.StartSession(ctx, sessionID, ch, target)
167 assert.NoError(t, err)
168
169
170 assert.Never(t, func() bool {
171 select {
172 case _, ok := <-ch:
173 if ok {
174 return false
175 }
176 return true
177 default:
178 return false
179 }
180 }, defaultWaitTime, defaultTickTime)
181
182
183 cancelFunc()
184
185 assert.Eventually(t, func() bool {
186 select {
187 case _, ok := <-ch:
188 if ok {
189 return false
190 }
191 return true
192 default:
193 return false
194 }
195 }, defaultWaitTime, defaultTickTime)
196
197
198 endSubscriptionChan <- struct{}{}
199 }
200
201 func TestSubscribeCalledOnce(t *testing.T) {
202
203
204 ctx := fog.IntoContext(context.Background(), logr.Discard())
205 target := defaultTarget
206 ch := make(chan msgdata.CommandResponse)
207 sessionID := uuid.NewString()
208
209 var subscriptionCalledCounter int
210 subFunc := func(_, _ string) {
211 subscriptionCalledCounter = subscriptionCalledCounter + 1
212 }
213
214 ms := messageService{
215 subscriptionFunc: subFunc,
216 subscribeWatchCtx: true,
217 }
218
219 rcli := New(ctx, &ms)
220
221 err := rcli.StartSession(ctx, sessionID, ch, target)
222 assert.NoError(t, err)
223 assert.Eventually(t, func() bool {
224 return subscriptionCalledCounter == 1
225 }, defaultWaitTime, defaultTickTime)
226
227
228 ch = make(chan msgdata.CommandResponse)
229 sessionID = uuid.NewString()
230 err = rcli.StartSession(ctx, sessionID, ch, target)
231 assert.NoError(t, err)
232 assert.Never(t, func() bool {
233 return subscriptionCalledCounter == 2
234 }, defaultWaitTime, defaultTickTime)
235
236
237 ch = make(chan msgdata.CommandResponse)
238 sessionID = uuid.NewString()
239 target.storeID = "anotherstore"
240 err = rcli.StartSession(ctx, sessionID, ch, target)
241 assert.NoError(t, err)
242 assert.Eventually(t, func() bool {
243 return subscriptionCalledCounter == 2
244 }, defaultWaitTime, defaultTickTime)
245 }
246
247 func TestStartSessionInvalidTarget(t *testing.T) {
248 buf := bytes.Buffer{}
249 logger := createLogger(&buf)
250 sessionID := uuid.NewString()
251
252 ctx := fog.IntoContext(context.Background(), logger)
253 ch, ms := make(chan msgdata.CommandResponse), messageService{}
254 rcli := New(ctx, &ms)
255 target := target{}
256
257 err := rcli.StartSession(ctx, sessionID, ch, target)
258 assert.Contains(t, err.Error(), validateTarget(target).Error())
259 }
260
261 func TestSend(t *testing.T) {
262 ctx := fog.IntoContext(context.Background(), logr.Discard())
263 ch := make(chan msgdata.CommandResponse)
264 ms := messageService{subscribeWatchCtx: true}
265 sessionID := uuid.NewString()
266 rcli := New(ctx, &ms)
267
268 identity, target := defaultIdentity, defaultTarget
269 target.projectID = "publish_error"
270 command := "echo hello"
271 request, err := msgdata.NewV1_0Request(command)
272 assert.NoError(t, err)
273 commandID := uuid.NewString()
274
275 err = rcli.StartSession(ctx, sessionID, ch, target)
276 assert.NoError(t, err)
277
278 err = rcli.Send(ctx, identity, sessionID, commandID, request)
279 assert.Contains(t, err.Error(), fmt.Errorf("publish returned error").Error())
280
281 assert.NotEmpty(t, ms.pubCommandID)
282 assert.Equal(t, commandID, ms.pubCommandID)
283 }
284
285 func TestSendNoSessionStarted(t *testing.T) {
286 buf := bytes.Buffer{}
287 logger := createLogger(&buf)
288
289 ctx := fog.IntoContext(context.Background(), logger)
290 ms := messageService{}
291 rcli := New(ctx, &ms)
292 identity, command, sessionID := defaultIdentity, "echo hello", "invalid-session-id"
293 request, err := msgdata.NewV1_0Request(command)
294 assert.NoError(t, err)
295
296 err = rcli.Send(ctx, identity, sessionID, uuid.NewString(), request)
297 assert.Contains(t, err.Error(), "invalid session id")
298 }
299
300 func TestEndSession(t *testing.T) {
301 buf := bytes.Buffer{}
302 logger := createLogger(&buf)
303 sessionID := uuid.NewString()
304
305 ms := messageService{
306 subscribeWatchCtx: true,
307 }
308
309 ctx := fog.IntoContext(context.Background(), logger)
310 ch, target := make(chan msgdata.CommandResponse), defaultTarget
311 rcli := New(ctx, &ms)
312 logOK := "Session stopped"
313
314 var err error
315 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, "nonexistant_session") })
316 assert.ErrorContains(t, err, "unknown session ID")
317 buf.Reset()
318
319 err = rcli.StartSession(ctx, sessionID, ch, target)
320 assert.NoError(t, err)
321 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
322 assert.NoError(t, err)
323 assert.Contains(t, buf.String(), logOK)
324 buf.Reset()
325
326
327 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
328 assert.ErrorContains(t, err, "unknown session ID")
329 buf.Reset()
330
331
332 assert.Eventually(t, func() bool {
333 select {
334 case v, ok := <-ch:
335 if ok {
336 t.Errorf("unexpected value on display channel: %s", v)
337 return false
338 }
339 return true
340 default:
341 return false
342 }
343 }, 5*time.Second, time.Microsecond)
344 }
345
346 func TestEndSession_Topics(t *testing.T) {
347 buf := bytes.Buffer{}
348 logger := createLogger(&buf)
349 sessionID := uuid.NewString()
350
351
352 type publishInfo struct {
353 topicID string
354 projectID string
355 }
356 endSessionCallMap := make(map[publishInfo]int)
357
358
359
360 stopPublishFunc := func(topicID string, projectID string) {
361 pInfo := publishInfo{topicID, projectID}
362 endSessionCallMap[pInfo] = endSessionCallMap[pInfo] + 1
363 }
364
365 ms := messageService{
366 stopPublishFunc: stopPublishFunc,
367 subscribeWatchCtx: true,
368 }
369
370 ctx := fog.IntoContext(context.Background(), logger)
371 ch, target := make(chan msgdata.CommandResponse), defaultTarget
372 rcli := New(ctx, &ms)
373
374 var err error
375 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
376 assert.ErrorContains(t, err, "unknown session ID")
377
378 request, err := msgdata.NewV1_0Request("command")
379 assert.NoError(t, err)
380
381
382 assert.Len(t, endSessionCallMap, 0)
383
384
385 assert.NoError(t, rcli.StartSession(ctx, sessionID, ch, target))
386 assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request, WithOptionalTemplate("a")))
387 assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request))
388 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
389 assert.NoError(t, err)
390
391
392 assert.Len(t, endSessionCallMap, 2)
393 assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}])
394 assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}])
395
396
397 assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
398 assert.ErrorContains(t, err, "unknown session ID")
399
400
401 assert.Len(t, endSessionCallMap, 2)
402 assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}])
403 assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}])
404 }
405
406 func handlerAttrs(sessionID string) map[string]string {
407 return map[string]string{
408 "bannerId": "banner",
409 "storeId": "store",
410 "terminalId": "terminal",
411 "sessionId": sessionID,
412 "identity": "identity",
413 "version": "1.0",
414 "signature": "signature",
415 "request-message-uuid": "id",
416 }
417 }
418
419 func handlerData(output string) []byte {
420 return []byte(fmt.Sprintf(`
421 {
422 "type": "Output",
423 "exitCode": 0,
424 "output": "%s",
425 "timestamp": "01-01-2023 00:00:00",
426 "duration": 0.1
427 }`, output))
428 }
429
430 func TestHandler(t *testing.T) {
431 ch1 := make(chan msgdata.CommandResponse, 3)
432 ch2 := make(chan msgdata.CommandResponse, 1)
433 ch3 := make(chan msgdata.CommandResponse)
434
435 rcli := RemoteCLI{
436 sessionLock: &sync.RWMutex{},
437 sessionData: map[string]sessionData{
438 "orderingKey": {
439 displayChan: ch1,
440 },
441 "orderingKey2": {
442 displayChan: ch2,
443 },
444 "orderingKey3": {
445 displayChan: ch3,
446 },
447 },
448 }
449 ctx, cancelFunc := context.WithCancel(context.Background())
450 fn := rcli.handler()
451
452 data1 := handlerData("message 1")
453 data2 := handlerData("message 2")
454 data3 := handlerData("other ordering key")
455 data4 := handlerData("this message should not be received")
456
457 msg1, _ := msgdata.NewCommandResponse(data1, handlerAttrs("orderingKey"))
458 msg2, _ := msgdata.NewCommandResponse(data2, handlerAttrs("orderingKey"))
459 msg3, _ := msgdata.NewCommandResponse(data3, handlerAttrs("orderingKey2"))
460 msg4, _ := msgdata.NewCommandResponse(data4, handlerAttrs("orderingKey3"))
461
462 fn(ctx, msg1)
463 assert.Empty(t, ch2)
464 assert.Equal(t, msg1, <-ch1)
465
466 fn(ctx, msg2)
467 assert.Empty(t, ch2)
468 assert.Equal(t, msg2, <-ch1)
469
470
471 fn(context.Background(), msg3)
472 assert.Empty(t, ch1)
473 assert.Equal(t, msg3, <-ch2)
474
475 cancelFunc()
476 time.Sleep(time.Millisecond * 100)
477
478 fn(ctx, msg4)
479 assert.Empty(t, ch1)
480 assert.Empty(t, ch2)
481 assert.Empty(t, ch3)
482 }
483
484 func TestCreateOptionalConfig(t *testing.T) {
485 template1, template2, template3 := "template-string-1", "template-string-2", "template-string-3"
486 opts := []RCLIOption{WithOptionalTemplate(template1), WithOptionalTemplate(template2), WithOptionalTemplate(template3)}
487
488 expected := &templateConfig{template: &template3}
489 assert.Equal(t, expected, createOptionalConfig(opts))
490
491 assert.Nil(t, createOptionalConfig(nil))
492 }
493
494 func TestFillTemplate(t *testing.T) {
495 defaultTemplate := "default.<PROJECT_ID>.<BANNER_ID>.<STORE_ID>.<TERMINAL_ID>"
496 optionalTemplate := "optional.<PROJECT_ID>.<BANNER_ID>.<STORE_ID>.<TERMINAL_ID>"
497
498 target := defaultTarget
499 config := &templateConfig{template: &optionalTemplate}
500
501 expected := fmt.Sprintf("%s.%s.%s.%s", target.projectID, target.bannerID, target.storeID, target.terminalID)
502 assert.Equal(t, "default."+expected, fillTemplate(target, defaultTemplate, nil))
503 assert.Equal(t, "optional."+expected, fillTemplate(target, defaultTemplate, config))
504 }
505
506 func createLogger(buf *bytes.Buffer) logr.Logger {
507 return funcr.New(func(prefix, args string) {
508 if prefix != "" {
509 fmt.Fprintf(buf, "%s: %s\n", prefix, args)
510 } else {
511 fmt.Fprintln(buf, args)
512 }
513 }, funcr.Options{})
514 }
515
View as plain text