1 package requestservice
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7
8 "github.com/DATA-DOG/go-sqlmock"
9 "github.com/hashicorp/go-version"
10 "github.com/stretchr/testify/assert"
11 "github.com/stretchr/testify/require"
12
13 "edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
14 "edge-infra.dev/pkg/sds/emergencyaccess/types"
15 )
16
17 type helper interface {
18 Helper()
19 }
20
21 func EqualError(message string) assert.ErrorAssertionFunc {
22 return func(t assert.TestingT, err error, i ...interface{}) bool {
23 if help, ok := t.(helper); ok {
24 help.Helper()
25 }
26 return assert.EqualError(t, err, message, i...)
27 }
28 }
29
30 func TestCreateRequest(t *testing.T) {
31 t.Parallel()
32
33 version1_0, err := version.NewVersion(string(eaconst.MessageVersion1_0))
34 require.NoError(t, err)
35 version2_0, err := version.NewVersion(string(eaconst.MessageVersion2_0))
36 require.NoError(t, err)
37 badVersion, err := version.NewVersion("0.0")
38 require.NoError(t, err)
39
40 tests := map[string]struct {
41 payload string
42 messageVersion *version.Version
43 expectedData string
44 expectedAttr map[string]string
45 errAssert assert.ErrorAssertionFunc
46 expNil bool
47 }{
48 "Bad Version": {
49 messageVersion: badVersion,
50 errAssert: EqualError("unsupported message version 0.0"),
51 expNil: true,
52 },
53 "Version 1.0 Command": {
54 payload: "echo hello there",
55 messageVersion: version1_0,
56 expectedData: `{
57 "command": "echo hello there"
58 }`,
59 expectedAttr: map[string]string{
60 eaconst.VersionKey: string(eaconst.MessageVersion1_0),
61 eaconst.RequestTypeKey: string(eaconst.Command),
62 },
63 errAssert: assert.NoError,
64 expNil: false,
65 },
66
67 "Version 1.0 Script": {
68 payload: "./myScript hello there",
69 messageVersion: version1_0,
70 expectedData: `{
71 "command": "./myScript hello there"
72 }`,
73 expectedAttr: map[string]string{
74 eaconst.VersionKey: string(eaconst.MessageVersion1_0),
75 eaconst.RequestTypeKey: string(eaconst.Command),
76 },
77 errAssert: assert.NoError,
78 expNil: false,
79 },
80 "Version 2.0 Command": {
81 payload: "echo hello there",
82 messageVersion: version2_0,
83 expectedData: `{
84 "command": "echo",
85 "args": ["hello", "there"]
86 }`,
87 expectedAttr: map[string]string{
88 eaconst.VersionKey: string(eaconst.MessageVersion2_0),
89 eaconst.RequestTypeKey: string(eaconst.Command),
90 },
91 errAssert: assert.NoError,
92 expNil: false,
93 },
94 "Version 2.0 Script": {
95 payload: "./myScript hello there",
96 messageVersion: version2_0,
97 expectedData: `{
98 "executable": {
99 "name": "myScript",
100 "contents": ""
101 },
102 "args": ["hello", "there"]
103 }`,
104 expectedAttr: map[string]string{
105 eaconst.VersionKey: string(eaconst.MessageVersion2_0),
106 eaconst.RequestTypeKey: string(eaconst.Executable),
107 },
108 errAssert: assert.NoError,
109 expNil: false,
110 },
111 }
112
113 for name, tc := range tests {
114 tc := tc
115 t.Run(name, func(t *testing.T) {
116 t.Parallel()
117
118 config := Config{
119 Target: types.Target{
120 Projectid: "projectID",
121 Bannerid: "bannerID",
122 Storeid: "storeID",
123 Terminalid: "terminalID",
124 },
125 }
126 rs, err := New(nil)
127 require.NoError(t, err)
128 rs.versionCache.cache[config.Target] = tc.messageVersion
129
130 request, err := rs.CreateRequest(context.Background(), tc.payload, config)
131 tc.errAssert(t, err)
132 assert.Equal(t, tc.expNil, request == nil)
133 if !tc.expNil {
134 actualData, err := request.Data()
135 assert.NoError(t, err)
136 assert.JSONEq(t, tc.expectedData, string(actualData))
137 assert.Equal(t, tc.expectedAttr, request.Attributes())
138 }
139 })
140 }
141 }
142
143 func TestGetMessageVersion(t *testing.T) {
144 t.Parallel()
145
146 target := types.Target{
147 Projectid: "projectID",
148 Bannerid: "bannerID",
149 Storeid: "storeID",
150 Terminalid: "terminalID",
151 }
152
153 testVersions := versionsMap{
154 "1.0": "1.0",
155 "1.16": "2.0",
156 }
157
158 expected1_0, err := version.NewVersion("1.0")
159 require.NoError(t, err)
160 expected2_0, err := version.NewVersion("2.0")
161 require.NoError(t, err)
162 expectedMinimum, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion))
163 require.NoError(t, err)
164
165 tests := map[string]struct {
166 expectations func(mock sqlmock.Sqlmock)
167 expected *version.Version
168 errAssert assert.ErrorAssertionFunc
169 }{
170 "Success 1.0": {
171 expectations: func(mock sqlmock.Sqlmock) {
172 mock.ExpectQuery(getEdgeOSVersionQuery).
173 WithArgs(target.Storeid, target.Terminalid).
174 WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0"))
175 },
176 expected: expected1_0,
177 errAssert: assert.NoError,
178 },
179 "Success With EdgeOS Suffix": {
180 expectations: func(mock sqlmock.Sqlmock) {
181 mock.ExpectQuery(getEdgeOSVersionQuery).
182 WithArgs(target.Storeid, target.Terminalid).
183 WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0-085ae663-dev"))
184 },
185 expected: expected1_0,
186 errAssert: assert.NoError,
187 },
188 "Success 2.0": {
189 expectations: func(mock sqlmock.Sqlmock) {
190 mock.ExpectQuery(getEdgeOSVersionQuery).
191 WithArgs(target.Storeid, target.Terminalid).
192 WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v2.0.0"))
193 },
194 expected: expected2_0,
195 errAssert: assert.NoError,
196 },
197 "Query Error": {
198 expectations: func(mock sqlmock.Sqlmock) {
199 mock.ExpectQuery(getEdgeOSVersionQuery).
200 WithArgs(target.Storeid, target.Terminalid).
201 WillReturnError(fmt.Errorf("error"))
202 },
203 errAssert: EqualError("error scanning edgeOS version results: error"),
204 },
205 "No Rows": {
206 expectations: func(mock sqlmock.Sqlmock) {
207 mock.ExpectQuery(getEdgeOSVersionQuery).
208 WithArgs(target.Storeid, target.Terminalid).
209 WillReturnRows(sqlmock.NewRows([]string{"value"}))
210 },
211 expected: expectedMinimum,
212 errAssert: assert.NoError,
213 },
214 }
215
216 for name, tc := range tests {
217 tc := tc
218 t.Run(name, func(t *testing.T) {
219 t.Parallel()
220
221 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
222 require.NoError(t, err)
223 defer db.Close()
224 tc.expectations(mock)
225
226 rs, err := New(db)
227 require.NoError(t, err)
228 rs.versions, err = messageVersionMappings(testVersions)
229 require.NoError(t, err)
230 v, err := rs.getMessageVersion(context.Background(), target)
231 tc.errAssert(t, err)
232 assert.Equal(t, tc.expected, v)
233 assert.NoError(t, mock.ExpectationsWereMet())
234 })
235 }
236 }
237
238 func TestGetMessageVersionCache(t *testing.T) {
239 t.Parallel()
240
241 target := types.Target{
242 Projectid: "projectID",
243 Bannerid: "bannerID",
244 Storeid: "storeID",
245 Terminalid: "terminalID",
246 }
247
248 db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
249 require.NoError(t, err)
250 defer db.Close()
251 mock.ExpectQuery(getEdgeOSVersionQuery).
252 WithArgs(target.Storeid, target.Terminalid).
253 WillReturnRows(sqlmock.NewRows([]string{"value"}).AddRow("v1.14.0"))
254
255 rs, err := New(db)
256 require.NoError(t, err)
257 rs.versions, err = messageVersionMappings(versionsMap{"1.0": "1.0"})
258 require.NoError(t, err)
259
260 _, ok := rs.versionCache.Get(target)
261 assert.False(t, ok)
262
263
264 messageVersion, err := rs.getMessageVersion(context.Background(), target)
265 assert.NoError(t, err)
266 expected, err := version.NewVersion("1.0")
267 require.NoError(t, err)
268 assert.Equal(t, expected, messageVersion)
269
270 val, ok := rs.versionCache.Get(target)
271 assert.True(t, ok)
272 assert.Equal(t, messageVersion, val)
273
274
275 messageVersion, err = rs.getMessageVersion(context.Background(), target)
276 assert.NoError(t, err)
277 assert.Equal(t, expected, messageVersion)
278
279 assert.NoError(t, mock.ExpectationsWereMet())
280 }
281
282 func TestPickMessageVersionSuccess(t *testing.T) {
283 t.Parallel()
284
285 testVersions := versionsMap{
286 "1.0": "0",
287 "2.0": "0",
288 "2.1": "0",
289 "2.5": "0",
290 "2.99": "0",
291 "10.10": "0",
292 }
293
294 success, err := version.NewVersion("0.1")
295 require.NoError(t, err)
296
297 minimumVersion, err := version.NewVersion(string(eaconst.MinimumSupportedMessageVersion))
298 require.NoError(t, err)
299
300 tests := map[string]struct {
301 version string
302 keyToMatchWith string
303 expected *version.Version
304 }{
305 "1.0": {
306 version: "1.0",
307 keyToMatchWith: "1.0",
308 expected: success,
309 },
310 "1.5": {
311 version: "1.5",
312 keyToMatchWith: "1.0",
313 expected: success,
314 },
315 "2.0": {
316 version: "2.0",
317 keyToMatchWith: "2.0",
318 expected: success,
319 },
320 "2.1": {
321 version: "2.1",
322 keyToMatchWith: "2.1",
323 expected: success,
324 },
325 "2.3": {
326 version: "2.3",
327 keyToMatchWith: "2.1",
328 expected: success,
329 },
330 "2.98": {
331 version: "2.98",
332 keyToMatchWith: "2.5",
333 expected: success,
334 },
335 "2.100": {
336 version: "2.100",
337 keyToMatchWith: "2.99",
338 expected: success,
339 },
340 "5.0": {
341 version: "5.0",
342 keyToMatchWith: "2.99",
343 expected: success,
344 },
345 "10.10": {
346 version: "10.10",
347 keyToMatchWith: "10.10",
348 expected: success,
349 },
350 "11.0": {
351 version: "11.0",
352 keyToMatchWith: "10.10",
353 expected: success,
354 },
355 "Not a version": {
356 version: "notaversion",
357 expected: minimumVersion,
358 },
359 }
360
361 for name, tc := range tests {
362 tc := tc
363 t.Run(name, func(t *testing.T) {
364 t.Parallel()
365
366
367
368 versions := make(versionsMap)
369 for key, value := range testVersions {
370 if key == tc.keyToMatchWith {
371 value = "0.1"
372 }
373 versions[key] = value
374 }
375
376 versionMappings, err := messageVersionMappings(versions)
377 require.NoError(t, err)
378 rs := RequestService{versions: versionMappings, minimumVersion: minimumVersion}
379 actual, err := rs.pickMessageVersion(context.Background(), tc.version)
380 assert.NoError(t, err)
381 assert.Equal(t, tc.expected, actual)
382 })
383 }
384 }
385
386 func TestPickMessageVersionNoMatch(t *testing.T) {
387 t.Parallel()
388
389 testVersions := versionsMap{
390 "2.5": "0",
391 "4.0": "0",
392 }
393
394 tests := map[string]struct {
395 version string
396 }{
397 "Less than any keys": {
398 version: "1.0",
399 },
400 "Less than first matching major": {
401 version: "2.0",
402 },
403 }
404
405 for name, tc := range tests {
406 tc := tc
407 t.Run(name, func(t *testing.T) {
408 t.Parallel()
409
410 versions := make(versionsMap)
411 for key, value := range testVersions {
412 versions[key] = value
413 }
414
415 versionMappings, err := messageVersionMappings(versions)
416 require.NoError(t, err)
417 rs := RequestService{versions: versionMappings}
418 _, err = rs.pickMessageVersion(context.Background(), tc.version)
419 assert.Error(t, err)
420 })
421 }
422 }
423
View as plain text