package msgsvc import ( "context" "fmt" "testing" "time" "github.com/go-logr/logr" "github.com/google/uuid" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/sds/emergencyaccess/msgdata" ) var ( defaultReqAttr = map[string]string{ "bannerId": "a", "storeId": "b", "terminalId": "c", "identity": "d", "sessionId": "e", "signature": "g", "commandId": "h", } defaultReqData = "echo 4" defaultRespAttr = map[string]string{ "bannerId": "a", "storeId": "b", "terminalId": "c", "identity": "d", "sessionId": "e", "request-message-uuid": "f", } defaultRespData = []byte(` { "type": "Output", "exitCode": 0, "output": "4\n", "timestamp": "01-01-2023 00:00:00", "duration": 0.1 }`, ) ) func populateAttrs(request msgdata.Request) { for k, v := range defaultReqAttr { request.AddAttribute(k, v) } } // Implements client interface (by embedding) but only used for testig topics type mockClientForTopicInProject struct { clientItfc callCount int mockTopics []*mockTopic } func (mT *mockClientForTopicInProject) TopicInProject(_ string, _ string) topicItfc { top := mT.mockTopics[mT.callCount] mT.callCount++ return top } type mockTopic struct { callCount int stopCallCount int mockResults []*mockResult calledWith []messageItfc } func (m *mockTopic) ID() string { return "mock topic" } func (m *mockTopic) Publish(_ context.Context, msg messageItfc) publishResultItfc { m.calledWith = append(m.calledWith, msg) res := m.mockResults[m.callCount] m.callCount++ return res } func (m *mockTopic) Stop() { m.stopCallCount++ } func (m *mockTopic) SetOrdering(bool) { } type mockResult struct { // publishResulter } func (r mockResult) Get(_ context.Context) (serverID string, err error) { return "a", nil } func TestSubscriptionFilter(t *testing.T) { defaultAttr := map[string]string{ "bannerId": "banner", "storeId": "store", "terminalId": "terminal", "sessionId": "orderingKey", "identity": "identity", "version": "1.0", "signature": "signature", } tests := map[string]struct { attr map[string]string filter map[string]string want bool }{ "Success": { attr: defaultAttr, filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"}, want: true, }, "Elements don't match": { attr: defaultAttr, filter: map[string]string{"bannerId": "store", "storeId": "banner", "terminalId": "terminal"}, want: false, }, "Empty elements": { attr: defaultAttr, filter: map[string]string{"bannerId": "", "storeId": "", "terminalId": ""}, want: false, }, "Key not in attr": { attr: defaultAttr, filter: map[string]string{"keyNotInAttr": "something", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"}, want: false, }, "Key not in attr and empty": { attr: defaultAttr, filter: map[string]string{"keyNotInAttr": "", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"}, want: false, }, "Empty attr": { attr: nil, filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"}, want: false, }, "Empty filter": { attr: defaultAttr, filter: map[string]string{}, want: true, }, "Nil Filter": { attr: defaultAttr, filter: nil, want: true, }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { assert.Equal(t, tc.want, isFilterMatch(tc.attr, tc.filter)) }) } } func TestPublish(t *testing.T) { mTopic := mockTopic{mockResults: []*mockResult{{}}} mockClient := mockClientForTopicInProject{ mockTopics: []*mockTopic{&mTopic}, } message, err := msgdata.NewV1_0Request(defaultReqData) assert.NoError(t, err) populateAttrs(message) // Initialise message service using the NewMessageService public api, then // replace the pubsub client with the mock client ms, err := NewMessageService(context.Background()) assert.NoError(t, err) ms.ps = &mockClient err = ms.Publish(context.Background(), "abcd", "efgh", message) assert.NoError(t, err) assert.Equal(t, 1, mockClient.callCount) expectedAttr := map[string]string{ "bannerId": "a", "storeId": "b", "terminalId": "c", "identity": "d", "sessionId": "e", "signature": "g", "commandId": "h", "type": "command", "version": "1.0", } assert.Equal(t, expectedAttr, mTopic.calledWith[0].Attributes()) assert.JSONEq(t, `{"command": "echo 4"}`, string(mTopic.calledWith[0].Data())) assert.Equal(t, defaultReqAttr["commandId"], mTopic.calledWith[0].OrderingKey()) // TODO test message.Get called } func TestPublishCaching(t *testing.T) { // Test that unique (topic, project) tuples are cached and the topic is not // recreated from the client // Test correct topic is published to when retreiving from cache // List of topics to return each time TopicInProject is called topics := []*mockTopic{ { mockResults: []*mockResult{{}, {}}, }, { mockResults: []*mockResult{{}, {}}, }, { mockResults: []*mockResult{{}, {}}, }, { mockResults: []*mockResult{{}, {}}, }, } mockClient := mockClientForTopicInProject{ mockTopics: []*mockTopic{topics[0], topics[1], topics[2], topics[3]}, } message, err := msgdata.NewV1_0Request(defaultReqData) assert.NoError(t, err) ms := MessageService{ ps: &mockClient, topicCache: make(map[topicEntry]topicItfc), logger: logr.Discard(), } tests := []struct { testName string topic string project string clientCallCount int // Count of times the projectInTopic method should have been called // List of number of times the Publish method on each topic should be // called for the given test case topicsCallCount []int }{ {"Initial Entry", "abcd", "efgh", 1, []int{1, 0, 0, 0}, }, { "Add Topic with new topic ID", // Test new topic instance created and used "ijkl", "efgh", 2, []int{1, 1, 0, 0}, }, { "Add Topic with new project ID", // Test new topic instance created and used "abcd", "mnop", 3, []int{1, 1, 1, 0}, }, { "Reuse entry from cache", // Test existing topic instance used "abcd", "efgh", 3, []int{2, 1, 1, 0}, }, } for _, tc := range tests { // capture range variables tc := tc t.Run(tc.testName, func(t *testing.T) { // Publish on the given topic err = ms.Publish(context.Background(), tc.topic, tc.project, message) assert.NoError(t, err) // Check the ProjectInTopic method is called the appropriate number // of times assert.Equal(t, tc.clientCallCount, mockClient.callCount) // Check each topic has been published to the correct number of times for i, val := range tc.topicsCallCount { assert.Equal(t, val, topics[i].callCount) } }) } } func TestStopPublishing(t *testing.T) { // Test StopPublish only applies to a single topic in the cache, // and that an already stopped topic is not stopped again // TODO should we test that ProjectInTopic is called again to refill the // cache once Publish is called on a stopped topic? topics := []*mockTopic{ { mockResults: []*mockResult{{}, {}}, }, { mockResults: []*mockResult{{}, {}}, }, } mockClient := mockClientForTopicInProject{ mockTopics: []*mockTopic{topics[0], topics[1]}, } message, err := msgdata.NewV1_0Request(defaultReqData) assert.NoError(t, err) ms := MessageService{ ps: &mockClient, topicCache: make(map[topicEntry]topicItfc), logger: logr.Discard(), } err = ms.Publish(context.Background(), "abcd", "efgh", message) assert.NoError(t, err) err = ms.Publish(context.Background(), "ijkl", "mnop", message) assert.NoError(t, err) ms.StopPublish("abcd", "efgh") assert.Equal(t, 1, topics[0].stopCallCount) assert.Equal(t, 0, topics[1].stopCallCount) ms.StopPublish("abcd", "efgh") assert.Equal(t, 1, topics[0].stopCallCount) assert.Equal(t, 0, topics[1].stopCallCount) ms.StopPublish("qrst", "uvwx") assert.Equal(t, 1, topics[0].stopCallCount) assert.Equal(t, 0, topics[1].stopCallCount) } // Implements client interface (by embedding) but only used for testig subscribing // Initialise with a list of mockSubscriptions. On each SubscriptionInProject // call returns next element in list. Tracks number of times SubscriptionInProject // is called type mockClientForSubscribeInProject struct { clientItfc callCount int mockSubscriptions []*mockSubscriptionSynch } func (mC *mockClientForSubscribeInProject) SubscriptionInProject(_ string, _ string) subscriptionItfc { subs := mC.mockSubscriptions[mC.callCount] mC.callCount++ return subs } // implements subscriptionInt interface. Initialise with list of messages to be // processed in subsequent calls to receive. // panics if you don't initialise with enough messages // Calls handler synchronously so no goroutine management is needed type mockSubscriptionSynch struct { subscriptionItfc callCount int messages []*mockMessage } func (mS *mockSubscriptionSynch) Receive(ctx context.Context, handler func(ctx context.Context, msg messageItfc)) error { callNo := mS.callCount mS.callCount++ message := mS.messages[callNo] handler(ctx, message) // TODO return error? return nil } type mockMessage struct { data []byte attributes map[string]string dataCallCount int attributesCallCount int ackCallCount int nackCallCount int } func (mM *mockMessage) ID() string { return uuid.NewString() } func (mM *mockMessage) Ack() { mM.ackCallCount++ } func (mM *mockMessage) Nack() { mM.nackCallCount++ } func (mM *mockMessage) Attributes() map[string]string { mM.attributesCallCount++ return mM.attributes } func (mM *mockMessage) Data() []byte { mM.dataCallCount++ return mM.data } func (mM *mockMessage) OrderingKey() string { return "" } func (mM *mockMessage) SetOrderingKey(_ string) { } func (mM *mockMessage) messageOnlyAckedOnce(t *testing.T) { if mM.nackCallCount != 0 { t.Errorf("Message should be Ack'ed not Nack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount) } if mM.ackCallCount != 1 { t.Errorf("Message should be Ack'ed once, message Ack'ed %v times", mM.ackCallCount) } } func (mM *mockMessage) messageOnlyNackedOnce(t *testing.T) { if mM.ackCallCount != 0 { t.Errorf("Message should be Nack'ed not Ack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount) } if mM.nackCallCount != 1 { t.Errorf("Message should be Nack'ed once, message Nack'ed %v times", mM.nackCallCount) } } func TestFilterSubscribeSkip(t *testing.T) { filter := map[string]string{ "filterkeynotinattrs": "val", } testHandler := func(_ context.Context, _ msgdata.CommandResponse) {} testMockMessage := mockMessage{ data: defaultRespData, attributes: defaultRespAttr, } mockSub := mockSubscriptionSynch{ messages: []*mockMessage{&testMockMessage}, } mockClient := mockClientForSubscribeInProject{ mockSubscriptions: []*mockSubscriptionSynch{&mockSub}, } ms := MessageService{ ps: &mockClient, logger: logr.Discard(), } err := ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter) assert.NoError(t, err) // Test message is Nack'ed and no data is accessed testMockMessage.messageOnlyNackedOnce(t) assert.Equal(t, 1, testMockMessage.attributesCallCount) assert.Equal(t, 0, testMockMessage.dataCallCount) } func TestFilterSubscribe(t *testing.T) { filter := defaultRespAttr testMockMessage := mockMessage{ data: defaultRespData, attributes: defaultRespAttr, } expMessage, err := msgdata.NewCommandResponse(defaultRespData, defaultRespAttr) assert.NoError(t, err) testHandler := func(_ context.Context, msg msgdata.CommandResponse) { assert.Equal(t, expMessage, msg) } mockSub := mockSubscriptionSynch{ messages: []*mockMessage{&testMockMessage}, } mockClient := mockClientForSubscribeInProject{ mockSubscriptions: []*mockSubscriptionSynch{&mockSub}, } // Initialise message service using the NewMessageService public api, then // replace the pubsub client with the mock client ms, err := NewMessageService(context.Background()) assert.NoError(t, err) ms.ps = &mockClient err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter) assert.NoError(t, err) // Test message is Ack'ed and data is accessed testMockMessage.messageOnlyAckedOnce(t) assert.Equal(t, 2, testMockMessage.attributesCallCount) assert.Equal(t, 1, testMockMessage.dataCallCount) } func TestInvalidMessageNack(t *testing.T) { testMockMessage := mockMessage{ data: []byte(`{"invalid json": "stri`), attributes: defaultRespAttr, } testHandler := func(_ context.Context, _ msgdata.CommandResponse) {} mockSub := mockSubscriptionSynch{ messages: []*mockMessage{&testMockMessage}, } mockClient := mockClientForSubscribeInProject{ mockSubscriptions: []*mockSubscriptionSynch{&mockSub}, } // Initialise message service using the NewMessageService public api, then // replace the pubsub client with the mock client ms, err := NewMessageService(context.Background()) assert.NoError(t, err) ms.ps = &mockClient err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, nil) assert.NoError(t, err) testMockMessage.messageOnlyNackedOnce(t) assert.Equal(t, 2, testMockMessage.attributesCallCount) assert.Equal(t, 1, testMockMessage.dataCallCount) } type createSubscriptionClient struct { clientItfc subscriptionID string cfg subscriptionCfg } func (cl *createSubscriptionClient) CreateSubscription(_ context.Context, subscriptionID string, cfg subscriptionCfg) (subscriptionItfc, error) { cl.subscriptionID = subscriptionID cl.cfg = cfg return nil, nil } func TestCreateSubscription(t *testing.T) { t.Parallel() tests := map[string]struct { sessionID string subscriptionID string projectID string topicID string expCfg subscriptionCfg }{ "Create Subscription": { sessionID: "abcd", subscriptionID: "efgh", projectID: "ijkl", topicID: "mnop", expCfg: subscriptionCfg{ topicName: "mnop", projectID: "ijkl", retentionDuration: time.Hour, expirationPolicy: 24 * time.Hour, filter: `attributes.sessionId="abcd"`, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() // Initialise message service using the NewMessageService public api, then // replace the pubsub client with the mock client ms, err := NewMessageService(context.Background()) assert.NoError(t, err) mockCl := createSubscriptionClient{} ms.ps = &mockCl err = ms.CreateSubscription(context.Background(), tc.sessionID, tc.subscriptionID, tc.projectID, tc.topicID) assert.NoError(t, err) assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID) assert.Equal(t, tc.expCfg, mockCl.cfg) }) } } // Implement both clientItfc and subscriptionItfc so that we can reuse the same // type in tests, cutting down on setup boilerplate type deleteSubscriptionClient struct { clientItfc subscriptionItfc subscriptionID string projectID string retErr error } func (cl *deleteSubscriptionClient) SubscriptionInProject(subscriptionID, projectID string) subscriptionItfc { cl.subscriptionID = subscriptionID cl.projectID = projectID // Return itself as a subscription to cut down on test specific types return cl } func (cl *deleteSubscriptionClient) Delete(_ context.Context) error { return cl.retErr } func TestDeleteSubscription(t *testing.T) { t.Parallel() tests := map[string]struct { subscriptionID string projectID string retErr error expErr assert.ErrorAssertionFunc }{ "Normal": { subscriptionID: "abcd", projectID: "efgh", retErr: nil, expErr: assert.NoError, }, "Error": { subscriptionID: "abcd", projectID: "efgh", retErr: fmt.Errorf("bad"), expErr: assert.Error, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ms, err := NewMessageService(context.Background()) assert.NoError(t, err) mockCl := deleteSubscriptionClient{ retErr: tc.retErr, } ms.ps = &mockCl err = ms.DeleteSubscription(context.Background(), tc.subscriptionID, tc.projectID) tc.expErr(t, err) assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID) assert.Equal(t, tc.projectID, mockCl.projectID) }) } }