...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/msgsvc/msgsvc_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/msgsvc

     1  package msgsvc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/go-logr/logr"
    10  	"github.com/google/uuid"
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
    14  )
    15  
    16  var (
    17  	defaultReqAttr = map[string]string{
    18  		"bannerId":   "a",
    19  		"storeId":    "b",
    20  		"terminalId": "c",
    21  		"identity":   "d",
    22  		"sessionId":  "e",
    23  		"signature":  "g",
    24  		"commandId":  "h",
    25  	}
    26  	defaultReqData  = "echo 4"
    27  	defaultRespAttr = map[string]string{
    28  		"bannerId":   "a",
    29  		"storeId":    "b",
    30  		"terminalId": "c",
    31  		"identity":   "d",
    32  		"sessionId":  "e",
    33  
    34  		"request-message-uuid": "f",
    35  	}
    36  	defaultRespData = []byte(`
    37  		{
    38  			"type": "Output",
    39  			"exitCode": 0,
    40  			"output": "4\n",
    41  			"timestamp": "01-01-2023 00:00:00",
    42  			"duration": 0.1
    43  		}`,
    44  	)
    45  )
    46  
    47  func populateAttrs(request msgdata.Request) {
    48  	for k, v := range defaultReqAttr {
    49  		request.AddAttribute(k, v)
    50  	}
    51  }
    52  
    53  // Implements client interface (by embedding) but only used for testig topics
    54  type mockClientForTopicInProject struct {
    55  	clientItfc
    56  	callCount  int
    57  	mockTopics []*mockTopic
    58  }
    59  
    60  func (mT *mockClientForTopicInProject) TopicInProject(_ string, _ string) topicItfc {
    61  	top := mT.mockTopics[mT.callCount]
    62  	mT.callCount++
    63  	return top
    64  }
    65  
    66  type mockTopic struct {
    67  	callCount     int
    68  	stopCallCount int
    69  	mockResults   []*mockResult
    70  	calledWith    []messageItfc
    71  }
    72  
    73  func (m *mockTopic) ID() string {
    74  	return "mock topic"
    75  }
    76  
    77  func (m *mockTopic) Publish(_ context.Context, msg messageItfc) publishResultItfc {
    78  	m.calledWith = append(m.calledWith, msg)
    79  	res := m.mockResults[m.callCount]
    80  	m.callCount++
    81  	return res
    82  }
    83  
    84  func (m *mockTopic) Stop() {
    85  	m.stopCallCount++
    86  }
    87  
    88  func (m *mockTopic) SetOrdering(bool) {
    89  
    90  }
    91  
    92  type mockResult struct {
    93  	// publishResulter
    94  }
    95  
    96  func (r mockResult) Get(_ context.Context) (serverID string, err error) {
    97  	return "a", nil
    98  }
    99  
   100  func TestSubscriptionFilter(t *testing.T) {
   101  	defaultAttr := map[string]string{
   102  		"bannerId":   "banner",
   103  		"storeId":    "store",
   104  		"terminalId": "terminal",
   105  		"sessionId":  "orderingKey",
   106  		"identity":   "identity",
   107  		"version":    "1.0",
   108  		"signature":  "signature",
   109  	}
   110  
   111  	tests := map[string]struct {
   112  		attr   map[string]string
   113  		filter map[string]string
   114  		want   bool
   115  	}{
   116  		"Success": {
   117  			attr:   defaultAttr,
   118  			filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
   119  			want:   true,
   120  		},
   121  		"Elements don't match": {
   122  			attr:   defaultAttr,
   123  			filter: map[string]string{"bannerId": "store", "storeId": "banner", "terminalId": "terminal"},
   124  			want:   false,
   125  		},
   126  		"Empty elements": {
   127  			attr:   defaultAttr,
   128  			filter: map[string]string{"bannerId": "", "storeId": "", "terminalId": ""},
   129  			want:   false,
   130  		},
   131  		"Key not in attr": {
   132  			attr:   defaultAttr,
   133  			filter: map[string]string{"keyNotInAttr": "something", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
   134  			want:   false,
   135  		},
   136  		"Key not in attr and empty": {
   137  			attr:   defaultAttr,
   138  			filter: map[string]string{"keyNotInAttr": "", "bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
   139  			want:   false,
   140  		},
   141  		"Empty attr": {
   142  			attr:   nil,
   143  			filter: map[string]string{"bannerId": "banner", "storeId": "store", "terminalId": "terminal"},
   144  			want:   false,
   145  		},
   146  		"Empty filter": {
   147  			attr:   defaultAttr,
   148  			filter: map[string]string{},
   149  			want:   true,
   150  		},
   151  		"Nil Filter": {
   152  			attr:   defaultAttr,
   153  			filter: nil,
   154  			want:   true,
   155  		},
   156  	}
   157  
   158  	for name, tc := range tests {
   159  		t.Run(name, func(t *testing.T) {
   160  			assert.Equal(t, tc.want, isFilterMatch(tc.attr, tc.filter))
   161  		})
   162  	}
   163  }
   164  
   165  func TestPublish(t *testing.T) {
   166  	mTopic := mockTopic{mockResults: []*mockResult{{}}}
   167  	mockClient := mockClientForTopicInProject{
   168  		mockTopics: []*mockTopic{&mTopic},
   169  	}
   170  
   171  	message, err := msgdata.NewV1_0Request(defaultReqData)
   172  	assert.NoError(t, err)
   173  	populateAttrs(message)
   174  
   175  	// Initialise message service using the NewMessageService public api, then
   176  	// replace the pubsub client with the mock client
   177  	ms, err := NewMessageService(context.Background())
   178  	assert.NoError(t, err)
   179  	ms.ps = &mockClient
   180  
   181  	err = ms.Publish(context.Background(), "abcd", "efgh", message)
   182  	assert.NoError(t, err)
   183  
   184  	assert.Equal(t, 1, mockClient.callCount)
   185  	expectedAttr := map[string]string{
   186  		"bannerId":   "a",
   187  		"storeId":    "b",
   188  		"terminalId": "c",
   189  		"identity":   "d",
   190  		"sessionId":  "e",
   191  		"signature":  "g",
   192  		"commandId":  "h",
   193  		"type":       "command",
   194  		"version":    "1.0",
   195  	}
   196  	assert.Equal(t, expectedAttr, mTopic.calledWith[0].Attributes())
   197  	assert.JSONEq(t, `{"command": "echo 4"}`, string(mTopic.calledWith[0].Data()))
   198  	assert.Equal(t, defaultReqAttr["commandId"], mTopic.calledWith[0].OrderingKey())
   199  	// TODO test message.Get called
   200  }
   201  
   202  func TestPublishCaching(t *testing.T) {
   203  	// Test that unique (topic, project) tuples are cached and the topic is not
   204  	// recreated from the client
   205  	// Test correct topic is published to when retreiving from cache
   206  
   207  	// List of topics to return each time TopicInProject is called
   208  	topics := []*mockTopic{
   209  		{
   210  			mockResults: []*mockResult{{}, {}},
   211  		},
   212  		{
   213  			mockResults: []*mockResult{{}, {}},
   214  		},
   215  		{
   216  			mockResults: []*mockResult{{}, {}},
   217  		},
   218  		{
   219  			mockResults: []*mockResult{{}, {}},
   220  		},
   221  	}
   222  
   223  	mockClient := mockClientForTopicInProject{
   224  		mockTopics: []*mockTopic{topics[0], topics[1], topics[2], topics[3]},
   225  	}
   226  
   227  	message, err := msgdata.NewV1_0Request(defaultReqData)
   228  	assert.NoError(t, err)
   229  
   230  	ms := MessageService{
   231  		ps:         &mockClient,
   232  		topicCache: make(map[topicEntry]topicItfc),
   233  		logger:     logr.Discard(),
   234  	}
   235  
   236  	tests := []struct {
   237  		testName        string
   238  		topic           string
   239  		project         string
   240  		clientCallCount int // Count of times the projectInTopic method should have been called
   241  		// List of number of times the Publish method on each topic should be
   242  		// called for the given test case
   243  		topicsCallCount []int
   244  	}{
   245  		{"Initial Entry",
   246  			"abcd",
   247  			"efgh",
   248  			1,
   249  			[]int{1, 0, 0, 0},
   250  		},
   251  		{
   252  			"Add Topic with new topic ID", // Test new topic instance created and used
   253  			"ijkl",
   254  			"efgh",
   255  			2,
   256  			[]int{1, 1, 0, 0},
   257  		},
   258  		{
   259  			"Add Topic with new project ID", // Test new topic instance created and used
   260  			"abcd",
   261  			"mnop",
   262  			3,
   263  			[]int{1, 1, 1, 0},
   264  		},
   265  		{
   266  			"Reuse entry from cache", // Test existing topic instance used
   267  			"abcd",
   268  			"efgh",
   269  			3,
   270  			[]int{2, 1, 1, 0},
   271  		},
   272  	}
   273  
   274  	for _, tc := range tests {
   275  		// capture range variables
   276  		tc := tc
   277  
   278  		t.Run(tc.testName, func(t *testing.T) {
   279  			// Publish on the given topic
   280  			err = ms.Publish(context.Background(), tc.topic, tc.project, message)
   281  			assert.NoError(t, err)
   282  
   283  			// Check the ProjectInTopic method is called the appropriate number
   284  			// of times
   285  			assert.Equal(t, tc.clientCallCount, mockClient.callCount)
   286  
   287  			// Check each topic has been published to the correct number of times
   288  			for i, val := range tc.topicsCallCount {
   289  				assert.Equal(t, val, topics[i].callCount)
   290  			}
   291  		})
   292  	}
   293  }
   294  
   295  func TestStopPublishing(t *testing.T) {
   296  	// Test StopPublish only applies to a single topic in the cache,
   297  	// and that an already stopped topic is not stopped again
   298  	// TODO should we test that ProjectInTopic is called again to refill the
   299  	//      cache once Publish is called on a stopped topic?
   300  	topics := []*mockTopic{
   301  		{
   302  			mockResults: []*mockResult{{}, {}},
   303  		},
   304  		{
   305  			mockResults: []*mockResult{{}, {}},
   306  		},
   307  	}
   308  
   309  	mockClient := mockClientForTopicInProject{
   310  		mockTopics: []*mockTopic{topics[0], topics[1]},
   311  	}
   312  
   313  	message, err := msgdata.NewV1_0Request(defaultReqData)
   314  	assert.NoError(t, err)
   315  
   316  	ms := MessageService{
   317  		ps:         &mockClient,
   318  		topicCache: make(map[topicEntry]topicItfc),
   319  		logger:     logr.Discard(),
   320  	}
   321  
   322  	err = ms.Publish(context.Background(), "abcd", "efgh", message)
   323  	assert.NoError(t, err)
   324  	err = ms.Publish(context.Background(), "ijkl", "mnop", message)
   325  	assert.NoError(t, err)
   326  
   327  	ms.StopPublish("abcd", "efgh")
   328  	assert.Equal(t, 1, topics[0].stopCallCount)
   329  	assert.Equal(t, 0, topics[1].stopCallCount)
   330  
   331  	ms.StopPublish("abcd", "efgh")
   332  	assert.Equal(t, 1, topics[0].stopCallCount)
   333  	assert.Equal(t, 0, topics[1].stopCallCount)
   334  
   335  	ms.StopPublish("qrst", "uvwx")
   336  	assert.Equal(t, 1, topics[0].stopCallCount)
   337  	assert.Equal(t, 0, topics[1].stopCallCount)
   338  }
   339  
   340  // Implements client interface (by embedding) but only used for testig subscribing
   341  // Initialise with a list of mockSubscriptions. On each SubscriptionInProject
   342  // call returns next element in list. Tracks number of times SubscriptionInProject
   343  // is called
   344  type mockClientForSubscribeInProject struct {
   345  	clientItfc
   346  	callCount         int
   347  	mockSubscriptions []*mockSubscriptionSynch
   348  }
   349  
   350  func (mC *mockClientForSubscribeInProject) SubscriptionInProject(_ string, _ string) subscriptionItfc {
   351  	subs := mC.mockSubscriptions[mC.callCount]
   352  	mC.callCount++
   353  	return subs
   354  }
   355  
   356  // implements subscriptionInt interface. Initialise with list of messages to be
   357  // processed in subsequent calls to receive.
   358  // panics if you don't initialise with enough messages
   359  // Calls handler synchronously so no goroutine management is needed
   360  type mockSubscriptionSynch struct {
   361  	subscriptionItfc
   362  	callCount int
   363  	messages  []*mockMessage
   364  }
   365  
   366  func (mS *mockSubscriptionSynch) Receive(ctx context.Context, handler func(ctx context.Context, msg messageItfc)) error {
   367  	callNo := mS.callCount
   368  	mS.callCount++
   369  
   370  	message := mS.messages[callNo]
   371  
   372  	handler(ctx, message)
   373  
   374  	// TODO return error?
   375  	return nil
   376  }
   377  
   378  type mockMessage struct {
   379  	data       []byte
   380  	attributes map[string]string
   381  
   382  	dataCallCount       int
   383  	attributesCallCount int
   384  	ackCallCount        int
   385  	nackCallCount       int
   386  }
   387  
   388  func (mM *mockMessage) ID() string {
   389  	return uuid.NewString()
   390  }
   391  
   392  func (mM *mockMessage) Ack() {
   393  	mM.ackCallCount++
   394  }
   395  
   396  func (mM *mockMessage) Nack() {
   397  	mM.nackCallCount++
   398  }
   399  
   400  func (mM *mockMessage) Attributes() map[string]string {
   401  	mM.attributesCallCount++
   402  	return mM.attributes
   403  }
   404  
   405  func (mM *mockMessage) Data() []byte {
   406  	mM.dataCallCount++
   407  	return mM.data
   408  }
   409  
   410  func (mM *mockMessage) OrderingKey() string {
   411  	return ""
   412  }
   413  
   414  func (mM *mockMessage) SetOrderingKey(_ string) {
   415  
   416  }
   417  func (mM *mockMessage) messageOnlyAckedOnce(t *testing.T) {
   418  	if mM.nackCallCount != 0 {
   419  		t.Errorf("Message should be Ack'ed not Nack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount)
   420  	}
   421  	if mM.ackCallCount != 1 {
   422  		t.Errorf("Message should be Ack'ed once, message Ack'ed %v times", mM.ackCallCount)
   423  	}
   424  }
   425  
   426  func (mM *mockMessage) messageOnlyNackedOnce(t *testing.T) {
   427  	if mM.ackCallCount != 0 {
   428  		t.Errorf("Message should be Nack'ed not Ack'ed: message Nack'ed %v times, Ack'ed %v times", mM.nackCallCount, mM.ackCallCount)
   429  	}
   430  	if mM.nackCallCount != 1 {
   431  		t.Errorf("Message should be Nack'ed once, message Nack'ed %v times", mM.nackCallCount)
   432  	}
   433  }
   434  
   435  func TestFilterSubscribeSkip(t *testing.T) {
   436  	filter := map[string]string{
   437  		"filterkeynotinattrs": "val",
   438  	}
   439  
   440  	testHandler := func(_ context.Context, _ msgdata.CommandResponse) {}
   441  
   442  	testMockMessage := mockMessage{
   443  		data:       defaultRespData,
   444  		attributes: defaultRespAttr,
   445  	}
   446  
   447  	mockSub := mockSubscriptionSynch{
   448  		messages: []*mockMessage{&testMockMessage},
   449  	}
   450  
   451  	mockClient := mockClientForSubscribeInProject{
   452  		mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
   453  	}
   454  
   455  	ms := MessageService{
   456  		ps:     &mockClient,
   457  		logger: logr.Discard(),
   458  	}
   459  
   460  	err := ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter)
   461  	assert.NoError(t, err)
   462  
   463  	// Test message is Nack'ed and no data is accessed
   464  	testMockMessage.messageOnlyNackedOnce(t)
   465  	assert.Equal(t, 1, testMockMessage.attributesCallCount)
   466  	assert.Equal(t, 0, testMockMessage.dataCallCount)
   467  }
   468  
   469  func TestFilterSubscribe(t *testing.T) {
   470  	filter := defaultRespAttr
   471  
   472  	testMockMessage := mockMessage{
   473  		data:       defaultRespData,
   474  		attributes: defaultRespAttr,
   475  	}
   476  
   477  	expMessage, err := msgdata.NewCommandResponse(defaultRespData, defaultRespAttr)
   478  	assert.NoError(t, err)
   479  
   480  	testHandler := func(_ context.Context, msg msgdata.CommandResponse) {
   481  		assert.Equal(t, expMessage, msg)
   482  	}
   483  
   484  	mockSub := mockSubscriptionSynch{
   485  		messages: []*mockMessage{&testMockMessage},
   486  	}
   487  
   488  	mockClient := mockClientForSubscribeInProject{
   489  		mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
   490  	}
   491  
   492  	// Initialise message service using the NewMessageService public api, then
   493  	// replace the pubsub client with the mock client
   494  	ms, err := NewMessageService(context.Background())
   495  	assert.NoError(t, err)
   496  	ms.ps = &mockClient
   497  
   498  	err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, filter)
   499  	assert.NoError(t, err)
   500  
   501  	// Test message is Ack'ed and data is accessed
   502  	testMockMessage.messageOnlyAckedOnce(t)
   503  	assert.Equal(t, 2, testMockMessage.attributesCallCount)
   504  	assert.Equal(t, 1, testMockMessage.dataCallCount)
   505  }
   506  
   507  func TestInvalidMessageNack(t *testing.T) {
   508  	testMockMessage := mockMessage{
   509  		data:       []byte(`{"invalid json": "stri`),
   510  		attributes: defaultRespAttr,
   511  	}
   512  
   513  	testHandler := func(_ context.Context, _ msgdata.CommandResponse) {}
   514  
   515  	mockSub := mockSubscriptionSynch{
   516  		messages: []*mockMessage{&testMockMessage},
   517  	}
   518  
   519  	mockClient := mockClientForSubscribeInProject{
   520  		mockSubscriptions: []*mockSubscriptionSynch{&mockSub},
   521  	}
   522  
   523  	// Initialise message service using the NewMessageService public api, then
   524  	// replace the pubsub client with the mock client
   525  	ms, err := NewMessageService(context.Background())
   526  	assert.NoError(t, err)
   527  	ms.ps = &mockClient
   528  
   529  	err = ms.Subscribe(context.Background(), "abcd", "efgh", testHandler, nil)
   530  	assert.NoError(t, err)
   531  
   532  	testMockMessage.messageOnlyNackedOnce(t)
   533  	assert.Equal(t, 2, testMockMessage.attributesCallCount)
   534  	assert.Equal(t, 1, testMockMessage.dataCallCount)
   535  }
   536  
   537  type createSubscriptionClient struct {
   538  	clientItfc
   539  	subscriptionID string
   540  	cfg            subscriptionCfg
   541  }
   542  
   543  func (cl *createSubscriptionClient) CreateSubscription(_ context.Context, subscriptionID string, cfg subscriptionCfg) (subscriptionItfc, error) {
   544  	cl.subscriptionID = subscriptionID
   545  	cl.cfg = cfg
   546  
   547  	return nil, nil
   548  }
   549  
   550  func TestCreateSubscription(t *testing.T) {
   551  	t.Parallel()
   552  
   553  	tests := map[string]struct {
   554  		sessionID      string
   555  		subscriptionID string
   556  		projectID      string
   557  		topicID        string
   558  
   559  		expCfg subscriptionCfg
   560  	}{
   561  		"Create Subscription": {
   562  			sessionID:      "abcd",
   563  			subscriptionID: "efgh",
   564  			projectID:      "ijkl",
   565  			topicID:        "mnop",
   566  
   567  			expCfg: subscriptionCfg{
   568  				topicName:         "mnop",
   569  				projectID:         "ijkl",
   570  				retentionDuration: time.Hour,
   571  				expirationPolicy:  24 * time.Hour,
   572  				filter:            `attributes.sessionId="abcd"`,
   573  			},
   574  		},
   575  	}
   576  
   577  	for name, tc := range tests {
   578  		tc := tc
   579  		t.Run(name, func(t *testing.T) {
   580  			t.Parallel()
   581  
   582  			// Initialise message service using the NewMessageService public api, then
   583  			// replace the pubsub client with the mock client
   584  			ms, err := NewMessageService(context.Background())
   585  			assert.NoError(t, err)
   586  
   587  			mockCl := createSubscriptionClient{}
   588  			ms.ps = &mockCl
   589  
   590  			err = ms.CreateSubscription(context.Background(), tc.sessionID, tc.subscriptionID, tc.projectID, tc.topicID)
   591  			assert.NoError(t, err)
   592  
   593  			assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID)
   594  			assert.Equal(t, tc.expCfg, mockCl.cfg)
   595  		})
   596  	}
   597  }
   598  
   599  // Implement both clientItfc and subscriptionItfc so that we can reuse the same
   600  // type in tests, cutting down on setup boilerplate
   601  type deleteSubscriptionClient struct {
   602  	clientItfc
   603  	subscriptionItfc
   604  	subscriptionID string
   605  	projectID      string
   606  	retErr         error
   607  }
   608  
   609  func (cl *deleteSubscriptionClient) SubscriptionInProject(subscriptionID, projectID string) subscriptionItfc {
   610  	cl.subscriptionID = subscriptionID
   611  	cl.projectID = projectID
   612  	// Return itself as a subscription to cut down on test specific types
   613  	return cl
   614  }
   615  
   616  func (cl *deleteSubscriptionClient) Delete(_ context.Context) error {
   617  	return cl.retErr
   618  }
   619  
   620  func TestDeleteSubscription(t *testing.T) {
   621  	t.Parallel()
   622  
   623  	tests := map[string]struct {
   624  		subscriptionID string
   625  		projectID      string
   626  		retErr         error
   627  
   628  		expErr assert.ErrorAssertionFunc
   629  	}{
   630  		"Normal": {
   631  			subscriptionID: "abcd",
   632  			projectID:      "efgh",
   633  			retErr:         nil,
   634  			expErr:         assert.NoError,
   635  		},
   636  		"Error": {
   637  			subscriptionID: "abcd",
   638  			projectID:      "efgh",
   639  			retErr:         fmt.Errorf("bad"),
   640  			expErr:         assert.Error,
   641  		},
   642  	}
   643  
   644  	for name, tc := range tests {
   645  		tc := tc
   646  		t.Run(name, func(t *testing.T) {
   647  			t.Parallel()
   648  
   649  			ms, err := NewMessageService(context.Background())
   650  			assert.NoError(t, err)
   651  
   652  			mockCl := deleteSubscriptionClient{
   653  				retErr: tc.retErr,
   654  			}
   655  			ms.ps = &mockCl
   656  
   657  			err = ms.DeleteSubscription(context.Background(), tc.subscriptionID, tc.projectID)
   658  			tc.expErr(t, err)
   659  
   660  			assert.Equal(t, tc.subscriptionID, mockCl.subscriptionID)
   661  			assert.Equal(t, tc.projectID, mockCl.projectID)
   662  		})
   663  	}
   664  }
   665  

View as plain text