...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/requestservice/requestservice_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/requestservice

     1  package requestservice
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/DATA-DOG/go-sqlmock"
     9  	"github.com/hashicorp/go-version"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
    14  	"edge-infra.dev/pkg/sds/emergencyaccess/types"
    15  )
    16  
    17  type helper interface {
    18  	Helper()
    19  }
    20  
    21  func EqualError(message string) assert.ErrorAssertionFunc {
    22  	return func(t assert.TestingT, err error, i ...interface{}) bool {
    23  		if help, ok := t.(helper); ok {
    24  			help.Helper()
    25  		}
    26  		return assert.EqualError(t, err, message, i...)
    27  	}
    28  }
    29  
    30  func TestCreateRequest(t *testing.T) {
    31  	t.Parallel()
    32  
    33  	version1_0, err := version.NewVersion(string(eaconst.MessageVersion1_0))
    34  	require.NoError(t, err)
    35  	version2_0, err := version.NewVersion(string(eaconst.MessageVersion2_0))
    36  	require.NoError(t, err)
    37  	badVersion, err := version.NewVersion("0.0")
    38  	require.NoError(t, err)
    39  
    40  	tests := map[string]struct {
    41  		payload        string
    42  		messageVersion *version.Version
    43  		expectedData   string
    44  		expectedAttr   map[string]string
    45  		errAssert      assert.ErrorAssertionFunc
    46  		expNil         bool
    47  	}{
    48  		"Bad Version": {
    49  			messageVersion: badVersion,
    50  			errAssert:      EqualError("unsupported message version 0.0"),
    51  			expNil:         true,
    52  		},
    53  		"Version 1.0 Command": {
    54  			payload:        "echo hello there",
    55  			messageVersion: version1_0,
    56  			expectedData: `{
    57  				"command": "echo hello there"
    58  			}`,
    59  			expectedAttr: map[string]string{
    60  				eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
    61  				eaconst.RequestTypeKey: string(eaconst.Command),
    62  			},
    63  			errAssert: assert.NoError,
    64  			expNil:    false,
    65  		},
    66  		// This will ultimately not be valid, but v1 messages will interpret all payloads as commands
    67  		"Version 1.0 Script": {
    68  			payload:        "./myScript hello there",
    69  			messageVersion: version1_0,
    70  			expectedData: `{
    71  				"command": "./myScript hello there"
    72  			}`,
    73  			expectedAttr: map[string]string{
    74  				eaconst.VersionKey:     string(eaconst.MessageVersion1_0),
    75  				eaconst.RequestTypeKey: string(eaconst.Command),
    76  			},
    77  			errAssert: assert.NoError,
    78  			expNil:    false,
    79  		},
    80  		"Version 2.0 Command": {
    81  			payload:        "echo hello there",
    82  			messageVersion: version2_0,
    83  			expectedData: `{
    84  				"command": "echo",
    85  				"args": ["hello", "there"]
    86  			}`,
    87  			expectedAttr: map[string]string{
    88  				eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
    89  				eaconst.RequestTypeKey: string(eaconst.Command),
    90  			},
    91  			errAssert: assert.NoError,
    92  			expNil:    false,
    93  		},
    94  		"Version 2.0 Script": {
    95  			payload:        "./myScript hello there",
    96  			messageVersion: version2_0,
    97  			expectedData: `{
    98  				"executable": {
    99  					"name": "myScript",
   100  					"contents": ""
   101  				},
   102  				"args": ["hello", "there"]
   103  			}`,
   104  			expectedAttr: map[string]string{
   105  				eaconst.VersionKey:     string(eaconst.MessageVersion2_0),
   106  				eaconst.RequestTypeKey: string(eaconst.Executable),
   107  			},
   108  			errAssert: assert.NoError,
   109  			expNil:    false,
   110  		},
   111  	}
   112  
   113  	for name, tc := range tests {
   114  		tc := tc
   115  		t.Run(name, func(t *testing.T) {
   116  			t.Parallel()
   117  
   118  			config := Config{
   119  				Target: types.Target{
   120  					Projectid:  "projectID",
   121  					Bannerid:   "bannerID",
   122  					Storeid:    "storeID",
   123  					Terminalid: "terminalID",
   124  				},
   125  			}
   126  			rs, err := New(nil)
   127  			require.NoError(t, err)
   128  			rs.versionCache.cache[config.Target] = tc.messageVersion
   129  
   130  			request, err := rs.CreateRequest(context.Background(), tc.payload, config)
   131  			tc.errAssert(t, err)
   132  			assert.Equal(t, tc.expNil, request == nil)
   133  			if !tc.expNil {
   134  				actualData, err := request.Data()
   135  				assert.NoError(t, err)
   136  				assert.JSONEq(t, tc.expectedData, string(actualData))
   137  				assert.Equal(t, tc.expectedAttr, request.Attributes())
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func TestGetMessageVersion(t *testing.T) {
   144  	t.Parallel()
   145  
   146  	target := types.Target{
   147  		Projectid:  "projectID",
   148  		Bannerid:   "bannerID",
   149  		Storeid:    "storeID",
   150  		Terminalid: "terminalID",
   151  	}
   152  
   153  	testVersions := versionsMap{
   154  		"1.0":  "1.0",
   155  		"1.16": "2.0",
   156  	}
   157  
   158  	expected1_0, err := version.NewVersion("1.0")
   159  	require.NoError(t, err)
   160  	expected2_0, err := version.NewVersion("2.0")
   161  	require.NoError(t, err)
   162  	expectedMinimum, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion))
   163  	require.NoError(t, err)
   164  
   165  	tests := map[string]struct {
   166  		expectations func(mock sqlmock.Sqlmock)
   167  		expected     *version.Version
   168  		errAssert    assert.ErrorAssertionFunc
   169  	}{
   170  		"Success 1.0": {
   171  			expectations: func(mock sqlmock.Sqlmock) {
   172  				mock.ExpectQuery(getEdgeOSVersionQuery).
   173  					WithArgs(target.Storeid, target.Terminalid).
   174  					WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0"))
   175  			},
   176  			expected:  expected1_0,
   177  			errAssert: assert.NoError,
   178  		},
   179  		"Success With EdgeOS Suffix": {
   180  			expectations: func(mock sqlmock.Sqlmock) {
   181  				mock.ExpectQuery(getEdgeOSVersionQuery).
   182  					WithArgs(target.Storeid, target.Terminalid).
   183  					WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0-085ae663-dev"))
   184  			},
   185  			expected:  expected1_0,
   186  			errAssert: assert.NoError,
   187  		},
   188  		"Success 2.0": {
   189  			expectations: func(mock sqlmock.Sqlmock) {
   190  				mock.ExpectQuery(getEdgeOSVersionQuery).
   191  					WithArgs(target.Storeid, target.Terminalid).
   192  					WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v2.0.0"))
   193  			},
   194  			expected:  expected2_0,
   195  			errAssert: assert.NoError,
   196  		},
   197  		"Query Error": {
   198  			expectations: func(mock sqlmock.Sqlmock) {
   199  				mock.ExpectQuery(getEdgeOSVersionQuery).
   200  					WithArgs(target.Storeid, target.Terminalid).
   201  					WillReturnError(fmt.Errorf("error"))
   202  			},
   203  			errAssert: EqualError("error scanning edgeOS version results: error"),
   204  		},
   205  		"No Rows": {
   206  			expectations: func(mock sqlmock.Sqlmock) {
   207  				mock.ExpectQuery(getEdgeOSVersionQuery).
   208  					WithArgs(target.Storeid, target.Terminalid).
   209  					WillReturnRows(sqlmock.NewRows([]string{"value"}))
   210  			},
   211  			expected:  expectedMinimum,
   212  			errAssert: assert.NoError,
   213  		},
   214  	}
   215  
   216  	for name, tc := range tests {
   217  		tc := tc
   218  		t.Run(name, func(t *testing.T) {
   219  			t.Parallel()
   220  
   221  			db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
   222  			require.NoError(t, err)
   223  			defer db.Close()
   224  			tc.expectations(mock)
   225  
   226  			rs, err := New(db)
   227  			require.NoError(t, err)
   228  			rs.versions, err = messageVersionMappings(testVersions)
   229  			require.NoError(t, err)
   230  			v, err := rs.getMessageVersion(context.Background(), target)
   231  			tc.errAssert(t, err)
   232  			assert.Equal(t, tc.expected, v)
   233  			assert.NoError(t, mock.ExpectationsWereMet())
   234  		})
   235  	}
   236  }
   237  
   238  func TestGetMessageVersionCache(t *testing.T) {
   239  	t.Parallel()
   240  
   241  	target := types.Target{
   242  		Projectid:  "projectID",
   243  		Bannerid:   "bannerID",
   244  		Storeid:    "storeID",
   245  		Terminalid: "terminalID",
   246  	}
   247  
   248  	db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
   249  	require.NoError(t, err)
   250  	defer db.Close()
   251  	mock.ExpectQuery(getEdgeOSVersionQuery).
   252  		WithArgs(target.Storeid, target.Terminalid).
   253  		WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0"))
   254  
   255  	rs, err := New(db)
   256  	require.NoError(t, err)
   257  	rs.versions, err = messageVersionMappings(versionsMap{"1.0": "1.0"})
   258  	require.NoError(t, err)
   259  	// Assert that cache is empty
   260  	_, ok := rs.versionCache.Get(target)
   261  	assert.False(t, ok)
   262  
   263  	// On first call with target, we will query the db and insert the result into cache
   264  	messageVersion, err := rs.getMessageVersion(context.Background(), target)
   265  	assert.NoError(t, err)
   266  	expected, err := version.NewVersion("1.0")
   267  	require.NoError(t, err)
   268  	assert.Equal(t, expected, messageVersion)
   269  
   270  	val, ok := rs.versionCache.Get(target)
   271  	assert.True(t, ok)
   272  	assert.Equal(t, messageVersion, val)
   273  
   274  	// On second call with target, we will use the value stored in cache and not call db
   275  	messageVersion, err = rs.getMessageVersion(context.Background(), target)
   276  	assert.NoError(t, err)
   277  	assert.Equal(t, expected, messageVersion)
   278  
   279  	assert.NoError(t, mock.ExpectationsWereMet())
   280  }
   281  
   282  func TestPickMessageVersionSuccess(t *testing.T) {
   283  	t.Parallel()
   284  
   285  	testVersions := versionsMap{
   286  		"1.0":   "0",
   287  		"2.0":   "0",
   288  		"2.1":   "0",
   289  		"2.5":   "0",
   290  		"2.99":  "0",
   291  		"10.10": "0",
   292  	}
   293  
   294  	success, err := version.NewVersion("0.1")
   295  	require.NoError(t, err)
   296  
   297  	minimumVersion, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion))
   298  	require.NoError(t, err)
   299  
   300  	tests := map[string]struct {
   301  		version        string
   302  		keyToMatchWith string
   303  		expected       *version.Version
   304  	}{
   305  		"1.0": {
   306  			version:        "1.0",
   307  			keyToMatchWith: "1.0",
   308  			expected:       success,
   309  		},
   310  		"1.5": {
   311  			version:        "1.5",
   312  			keyToMatchWith: "1.0",
   313  			expected:       success,
   314  		},
   315  		"2.0": {
   316  			version:        "2.0",
   317  			keyToMatchWith: "2.0",
   318  			expected:       success,
   319  		},
   320  		"2.1": {
   321  			version:        "2.1",
   322  			keyToMatchWith: "2.1",
   323  			expected:       success,
   324  		},
   325  		"2.3": {
   326  			version:        "2.3",
   327  			keyToMatchWith: "2.1",
   328  			expected:       success,
   329  		},
   330  		"2.98": {
   331  			version:        "2.98",
   332  			keyToMatchWith: "2.5",
   333  			expected:       success,
   334  		},
   335  		"2.100": {
   336  			version:        "2.100",
   337  			keyToMatchWith: "2.99",
   338  			expected:       success,
   339  		},
   340  		"5.0": {
   341  			version:        "5.0",
   342  			keyToMatchWith: "2.99",
   343  			expected:       success,
   344  		},
   345  		"10.10": {
   346  			version:        "10.10",
   347  			keyToMatchWith: "10.10",
   348  			expected:       success,
   349  		},
   350  		"11.0": {
   351  			version:        "11.0",
   352  			keyToMatchWith: "10.10",
   353  			expected:       success,
   354  		},
   355  		"Not a version": {
   356  			version:  "notaversion",
   357  			expected: minimumVersion,
   358  		},
   359  	}
   360  
   361  	for name, tc := range tests {
   362  		tc := tc
   363  		t.Run(name, func(t *testing.T) {
   364  			t.Parallel()
   365  
   366  			// Make a copy of testVersions and set the version we expect to
   367  			// be picked to not return an error.
   368  			versions := make(versionsMap)
   369  			for key, value := range testVersions {
   370  				if key == tc.keyToMatchWith {
   371  					value = "0.1"
   372  				}
   373  				versions[key] = value
   374  			}
   375  
   376  			versionMappings, err := messageVersionMappings(versions)
   377  			require.NoError(t, err)
   378  			rs := RequestService{versions: versionMappings, minimumVersion: minimumVersion}
   379  			actual, err := rs.pickMessageVersion(context.Background(), tc.version)
   380  			assert.NoError(t, err)
   381  			assert.Equal(t, tc.expected, actual)
   382  		})
   383  	}
   384  }
   385  
   386  func TestPickMessageVersionNoMatch(t *testing.T) {
   387  	t.Parallel()
   388  
   389  	testVersions := versionsMap{
   390  		"2.5": "0",
   391  		"4.0": "0",
   392  	}
   393  
   394  	tests := map[string]struct {
   395  		version string
   396  	}{
   397  		"Less than any keys": {
   398  			version: "1.0",
   399  		},
   400  		"Less than first matching major": {
   401  			version: "2.0",
   402  		},
   403  	}
   404  
   405  	for name, tc := range tests {
   406  		tc := tc
   407  		t.Run(name, func(t *testing.T) {
   408  			t.Parallel()
   409  
   410  			versions := make(versionsMap)
   411  			for key, value := range testVersions {
   412  				versions[key] = value
   413  			}
   414  
   415  			versionMappings, err := messageVersionMappings(versions)
   416  			require.NoError(t, err)
   417  			rs := RequestService{versions: versionMappings}
   418  			_, err = rs.pickMessageVersion(context.Background(), tc.version)
   419  			assert.Error(t, err)
   420  		})
   421  	}
   422  }
   423  

View as plain text