package remotecli import ( "bytes" "context" "fmt" "strings" "sync" "testing" "time" "github.com/go-logr/logr" "github.com/go-logr/logr/funcr" "github.com/google/uuid" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/sds/emergencyaccess/eaconst" "edge-infra.dev/pkg/sds/emergencyaccess/msgdata" ) var ( defaultTarget = target{ projectID: "project", bannerID: "banner", storeID: "store", terminalID: "terminal", } defaultIdentity = "user" defaultWaitTime = 20 * time.Millisecond defaultTickTime = 1 * time.Microsecond ) type target struct { projectID string bannerID string storeID string terminalID string } func (t target) ProjectID() string { return t.projectID } func (t target) BannerID() string { return t.bannerID } func (t target) StoreID() string { return t.storeID } func (t target) TerminalID() string { return t.terminalID } type messageService struct { subscriptionFunc func(subscriptioID, projectID string) stopPublishFunc func(string, string) subscribeWatchCtx bool pubCommandID string } func (ms *messageService) Subscribe( ctx context.Context, subscriptionID string, projectID string, _ func(context.Context, msgdata.CommandResponse), _ map[string]string, ) error { if ms.subscriptionFunc != nil { ms.subscriptionFunc(subscriptionID, projectID) } // Wait for context cancellation when requested, otherwise return immediately if ms.subscribeWatchCtx { <-ctx.Done() } if projectID == "subscribe_error" { return fmt.Errorf("subscribe returned error") } return nil } func (ms *messageService) Publish(_ context.Context, _ string, projectID string, msg msgdata.Request) error { ms.pubCommandID = msg.Attributes()[eaconst.CommandIDKey] if projectID == "publish_error" { return fmt.Errorf("publish returned error") } return nil } func (ms messageService) StopPublish(topicID string, projectID string) { if ms.stopPublishFunc != nil { ms.stopPublishFunc(topicID, projectID) } } func TestValidateTarget(t *testing.T) { assert.NoError(t, validateTarget(defaultTarget)) badTarget := target{} assert.Len(t, validateTarget(badTarget), 4) } func TestNewRemoteCLI(t *testing.T) { ctx := context.Background() ms := messageService{} expected := &RemoteCLI{ msgService: &ms, sessionLock: &sync.RWMutex{}, context: ctx, sessionData: map[string]sessionData{}, subscriptionData: map[string]subscriptionData{}, topicData: map[string]topicData{}, } rcli := New(ctx, &ms) assert.Equal(t, expected, rcli) } func TestStartSession(t *testing.T) { buf := bytes.Buffer{} ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf))) ch := make(chan msgdata.CommandResponse) ms := messageService{ subscribeWatchCtx: true, } rcli := New(ctx, &ms) sessionID := uuid.NewString() target := defaultTarget target.projectID = "subscribe_error" err := rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) assert.Equal(t, target, rcli.sessionData[sessionID].target) expSubscriptionID := "sub.store.dsds-ea-response" assert.True(t, rcli.subscriptionData[expSubscriptionID].sessions.HasMember(sessionID)) cancelFunc() assert.Eventually(t, func() bool { return strings.Contains(buf.String(), "subscribe returned error") }, 4*time.Second, 1*time.Millisecond, "logs:\n%s\ndoes not contain expected string: %s", buf.String(), "subscribe returned error") } func TestStartSession_ListensToContextDone(t *testing.T) { buf := bytes.Buffer{} ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf))) target := defaultTarget ch := make(chan msgdata.CommandResponse) sessionID := uuid.NewString() // Make sure the mock subscription does not exit as we don't want the subscription // end cleanup to run endSubscriptionChan := make(chan struct{}) subFunc := func(_, _ string) { <-endSubscriptionChan } ms := messageService{ subscriptionFunc: subFunc, } rcli := New(ctx, &ms) err := rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) // Ensure the channel is not closed assert.Never(t, func() bool { select { case _, ok := <-ch: if ok { return false } return true default: return false } }, defaultWaitTime, defaultTickTime) // Cancel the context cancelFunc() // Confirm the channel is closed assert.Eventually(t, func() bool { select { case _, ok := <-ch: if ok { return false } return true default: return false } }, defaultWaitTime, defaultTickTime) // test cleanup of goroutines endSubscriptionChan <- struct{}{} } func TestSubscribeCalledOnce(t *testing.T) { // Test msgsvc.Subscribe only called once for startsession with same subscriptionID ctx := fog.IntoContext(context.Background(), logr.Discard()) target := defaultTarget ch := make(chan msgdata.CommandResponse) sessionID := uuid.NewString() var subscriptionCalledCounter int subFunc := func(_, _ string) { subscriptionCalledCounter = subscriptionCalledCounter + 1 } ms := messageService{ subscriptionFunc: subFunc, subscribeWatchCtx: true, } rcli := New(ctx, &ms) err := rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) assert.Eventually(t, func() bool { return subscriptionCalledCounter == 1 }, defaultWaitTime, defaultTickTime) // Test ms.Subscribe call not duplicated ch = make(chan msgdata.CommandResponse) sessionID = uuid.NewString() err = rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) assert.Never(t, func() bool { return subscriptionCalledCounter == 2 }, defaultWaitTime, defaultTickTime) // Test ms.subscribe called for different subscription ch = make(chan msgdata.CommandResponse) sessionID = uuid.NewString() target.storeID = "anotherstore" err = rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) assert.Eventually(t, func() bool { return subscriptionCalledCounter == 2 }, defaultWaitTime, defaultTickTime) } func TestStartSessionInvalidTarget(t *testing.T) { buf := bytes.Buffer{} logger := createLogger(&buf) sessionID := uuid.NewString() ctx := fog.IntoContext(context.Background(), logger) ch, ms := make(chan msgdata.CommandResponse), messageService{} rcli := New(ctx, &ms) target := target{} err := rcli.StartSession(ctx, sessionID, ch, target) assert.Contains(t, err.Error(), validateTarget(target).Error()) } func TestSend(t *testing.T) { ctx := fog.IntoContext(context.Background(), logr.Discard()) ch := make(chan msgdata.CommandResponse) ms := messageService{subscribeWatchCtx: true} sessionID := uuid.NewString() rcli := New(ctx, &ms) identity, target := defaultIdentity, defaultTarget target.projectID = "publish_error" command := "echo hello" request, err := msgdata.NewV1_0Request(command) assert.NoError(t, err) commandID := uuid.NewString() err = rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) err = rcli.Send(ctx, identity, sessionID, commandID, request) assert.Contains(t, err.Error(), fmt.Errorf("publish returned error").Error()) assert.NotEmpty(t, ms.pubCommandID) assert.Equal(t, commandID, ms.pubCommandID) } func TestSendNoSessionStarted(t *testing.T) { buf := bytes.Buffer{} logger := createLogger(&buf) ctx := fog.IntoContext(context.Background(), logger) ms := messageService{} rcli := New(ctx, &ms) identity, command, sessionID := defaultIdentity, "echo hello", "invalid-session-id" request, err := msgdata.NewV1_0Request(command) assert.NoError(t, err) err = rcli.Send(ctx, identity, sessionID, uuid.NewString(), request) assert.Contains(t, err.Error(), "invalid session id") } func TestEndSession(t *testing.T) { buf := bytes.Buffer{} logger := createLogger(&buf) sessionID := uuid.NewString() ms := messageService{ subscribeWatchCtx: true, } ctx := fog.IntoContext(context.Background(), logger) ch, target := make(chan msgdata.CommandResponse), defaultTarget rcli := New(ctx, &ms) logOK := "Session stopped" var err error assert.NotPanics(t, func() { err = rcli.EndSession(ctx, "nonexistant_session") }) assert.ErrorContains(t, err, "unknown session ID") buf.Reset() err = rcli.StartSession(ctx, sessionID, ch, target) assert.NoError(t, err) assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) }) assert.NoError(t, err) assert.Contains(t, buf.String(), logOK) buf.Reset() // Call EndSession with same SessionID again assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) }) assert.ErrorContains(t, err, "unknown session ID") buf.Reset() // verifies the context was cancelled assert.Eventually(t, func() bool { select { case v, ok := <-ch: if ok { t.Errorf("unexpected value on display channel: %s", v) return false } return true default: return false } }, 5*time.Second, time.Microsecond) } func TestEndSession_Topics(t *testing.T) { buf := bytes.Buffer{} logger := createLogger(&buf) sessionID := uuid.NewString() // Info struct and map to track calls of messageService StopPublish function type publishInfo struct { topicID string projectID string } endSessionCallMap := make(map[publishInfo]int) // Called whenever the messageService StopPublish function is called // Tracks number of times func is called with each info stopPublishFunc := func(topicID string, projectID string) { pInfo := publishInfo{topicID, projectID} endSessionCallMap[pInfo] = endSessionCallMap[pInfo] + 1 } ms := messageService{ stopPublishFunc: stopPublishFunc, subscribeWatchCtx: true, } ctx := fog.IntoContext(context.Background(), logger) ch, target := make(chan msgdata.CommandResponse), defaultTarget rcli := New(ctx, &ms) var err error assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) }) assert.ErrorContains(t, err, "unknown session ID") request, err := msgdata.NewV1_0Request("command") assert.NoError(t, err) // Session has not been posted to, so topic cleanup should not have been run assert.Len(t, endSessionCallMap, 0) // Start a session and send message to two different topics assert.NoError(t, rcli.StartSession(ctx, sessionID, ch, target)) assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request, WithOptionalTemplate("a"))) assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request)) assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) }) assert.NoError(t, err) // Check Stop Publish is called once for each of the topics published to assert.Len(t, endSessionCallMap, 2) assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}]) assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}]) // Repeat endsession call with same session ID assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) }) assert.ErrorContains(t, err, "unknown session ID") // Check StopPublish has not been called again assert.Len(t, endSessionCallMap, 2) assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}]) assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}]) } func handlerAttrs(sessionID string) map[string]string { return map[string]string{ "bannerId": "banner", "storeId": "store", "terminalId": "terminal", "sessionId": sessionID, "identity": "identity", "version": "1.0", "signature": "signature", "request-message-uuid": "id", } } func handlerData(output string) []byte { return []byte(fmt.Sprintf(` { "type": "Output", "exitCode": 0, "output": "%s", "timestamp": "01-01-2023 00:00:00", "duration": 0.1 }`, output)) } func TestHandler(t *testing.T) { ch1 := make(chan msgdata.CommandResponse, 3) ch2 := make(chan msgdata.CommandResponse, 1) ch3 := make(chan msgdata.CommandResponse) rcli := RemoteCLI{ sessionLock: &sync.RWMutex{}, sessionData: map[string]sessionData{ "orderingKey": { displayChan: ch1, }, "orderingKey2": { displayChan: ch2, }, "orderingKey3": { displayChan: ch3, }, }, } ctx, cancelFunc := context.WithCancel(context.Background()) fn := rcli.handler() data1 := handlerData("message 1") data2 := handlerData("message 2") data3 := handlerData("other ordering key") data4 := handlerData("this message should not be received") msg1, _ := msgdata.NewCommandResponse(data1, handlerAttrs("orderingKey")) msg2, _ := msgdata.NewCommandResponse(data2, handlerAttrs("orderingKey")) msg3, _ := msgdata.NewCommandResponse(data3, handlerAttrs("orderingKey2")) msg4, _ := msgdata.NewCommandResponse(data4, handlerAttrs("orderingKey3")) fn(ctx, msg1) assert.Empty(t, ch2) assert.Equal(t, msg1, <-ch1) fn(ctx, msg2) assert.Empty(t, ch2) assert.Equal(t, msg2, <-ch1) // New sessionID fn(context.Background(), msg3) assert.Empty(t, ch1) assert.Equal(t, msg3, <-ch2) cancelFunc() time.Sleep(time.Millisecond * 100) fn(ctx, msg4) assert.Empty(t, ch1) assert.Empty(t, ch2) assert.Empty(t, ch3) } func TestCreateOptionalConfig(t *testing.T) { template1, template2, template3 := "template-string-1", "template-string-2", "template-string-3" opts := []RCLIOption{WithOptionalTemplate(template1), WithOptionalTemplate(template2), WithOptionalTemplate(template3)} expected := &templateConfig{template: &template3} assert.Equal(t, expected, createOptionalConfig(opts)) assert.Nil(t, createOptionalConfig(nil)) } func TestFillTemplate(t *testing.T) { defaultTemplate := "default...." optionalTemplate := "optional...." target := defaultTarget config := &templateConfig{template: &optionalTemplate} expected := fmt.Sprintf("%s.%s.%s.%s", target.projectID, target.bannerID, target.storeID, target.terminalID) assert.Equal(t, "default."+expected, fillTemplate(target, defaultTemplate, nil)) assert.Equal(t, "optional."+expected, fillTemplate(target, defaultTemplate, config)) } func createLogger(buf *bytes.Buffer) logr.Logger { return funcr.New(func(prefix, args string) { if prefix != "" { fmt.Fprintf(buf, "%s: %s\n", prefix, args) } else { fmt.Fprintln(buf, args) } }, funcr.Options{}) }