package requestservice import ( "context" "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/hashicorp/go-version" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" "edge-infra.dev/pkg/sds/emergencyaccess/types" ) type helper interface { Helper() } func EqualError(message string) assert.ErrorAssertionFunc { return func(t assert.TestingT, err error, i ...interface{}) bool { if help, ok := t.(helper); ok { help.Helper() } return assert.EqualError(t, err, message, i...) } } func TestCreateRequest(t *testing.T) { t.Parallel() version1_0, err := version.NewVersion(string(eaconst.MessageVersion1_0)) require.NoError(t, err) version2_0, err := version.NewVersion(string(eaconst.MessageVersion2_0)) require.NoError(t, err) badVersion, err := version.NewVersion("0.0") require.NoError(t, err) tests := map[string]struct { payload string messageVersion *version.Version expectedData string expectedAttr map[string]string errAssert assert.ErrorAssertionFunc expNil bool }{ "Bad Version": { messageVersion: badVersion, errAssert: EqualError("unsupported message version 0.0"), expNil: true, }, "Version 1.0 Command": { payload: "echo hello there", messageVersion: version1_0, expectedData: `{ "command": "echo hello there" }`, expectedAttr: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, errAssert: assert.NoError, expNil: false, }, // This will ultimately not be valid, but v1 messages will interpret all payloads as commands "Version 1.0 Script": { payload: "./myScript hello there", messageVersion: version1_0, expectedData: `{ "command": "./myScript hello there" }`, expectedAttr: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion1_0), eaconst.RequestTypeKey: string(eaconst.Command), }, errAssert: assert.NoError, expNil: false, }, "Version 2.0 Command": { payload: "echo hello there", messageVersion: version2_0, expectedData: `{ "command": "echo", "args": ["hello", "there"] }`, expectedAttr: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion2_0), eaconst.RequestTypeKey: string(eaconst.Command), }, errAssert: assert.NoError, expNil: false, }, "Version 2.0 Script": { payload: "./myScript hello there", messageVersion: version2_0, expectedData: `{ "executable": { "name": "myScript", "contents": "" }, "args": ["hello", "there"] }`, expectedAttr: map[string]string{ eaconst.VersionKey: string(eaconst.MessageVersion2_0), eaconst.RequestTypeKey: string(eaconst.Executable), }, errAssert: assert.NoError, expNil: false, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() config := Config{ Target: types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", }, } rs, err := New(nil) require.NoError(t, err) rs.versionCache.cache[config.Target] = tc.messageVersion request, err := rs.CreateRequest(context.Background(), tc.payload, config) tc.errAssert(t, err) assert.Equal(t, tc.expNil, request == nil) if !tc.expNil { actualData, err := request.Data() assert.NoError(t, err) assert.JSONEq(t, tc.expectedData, string(actualData)) assert.Equal(t, tc.expectedAttr, request.Attributes()) } }) } } func TestGetMessageVersion(t *testing.T) { t.Parallel() target := types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", } testVersions := versionsMap{ "1.0": "1.0", "1.16": "2.0", } expected1_0, err := version.NewVersion("1.0") require.NoError(t, err) expected2_0, err := version.NewVersion("2.0") require.NoError(t, err) expectedMinimum, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion)) require.NoError(t, err) tests := map[string]struct { expectations func(mock sqlmock.Sqlmock) expected *version.Version errAssert assert.ErrorAssertionFunc }{ "Success 1.0": { expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0")) }, expected: expected1_0, errAssert: assert.NoError, }, "Success With EdgeOS Suffix": { expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0-085ae663-dev")) }, expected: expected1_0, errAssert: assert.NoError, }, "Success 2.0": { expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v2.0.0")) }, expected: expected2_0, errAssert: assert.NoError, }, "Query Error": { expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnError(fmt.Errorf("error")) }, errAssert: EqualError("error scanning edgeOS version results: error"), }, "No Rows": { expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnRows(sqlmock.NewRows([]string{"value"})) }, expected: expectedMinimum, errAssert: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.NoError(t, err) defer db.Close() tc.expectations(mock) rs, err := New(db) require.NoError(t, err) rs.versions, err = messageVersionMappings(testVersions) require.NoError(t, err) v, err := rs.getMessageVersion(context.Background(), target) tc.errAssert(t, err) assert.Equal(t, tc.expected, v) assert.NoError(t, mock.ExpectationsWereMet()) }) } } func TestGetMessageVersionCache(t *testing.T) { t.Parallel() target := types.Target{ Projectid: "projectID", Bannerid: "bannerID", Storeid: "storeID", Terminalid: "terminalID", } db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) require.NoError(t, err) defer db.Close() mock.ExpectQuery(getEdgeOSVersionQuery). WithArgs(target.Storeid, target.Terminalid). WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0")) rs, err := New(db) require.NoError(t, err) rs.versions, err = messageVersionMappings(versionsMap{"1.0": "1.0"}) require.NoError(t, err) // Assert that cache is empty _, ok := rs.versionCache.Get(target) assert.False(t, ok) // On first call with target, we will query the db and insert the result into cache messageVersion, err := rs.getMessageVersion(context.Background(), target) assert.NoError(t, err) expected, err := version.NewVersion("1.0") require.NoError(t, err) assert.Equal(t, expected, messageVersion) val, ok := rs.versionCache.Get(target) assert.True(t, ok) assert.Equal(t, messageVersion, val) // On second call with target, we will use the value stored in cache and not call db messageVersion, err = rs.getMessageVersion(context.Background(), target) assert.NoError(t, err) assert.Equal(t, expected, messageVersion) assert.NoError(t, mock.ExpectationsWereMet()) } func TestPickMessageVersionSuccess(t *testing.T) { t.Parallel() testVersions := versionsMap{ "1.0": "0", "2.0": "0", "2.1": "0", "2.5": "0", "2.99": "0", "10.10": "0", } success, err := version.NewVersion("0.1") require.NoError(t, err) minimumVersion, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion)) require.NoError(t, err) tests := map[string]struct { version string keyToMatchWith string expected *version.Version }{ "1.0": { version: "1.0", keyToMatchWith: "1.0", expected: success, }, "1.5": { version: "1.5", keyToMatchWith: "1.0", expected: success, }, "2.0": { version: "2.0", keyToMatchWith: "2.0", expected: success, }, "2.1": { version: "2.1", keyToMatchWith: "2.1", expected: success, }, "2.3": { version: "2.3", keyToMatchWith: "2.1", expected: success, }, "2.98": { version: "2.98", keyToMatchWith: "2.5", expected: success, }, "2.100": { version: "2.100", keyToMatchWith: "2.99", expected: success, }, "5.0": { version: "5.0", keyToMatchWith: "2.99", expected: success, }, "10.10": { version: "10.10", keyToMatchWith: "10.10", expected: success, }, "11.0": { version: "11.0", keyToMatchWith: "10.10", expected: success, }, "Not a version": { version: "notaversion", expected: minimumVersion, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() // Make a copy of testVersions and set the version we expect to // be picked to not return an error. versions := make(versionsMap) for key, value := range testVersions { if key == tc.keyToMatchWith { value = "0.1" } versions[key] = value } versionMappings, err := messageVersionMappings(versions) require.NoError(t, err) rs := RequestService{versions: versionMappings, minimumVersion: minimumVersion} actual, err := rs.pickMessageVersion(context.Background(), tc.version) assert.NoError(t, err) assert.Equal(t, tc.expected, actual) }) } } func TestPickMessageVersionNoMatch(t *testing.T) { t.Parallel() testVersions := versionsMap{ "2.5": "0", "4.0": "0", } tests := map[string]struct { version string }{ "Less than any keys": { version: "1.0", }, "Less than first matching major": { version: "2.0", }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() versions := make(versionsMap) for key, value := range testVersions { versions[key] = value } versionMappings, err := messageVersionMappings(versions) require.NoError(t, err) rs := RequestService{versions: versionMappings} _, err = rs.pickMessageVersion(context.Background(), tc.version) assert.Error(t, err) }) } }