1 package cliservice
2
3 import (
4 "context"
5 "fmt"
6 "testing"
7 "time"
8
9 "github.com/stretchr/testify/assert"
10
11 "edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
12 "edge-infra.dev/pkg/sds/emergencyaccess/remotecli"
13 )
14
15 type helper interface {
16 Helper()
17 }
18
19 func EqualError(message string) assert.ErrorAssertionFunc {
20 return func(t assert.TestingT, err error, i ...interface{}) bool {
21 if help, ok := t.(helper); ok {
22 help.Helper()
23 }
24
25 return assert.EqualError(t, err, message, i...)
26 }
27 }
28
29 type MockRemoteCLI struct {
30 sessionID string
31 outputIdentity string
32 outputCommID string
33 outputCommand string
34 outputSeshID string
35
36 sendOpts []remotecli.RCLIOption
37 startSessionOpts []remotecli.RCLIOption
38 }
39
40 func (mrcli *MockRemoteCLI) Send(_ context.Context, userID string, sessionID string, commandID string, command msgdata.Request, opts ...remotecli.RCLIOption) error {
41 mrcli.outputIdentity = userID
42 mrcli.outputSeshID = sessionID
43 mrcli.outputCommID = commandID
44 data, _ := command.Data()
45 mrcli.outputCommand = string(data)
46 mrcli.sendOpts = opts
47 return nil
48 }
49
50 func (mrcli *MockRemoteCLI) StartSession(_ context.Context, sessionID string, _ chan<- msgdata.CommandResponse, _ remotecli.Target, opts ...remotecli.RCLIOption) error {
51 mrcli.sessionID = sessionID
52 mrcli.startSessionOpts = opts
53 return nil
54 }
55
56 func (mrcli *MockRemoteCLI) EndSession(_ context.Context, _ string) error {
57 mrcli.sessionID = ""
58
59 return nil
60 }
61
62
63 var (
64 defaultTarget = target{
65 projectID: "project",
66 bannerID: "banner",
67 storeID: "store",
68 terminalID: "terminal",
69 }
70
71 defaultUserID = "user"
72 )
73
74 type messageService struct {
75
76 retErr error
77
78
79 sessionID string
80 subscriptionID string
81 projectID string
82 responseTopic string
83 }
84
85 func (ms messageService) Subscribe(context.Context, string, string,
86 func(context.Context, msgdata.CommandResponse), map[string]string) error {
87 return nil
88 }
89
90 func (ms messageService) Publish(context.Context, string, string, msgdata.Request) error {
91 return nil
92 }
93
94 func (ms messageService) StopPublish(string, string) {}
95
96 func (ms *messageService) CreateSubscription(_ context.Context, sessionID, subscriptionID, projectID, responseTopic string) error {
97 ms.sessionID = sessionID
98 ms.subscriptionID = subscriptionID
99 ms.projectID = projectID
100 ms.responseTopic = responseTopic
101 return ms.retErr
102 }
103
104 func (ms *messageService) DeleteSubscription(_ context.Context, subscriptionID string, projectID string) error {
105 ms.subscriptionID = subscriptionID
106 ms.projectID = projectID
107 return ms.retErr
108 }
109
110 func TestSuccessConnect(t *testing.T) {
111 cls := NewCLIService(context.Background(), &messageService{})
112 ctx, cancel := context.WithCancel(context.Background())
113 defer cancel()
114
115 mcrli := &MockRemoteCLI{}
116 cls.rcli = mcrli
117 err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
118 assert.NoError(t, err)
119 assert.Equal(t, cls.sessionID, mcrli.sessionID)
120 }
121
122 func TestUnsuccessfulConnect(t *testing.T) {
123 cls := NewCLIService(context.Background(), &messageService{})
124 ctx, cancel := context.WithCancel(context.Background())
125 defer cancel()
126
127 mcrli := &MockRemoteCLI{}
128 cls.rcli = mcrli
129 err := cls.Connect(ctx, "", defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
130 assert.ErrorContains(t, err, "Project ID is a required field")
131 }
132
133 func TestCreateSubscription(t *testing.T) {
134 t.Parallel()
135
136 tests := map[string]struct {
137 DisablePerSessionSubscription bool
138 createErr error
139 sessionID string
140 bannerID string
141 storeID string
142 terminalID string
143 projectID string
144
145 err assert.ErrorAssertionFunc
146 expOptsLen int
147 expMsgSvc messageService
148 }{
149 "Default Enabled": {
150 DisablePerSessionSubscription: false,
151 sessionID: "abcd",
152 projectID: "efgh",
153 bannerID: "ijkl",
154 storeID: "mnop",
155 terminalID: "qrst",
156 err: assert.NoError,
157 expOptsLen: 1,
158 expMsgSvc: messageService{
159 sessionID: "abcd",
160 subscriptionID: "sub.session.abcd.dsds-ea-response",
161 projectID: "efgh",
162 responseTopic: "topic.dsds-ea-response",
163 },
164 },
165 "Disabled": {
166 DisablePerSessionSubscription: true,
167 sessionID: "abcd",
168 projectID: "efgh",
169 bannerID: "ijkl",
170 storeID: "mnop",
171 terminalID: "qrst",
172 err: assert.NoError,
173 expOptsLen: 0,
174 expMsgSvc: messageService{},
175 },
176 "Error": {
177 DisablePerSessionSubscription: false,
178 createErr: fmt.Errorf("error uvwx"),
179 sessionID: "abcd",
180 projectID: "efgh",
181 bannerID: "ijkl",
182 storeID: "mnop",
183 terminalID: "qrst",
184 err: EqualError("error creating subscription: error uvwx"),
185 expOptsLen: 0,
186 expMsgSvc: messageService{
187 retErr: fmt.Errorf("error uvwx"),
188 sessionID: "abcd",
189 subscriptionID: "sub.session.abcd.dsds-ea-response",
190 projectID: "efgh",
191 responseTopic: "topic.dsds-ea-response",
192 },
193 },
194 }
195
196 for name, tc := range tests {
197 tc := tc
198 t.Run(name, func(t *testing.T) {
199 t.Parallel()
200
201 ms := messageService{
202 retErr: tc.createErr,
203 }
204
205 cls := NewCLIService(context.Background(), &ms)
206
207 if tc.DisablePerSessionSubscription {
208 cls.DisablePerSessionSubscription()
209 }
210
211 cls.sessionID = tc.sessionID
212
213 opts, err := cls.createSubscription(context.Background(), tc.projectID)
214 tc.err(t, err)
215 assert.Equal(t, tc.expOptsLen, len(opts))
216 assert.Equal(t, tc.expMsgSvc, ms)
217 })
218 }
219 }
220
221 func TestSuccessEnd(t *testing.T) {
222 cls := NewCLIService(context.Background(), messageService{})
223 ctx, cancel := context.WithCancel(context.Background())
224 defer cancel()
225
226 mcrli := &MockRemoteCLI{}
227 cls.rcli = mcrli
228 err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
229 assert.NoError(t, err)
230
231 err = cls.End()
232 assert.NoError(t, err)
233 assert.Empty(t, mcrli.sessionID)
234 }
235
236 func TestDeleteSubscription(t *testing.T) {
237 t.Parallel()
238
239 tests := map[string]struct {
240 DisablePerSessionSubscription bool
241 sessionID string
242 projectID string
243 retErr error
244
245 err assert.ErrorAssertionFunc
246 expMsgSvc messageService
247 }{
248 "Default": {
249 DisablePerSessionSubscription: false,
250 sessionID: "abcd",
251 projectID: "efgh",
252 err: assert.NoError,
253 expMsgSvc: messageService{
254 subscriptionID: "sub.session.abcd.dsds-ea-response",
255 projectID: "efgh",
256 },
257 },
258 "Disabled": {
259 DisablePerSessionSubscription: true,
260 sessionID: "abcd",
261 projectID: "efgh",
262 err: assert.NoError,
263 expMsgSvc: messageService{},
264 },
265 "Error": {
266 DisablePerSessionSubscription: false,
267 sessionID: "abcd",
268 projectID: "efgh",
269 retErr: fmt.Errorf("bad"),
270 err: EqualError("error deleting per session subscription: bad"),
271 expMsgSvc: messageService{
272 retErr: fmt.Errorf("bad"),
273 projectID: "efgh",
274 subscriptionID: "sub.session.abcd.dsds-ea-response",
275 },
276 },
277 }
278
279 for name, tc := range tests {
280 tc := tc
281 t.Run(name, func(t *testing.T) {
282 t.Parallel()
283
284 ms := messageService{
285 retErr: tc.retErr,
286 }
287
288 cls := NewCLIService(context.Background(), &ms)
289
290 if tc.DisablePerSessionSubscription {
291 cls.DisablePerSessionSubscription()
292 }
293
294 cls.sessionID = tc.sessionID
295 cls.target = target{projectID: tc.projectID}
296
297 err := cls.deleteSubscription(context.Background())
298 tc.err(t, err)
299 assert.Equal(t, tc.expMsgSvc, ms)
300 })
301 }
302 }
303
304 func TestSuccessSend(t *testing.T) {
305 cls := NewCLIService(context.Background(), messageService{})
306 ctx, cancel := context.WithCancel(context.Background())
307 defer cancel()
308
309 mcrli := &MockRemoteCLI{}
310 cls.rcli = mcrli
311 cls.userID = defaultUserID
312 err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
313 assert.NoError(t, err)
314
315 commandID, err := cls.Send("hello world")
316 assert.NoError(t, err)
317
318 assert.JSONEq(t, `{"command": "hello world"}`, mcrli.outputCommand)
319 assert.Equal(t, mcrli.outputSeshID, cls.sessionID)
320 assert.Equal(t, mcrli.outputCommID, commandID)
321 }
322
323 func TestSubscriptionTemplate(t *testing.T) {
324 mcrli := &MockRemoteCLI{}
325 cls := CLIService{
326 rcli: mcrli,
327 }
328
329 assert.Nil(t, mcrli.startSessionOpts)
330 _ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
331
332 assert.Len(t, mcrli.startSessionOpts, 0)
333
334 cls.SetSubscriptionTemplate("TEST_SUBSCRIPTION_TEMPLATE")
335 _ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
336
337 assert.Equal(t, "TEST_SUBSCRIPTION_TEMPLATE", cls.subscriptionTemplate)
338 assert.Len(t, mcrli.startSessionOpts, 1)
339
340
341
342 }
343
344 func TestTopicTemplate(t *testing.T) {
345 mcrli := &MockRemoteCLI{}
346 cls := CLIService{
347 rcli: mcrli,
348 seshCtx: context.Background(),
349 }
350
351 assert.Nil(t, mcrli.sendOpts)
352 cls.userID = defaultUserID
353 _, err := cls.Send("abcd")
354 assert.NoError(t, err)
355
356 assert.Len(t, mcrli.sendOpts, 0)
357
358 cls.SetTopicTemplate("TEST_TOPIC_TEMPLATE")
359 _, err = cls.Send("abcd")
360 assert.NoError(t, err)
361
362 assert.Equal(t, "TEST_TOPIC_TEMPLATE", cls.topicTemplate)
363 assert.Len(t, mcrli.sendOpts, 1)
364 }
365
366 func TestIdleTimeReset(t *testing.T) {
367
368 cls := NewCLIService(context.Background(), messageService{})
369 ctx, cancel := context.WithCancel(context.Background())
370 defer cancel()
371 mcrli := &MockRemoteCLI{}
372 cls.rcli = mcrli
373 cls.userID = defaultUserID
374
375 timePreConnect := time.Now()
376 err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
377 assert.NoError(t, err)
378 assert.Greater(t, time.Since(timePreConnect), cls.IdleTime())
379 oldIdleTime := cls.IdleTime()
380 _, err = cls.Send("a command")
381 assert.Greater(t, oldIdleTime, cls.IdleTime())
382 assert.Nil(t, err)
383 }
384
View as plain text