1 package msgsvc
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7 "time"
8
9 "github.com/go-logr/logr"
10 "github.com/google/uuid"
11 "github.com/stretchr/testify/assert"
12
13 "edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
14 )
15
16 var (
17 defaultReqAttr = map[string]string{
18 "bannerId": "a",
19 "storeId": "b",
20 "terminalId": "c",
21 "identity": "d",
22 "sessionId": "e",
23 "signature": "g",
24 "commandId": "h",
25 }
26 defaultReqData = "echo 4"
27 defaultRespAttr = map[string]string{
28 "bannerId": "a",
29 "storeId": "b",
30 "terminalId": "c",
31 "identity": "d",
32 "sessionId": "e",
33
34 "request-message-uuid": "f",
35 }
36 defaultRespData = []byte(`
37 {
38 "type": "Output",
39 "exitCode": 0,
40 "output": "4\n",
41 "timestamp": "01-01-2023 00:00:00",
42 "duration": 0.1
43 }`,
44 )
45 )
46
47 func populateAttrs(request msgdata.Request) {
48 for k, v := range defaultReqAttr {
49 request.AddAttribute(k, v)
50 }
51 }
52
53
54 type mockClientForTopicInProject struct {
55 clientItfc
56 callCount int
57 mockTopics []*mockTopic
58 }
59
60 func (mT *mockClientForTopicInProject) TopicInProject(_ string, _ string) topicItfc {
61 top := mT.mockTopics[mT.callCount]
62 mT.callCount++
63 return top
64 }
65
66 type mockTopic struct {
67 callCount int
68 stopCallCount int
69 mockResults []*mockResult
70 calledWith []messageItfc
71 }
72
73 func (m *mockTopic) ID() string {
74 return "mock topic"
75 }
76
77 func (m *mockTopic) Publish(_ context.Context, msg messageItfc) publishResultItfc {
78 m.calledWith = append(m.calledWith, msg)
79 res := m.mockResults[m.callCount]
80 m.callCount++
81 return res
82 }
83
84 func (m *mockTopic) Stop() {
85 m.stopCallCount++
86 }
87
88 func (m *mockTopic) SetOrdering(bool) {
89
90 }
91
92 type mockResult struct {
93
94 }
95
96 func (r mockResult) Get(_ context.Context) (serverID string, err error) {
97 return "a", nil
98 }
99
100 func TestSubscriptionFilter(t *testing.T) {
101 defaultAttr := map[string]string{
102 "bannerId": "banner",
103 "storeId": "store",
104 "terminalId": "terminal",
105 "sessionId": "orderingKey",
106 "identity": "identity",
107 "version": "1.0",
108 "signature": "signature",
109 }
110
111 tests := map[string]struct {
112 attr map[string]string
113 filter map[string]string
114 want bool
115 }{
116 "Success": {
117 attr: defaultAttr,
118 filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
119 want: true,
120 },
121 "Elements don't match": {
122 attr: defaultAttr,
123 filter: map[string]string{"bannerId": "store", "storeId": "banner", "terminalId": "terminal"},
124 want: false,
125 },
126 "Empty elements": {
127 attr: defaultAttr,
128 filter: map[string]string{"bannerId": "", "storeId": "", "terminalId": ""},
129 want: false,
130 },
131 "Key not in attr": {
132 attr: defaultAttr,
133 filter: map[string]string{"keyNotInAttr": "something", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
134 want: false,
135 },
136 "Key not in attr and empty": {
137 attr: defaultAttr,
138 filter: map[string]string{"keyNotInAttr": "", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
139 want: false,
140 },
141 "Empty attr": {
142 attr: nil,
143 filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
144 want: false,
145 },
146 "Empty filter": {
147 attr: defaultAttr,
148 filter: map[string]string{},
149 want: true,
150 },
151 "Nil Filter": {
152 attr: defaultAttr,
153 filter: nil,
154 want: true,
155 },
156 }
157
158 for name, tc := range tests {
159 t.Run(name, func(t *testing.T) {
160 assert.Equal(t, tc.want, isFilterMatch(tc.attr, tc.filter))
161 })
162 }
163 }
164
165 func TestPublish(t *testing.T) {
166 mTopic := mockTopic{mockResults: []*mockResult{{}}}
167 mockClient := mockClientForTopicInProject{
168 mockTopics: []*mockTopic{&mTopic},
169 }
170
171 message, err := msgdata.NewV1_0Request(defaultReqData)
172 assert.NoError(t, err)
173 populateAttrs(message)
174
175
176
177 ms, err := NewMessageService(context.Background())
178 assert.NoError(t, err)
179 ms.ps = &mockClient
180
181 err = ms.Publish(context.Background(), "abcd", "efgh", message)
182 assert.NoError(t, err)
183
184 assert.Equal(t, 1, mockClient.callCount)
185 expectedAttr := map[string]string{
186 "bannerId": "a",
187 "storeId": "b",
188 "terminalId": "c",
189 "identity": "d",
190 "sessionId": "e",
191 "signature": "g",
192 "commandId": "h",
193 "type": "command",
194 "version": "1.0",
195 }
196 assert.Equal(t, expectedAttr, mTopic.calledWith[0].Attributes())
197 assert.JSONEq(t, `{"command": "echo 4"}`, string(mTopic.calledWith[0].Data()))
198 assert.Equal(t, defaultReqAttr["commandId"], mTopic.calledWith[0].OrderingKey())
199
200 }
201
202 func TestPublishCaching(t *testing.T) {
203
204
205
206
207
208 topics := []*mockTopic{
209 {
210 mockResults: []*mockResult{{}, {}},
211 },
212 {
213 mockResults: []*mockResult{{}, {}},
214 },
215 {
216 mockResults: []*mockResult{{}, {}},
217 },
218 {
219 mockResults: []*mockResult{{}, {}},
220 },
221 }
222
223 mockClient := mockClientForTopicInProject{
224 mockTopics: []*mockTopic{topics[0], topics[1], topics[2], topics[3]},
225 }
226
227 message, err := msgdata.NewV1_0Request(defaultReqData)
228 assert.NoError(t, err)
229
230 ms := MessageService{
231 ps: &mockClient,
232 topicCache: make(map[topicEntry]topicItfc),
233 logger: logr.Discard(),
234 }
235
236 tests := []struct {
237 testName string
238 topic string
239 project string
240 clientCallCount int
241
242
243 topicsCallCount []int
244 }{
245 {"Initial Entry",
246 "abcd",
247 "efgh",
248 1,
249 []int{1, 0, 0, 0},
250 },
251 {
252 "Add Topic with new topic ID",
253 "ijkl",
254 "efgh",
255 2,
256 []int{1, 1, 0, 0},
257 },
258 {
259 "Add Topic with new project ID",
260 "abcd",
261 "mnop",
262 3,
263 []int{1, 1, 1, 0},
264 },
265 {
266 "Reuse entry from cache",
267 "abcd",
268 "efgh",
269 3,
270 []int{2, 1, 1, 0},
271 },
272 }
273
274 for _, tc := range tests {
275
276 tc := tc
277
278 t.Run(tc.testName, func(t *testing.T) {
279
280 err = ms.Publish(context.Background(), tc.topic, tc.project, message)
281 assert.NoError(t, err)
282
283
284
285 assert.Equal(t, tc.clientCallCount, mockClient.callCount)
286
287
288 for i, val := range tc.topicsCallCount {
289 assert.Equal(t, val, topics[i].callCount)
290 }
291 })
292 }
293 }
294
295 func TestStopPublishing(t *testing.T) {
296
297
298
299
300 topics := []*mockTopic{
301 {
302 mockResults: []*mockResult{{}, {}},
303 },
304 {
305 mockResults: []*mockResult{{}, {}},
306 },
307 }
308
309 mockClient := mockClientForTopicInProject{
310 mockTopics: []*mockTopic{topics[0], topics[1]},
311 }
312
313 message, err := msgdata.NewV1_0Request(defaultReqData)
314 assert.NoError(t, err)
315
316 ms := MessageService{
317 ps: &mockClient,
318 topicCache: make(map[topicEntry]topicItfc),
319 logger: logr.Discard(),
320 }
321
322 err = ms.Publish(context.Background(), "abcd", "efgh", message)
323 assert.NoError(t, err)
324 err = ms.Publish(context.Background(), "ijkl", "mnop", message)
325 assert.NoError(t, err)
326
327 ms.StopPublish("abcd", "efgh")
328 assert.Equal(t, 1, topics[0].stopCallCount)
329 assert.Equal(t, 0, topics[1].stopCallCount)
330
331 ms.StopPublish("abcd", "efgh")
332 assert.Equal(t, 1, topics[0].stopCallCount)
333 assert.Equal(t, 0, topics[1].stopCallCount)
334
335 ms.StopPublish("qrst", "uvwx")
336 assert.Equal(t, 1, topics[0].stopCallCount)
337 assert.Equal(t, 0, topics[1].stopCallCount)
338 }
339
340
341
342
343
344 type mockClientForSubscribeInProject struct {
345 clientItfc
346 callCount int
347 mockSubscriptions []*mockSubscriptionSynch
348 }
349
350 func (mC *mockClientForSubscribeInProject) SubscriptionInProject(_ string, _ string) subscriptionItfc {
351 subs := mC.mockSubscriptions[mC.callCount]
352 mC.callCount++
353 return subs
354 }
355
356
357
358
359
360 type mockSubscriptionSynch struct {
361 subscriptionItfc
362 callCount int
363 messages []*mockMessage
364 }
365
366 func (mS *mockSubscriptionSynch) Receive(ctx context.Context, handler func(ctx context.Context, msg messageItfc)) error {
367 callNo := mS.callCount
368 mS.callCount++
369
370 message := mS.messages[callNo]
371
372 handler(ctx, message)
373
374
375 return nil
376 }
377
378 type mockMessage struct {
379 data []byte
380 attributes map[string]string
381
382 dataCallCount int
383 attributesCallCount int
384 ackCallCount int
385 nackCallCount int
386 }
387
388 func (mM *mockMessage) ID() string {
389 return uuid.NewString()
390 }
391
392 func (mM *mockMessage) Ack() {
393 mM.ackCallCount++
394 }
395
396 func (mM *mockMessage) Nack() {
397 mM.nackCallCount++
398 }
399
400 func (mM *mockMessage) Attributes() map[string]string {
401 mM.attributesCallCount++
402 return mM.attributes
403 }
404
405 func (mM *mockMessage) Data() []byte {
406 mM.dataCallCount++
407 return mM.data
408 }
409
410 func (mM *mockMessage) OrderingKey() string {
411 return ""
412 }
413
414 func (mM *mockMessage) SetOrderingKey(_ string) {
415
416 }
417 func (mM *mockMessage) messageOnlyAckedOnce(t *testing.T) {
418 if mM.nackCallCount != 0 {
419 t.Errorf("Message should be Ack'ed not Nack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount)
420 }
421 if mM.ackCallCount != 1 {
422 t.Errorf("Message should be Ack'ed once, message Ack'ed %v times", mM.ackCallCount)
423 }
424 }
425
426 func (mM *mockMessage) messageOnlyNackedOnce(t *testing.T) {
427 if mM.ackCallCount != 0 {
428 t.Errorf("Message should be Nack'ed not Ack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount)
429 }
430 if mM.nackCallCount != 1 {
431 t.Errorf("Message should be Nack'ed once, message Nack'ed %v times", mM.nackCallCount)
432 }
433 }
434
435 func TestFilterSubscribeSkip(t *testing.T) {
436 filter := map[string]string{
437 "filterkeynotinattrs": "val",
438 }
439
440 testHandler := func(_ context.Context, _ msgdata.CommandResponse) {}
441
442 testMockMessage := mockMessage{
443 data: defaultRespData,
444 attributes: defaultRespAttr,
445 }
446
447 mockSub := mockSubscriptionSynch{
448 messages: []*mockMessage{&testMockMessage},
449 }
450
451 mockClient := mockClientForSubscribeInProject{
452 mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
453 }
454
455 ms := MessageService{
456 ps: &mockClient,
457 logger: logr.Discard(),
458 }
459
460 err := ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter)
461 assert.NoError(t, err)
462
463
464 testMockMessage.messageOnlyNackedOnce(t)
465 assert.Equal(t, 1, testMockMessage.attributesCallCount)
466 assert.Equal(t, 0, testMockMessage.dataCallCount)
467 }
468
469 func TestFilterSubscribe(t *testing.T) {
470 filter := defaultRespAttr
471
472 testMockMessage := mockMessage{
473 data: defaultRespData,
474 attributes: defaultRespAttr,
475 }
476
477 expMessage, err := msgdata.NewCommandResponse(defaultRespData, defaultRespAttr)
478 assert.NoError(t, err)
479
480 testHandler := func(_ context.Context, msg msgdata.CommandResponse) {
481 assert.Equal(t, expMessage, msg)
482 }
483
484 mockSub := mockSubscriptionSynch{
485 messages: []*mockMessage{&testMockMessage},
486 }
487
488 mockClient := mockClientForSubscribeInProject{
489 mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
490 }
491
492
493
494 ms, err := NewMessageService(context.Background())
495 assert.NoError(t, err)
496 ms.ps = &mockClient
497
498 err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter)
499 assert.NoError(t, err)
500
501
502 testMockMessage.messageOnlyAckedOnce(t)
503 assert.Equal(t, 2, testMockMessage.attributesCallCount)
504 assert.Equal(t, 1, testMockMessage.dataCallCount)
505 }
506
507 func TestInvalidMessageNack(t *testing.T) {
508 testMockMessage := mockMessage{
509 data: []byte(`{"invalid json": "stri`),
510 attributes: defaultRespAttr,
511 }
512
513 testHandler := func(_ context.Context, _ msgdata.CommandResponse) {}
514
515 mockSub := mockSubscriptionSynch{
516 messages: []*mockMessage{&testMockMessage},
517 }
518
519 mockClient := mockClientForSubscribeInProject{
520 mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
521 }
522
523
524
525 ms, err := NewMessageService(context.Background())
526 assert.NoError(t, err)
527 ms.ps = &mockClient
528
529 err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, nil)
530 assert.NoError(t, err)
531
532 testMockMessage.messageOnlyNackedOnce(t)
533 assert.Equal(t, 2, testMockMessage.attributesCallCount)
534 assert.Equal(t, 1, testMockMessage.dataCallCount)
535 }
536
537 type createSubscriptionClient struct {
538 clientItfc
539 subscriptionID string
540 cfg subscriptionCfg
541 }
542
543 func (cl *createSubscriptionClient) CreateSubscription(_ context.Context, subscriptionID string, cfg subscriptionCfg) (subscriptionItfc, error) {
544 cl.subscriptionID = subscriptionID
545 cl.cfg = cfg
546
547 return nil, nil
548 }
549
550 func TestCreateSubscription(t *testing.T) {
551 t.Parallel()
552
553 tests := map[string]struct {
554 sessionID string
555 subscriptionID string
556 projectID string
557 topicID string
558
559 expCfg subscriptionCfg
560 }{
561 "Create Subscription": {
562 sessionID: "abcd",
563 subscriptionID: "efgh",
564 projectID: "ijkl",
565 topicID: "mnop",
566
567 expCfg: subscriptionCfg{
568 topicName: "mnop",
569 projectID: "ijkl",
570 retentionDuration: time.Hour,
571 expirationPolicy: 24 * time.Hour,
572 filter: `attributes.sessionId="abcd"`,
573 },
574 },
575 }
576
577 for name, tc := range tests {
578 tc := tc
579 t.Run(name, func(t *testing.T) {
580 t.Parallel()
581
582
583
584 ms, err := NewMessageService(context.Background())
585 assert.NoError(t, err)
586
587 mockCl := createSubscriptionClient{}
588 ms.ps = &mockCl
589
590 err = ms.CreateSubscription(context.Background(), tc.sessionID, tc.subscriptionID, tc.projectID, tc.topicID)
591 assert.NoError(t, err)
592
593 assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID)
594 assert.Equal(t, tc.expCfg, mockCl.cfg)
595 })
596 }
597 }
598
599
600
601 type deleteSubscriptionClient struct {
602 clientItfc
603 subscriptionItfc
604 subscriptionID string
605 projectID string
606 retErr error
607 }
608
609 func (cl *deleteSubscriptionClient) SubscriptionInProject(subscriptionID, projectID string) subscriptionItfc {
610 cl.subscriptionID = subscriptionID
611 cl.projectID = projectID
612
613 return cl
614 }
615
616 func (cl *deleteSubscriptionClient) Delete(_ context.Context) error {
617 return cl.retErr
618 }
619
620 func TestDeleteSubscription(t *testing.T) {
621 t.Parallel()
622
623 tests := map[string]struct {
624 subscriptionID string
625 projectID string
626 retErr error
627
628 expErr assert.ErrorAssertionFunc
629 }{
630 "Normal": {
631 subscriptionID: "abcd",
632 projectID: "efgh",
633 retErr: nil,
634 expErr: assert.NoError,
635 },
636 "Error": {
637 subscriptionID: "abcd",
638 projectID: "efgh",
639 retErr: fmt.Errorf("bad"),
640 expErr: assert.Error,
641 },
642 }
643
644 for name, tc := range tests {
645 tc := tc
646 t.Run(name, func(t *testing.T) {
647 t.Parallel()
648
649 ms, err := NewMessageService(context.Background())
650 assert.NoError(t, err)
651
652 mockCl := deleteSubscriptionClient{
653 retErr: tc.retErr,
654 }
655 ms.ps = &mockCl
656
657 err = ms.DeleteSubscription(context.Background(), tc.subscriptionID, tc.projectID)
658 tc.expErr(t, err)
659
660 assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID)
661 assert.Equal(t, tc.projectID, mockCl.projectID)
662 })
663 }
664 }
665
View as plain text