package cliservice import ( "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "edge-infra.dev/pkg/sds/emergencyaccess/msgdata" "edge-infra.dev/pkg/sds/emergencyaccess/remotecli" ) type helper interface { Helper() } func EqualError(message string) assert.ErrorAssertionFunc { return func(t assert.TestingT, err error, i ...interface{}) bool { if help, ok := t.(helper); ok { help.Helper() } return assert.EqualError(t, err, message, i...) } } type MockRemoteCLI struct { sessionID string outputIdentity string outputCommID string outputCommand string outputSeshID string sendOpts []remotecli.RCLIOption startSessionOpts []remotecli.RCLIOption } func (mrcli *MockRemoteCLI) Send(_ context.Context, userID string, sessionID string, commandID string, command msgdata.Request, opts ...remotecli.RCLIOption) error { mrcli.outputIdentity = userID mrcli.outputSeshID = sessionID mrcli.outputCommID = commandID data, _ := command.Data() mrcli.outputCommand = string(data) mrcli.sendOpts = opts return nil } func (mrcli *MockRemoteCLI) StartSession(_ context.Context, sessionID string, _ chan<- msgdata.CommandResponse, _ remotecli.Target, opts ...remotecli.RCLIOption) error { mrcli.sessionID = sessionID mrcli.startSessionOpts = opts return nil } func (mrcli *MockRemoteCLI) EndSession(_ context.Context, _ string) error { mrcli.sessionID = "" return nil } // shamelessly "borrowed" these two structs from Joshua Reyes-Traverso's test framework var ( defaultTarget = target{ projectID: "project", bannerID: "banner", storeID: "store", terminalID: "terminal", } defaultUserID = "user" ) type messageService struct { // return retErr error // parameters sessionID string subscriptionID string projectID string responseTopic string } func (ms messageService) Subscribe(context.Context, string, string, func(context.Context, msgdata.CommandResponse), map[string]string) error { return nil } func (ms messageService) Publish(context.Context, string, string, msgdata.Request) error { return nil } func (ms messageService) StopPublish(string, string) {} func (ms *messageService) CreateSubscription(_ context.Context, sessionID, subscriptionID, projectID, responseTopic string) error { ms.sessionID = sessionID ms.subscriptionID = subscriptionID ms.projectID = projectID ms.responseTopic = responseTopic return ms.retErr } func (ms *messageService) DeleteSubscription(_ context.Context, subscriptionID string, projectID string) error { ms.subscriptionID = subscriptionID ms.projectID = projectID return ms.retErr } func TestSuccessConnect(t *testing.T) { cls := NewCLIService(context.Background(), &messageService{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mcrli := &MockRemoteCLI{} cls.rcli = mcrli err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.NoError(t, err) assert.Equal(t, cls.sessionID, mcrli.sessionID) } func TestUnsuccessfulConnect(t *testing.T) { cls := NewCLIService(context.Background(), &messageService{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mcrli := &MockRemoteCLI{} cls.rcli = mcrli err := cls.Connect(ctx, "", defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.ErrorContains(t, err, "Project ID is a required field") } func TestCreateSubscription(t *testing.T) { t.Parallel() tests := map[string]struct { DisablePerSessionSubscription bool createErr error sessionID string bannerID string storeID string terminalID string projectID string err assert.ErrorAssertionFunc expOptsLen int expMsgSvc messageService }{ "Default Enabled": { DisablePerSessionSubscription: false, sessionID: "abcd", projectID: "efgh", bannerID: "ijkl", storeID: "mnop", terminalID: "qrst", err: assert.NoError, expOptsLen: 1, expMsgSvc: messageService{ sessionID: "abcd", subscriptionID: "sub.session.abcd.dsds-ea-response", projectID: "efgh", responseTopic: "topic.dsds-ea-response", }, }, "Disabled": { DisablePerSessionSubscription: true, sessionID: "abcd", projectID: "efgh", bannerID: "ijkl", storeID: "mnop", terminalID: "qrst", err: assert.NoError, expOptsLen: 0, expMsgSvc: messageService{}, }, "Error": { DisablePerSessionSubscription: false, createErr: fmt.Errorf("error uvwx"), sessionID: "abcd", projectID: "efgh", bannerID: "ijkl", storeID: "mnop", terminalID: "qrst", err: EqualError("error creating subscription: error uvwx"), expOptsLen: 0, expMsgSvc: messageService{ retErr: fmt.Errorf("error uvwx"), sessionID: "abcd", subscriptionID: "sub.session.abcd.dsds-ea-response", projectID: "efgh", responseTopic: "topic.dsds-ea-response", }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ms := messageService{ retErr: tc.createErr, } cls := NewCLIService(context.Background(), &ms) if tc.DisablePerSessionSubscription { cls.DisablePerSessionSubscription() } cls.sessionID = tc.sessionID opts, err := cls.createSubscription(context.Background(), tc.projectID) tc.err(t, err) assert.Equal(t, tc.expOptsLen, len(opts)) assert.Equal(t, tc.expMsgSvc, ms) }) } } func TestSuccessEnd(t *testing.T) { cls := NewCLIService(context.Background(), messageService{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mcrli := &MockRemoteCLI{} cls.rcli = mcrli err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.NoError(t, err) err = cls.End() assert.NoError(t, err) assert.Empty(t, mcrli.sessionID) } func TestDeleteSubscription(t *testing.T) { t.Parallel() tests := map[string]struct { DisablePerSessionSubscription bool sessionID string projectID string retErr error err assert.ErrorAssertionFunc expMsgSvc messageService }{ "Default": { DisablePerSessionSubscription: false, sessionID: "abcd", projectID: "efgh", err: assert.NoError, expMsgSvc: messageService{ subscriptionID: "sub.session.abcd.dsds-ea-response", projectID: "efgh", }, }, "Disabled": { DisablePerSessionSubscription: true, sessionID: "abcd", projectID: "efgh", err: assert.NoError, expMsgSvc: messageService{}, }, "Error": { DisablePerSessionSubscription: false, sessionID: "abcd", projectID: "efgh", retErr: fmt.Errorf("bad"), err: EqualError("error deleting per session subscription: bad"), expMsgSvc: messageService{ retErr: fmt.Errorf("bad"), projectID: "efgh", subscriptionID: "sub.session.abcd.dsds-ea-response", }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ms := messageService{ retErr: tc.retErr, } cls := NewCLIService(context.Background(), &ms) if tc.DisablePerSessionSubscription { cls.DisablePerSessionSubscription() } cls.sessionID = tc.sessionID cls.target = target{projectID: tc.projectID} err := cls.deleteSubscription(context.Background()) tc.err(t, err) assert.Equal(t, tc.expMsgSvc, ms) }) } } func TestSuccessSend(t *testing.T) { cls := NewCLIService(context.Background(), messageService{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mcrli := &MockRemoteCLI{} cls.rcli = mcrli cls.userID = defaultUserID err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.NoError(t, err) commandID, err := cls.Send("hello world") assert.NoError(t, err) assert.JSONEq(t, `{"command": "hello world"}`, mcrli.outputCommand) assert.Equal(t, mcrli.outputSeshID, cls.sessionID) assert.Equal(t, mcrli.outputCommID, commandID) } func TestSubscriptionTemplate(t *testing.T) { mcrli := &MockRemoteCLI{} cls := CLIService{ rcli: mcrli, } assert.Nil(t, mcrli.startSessionOpts) _ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.Len(t, mcrli.startSessionOpts, 0) cls.SetSubscriptionTemplate("TEST_SUBSCRIPTION_TEMPLATE") _ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.Equal(t, "TEST_SUBSCRIPTION_TEMPLATE", cls.subscriptionTemplate) assert.Len(t, mcrli.startSessionOpts, 1) // TODO: Currently only testing that remotecli.Send is called with some sort // of option, need to test it is called with the correct // TEST_SUBSCRIPTION_TEMPLATE value } func TestTopicTemplate(t *testing.T) { mcrli := &MockRemoteCLI{} cls := CLIService{ rcli: mcrli, seshCtx: context.Background(), } assert.Nil(t, mcrli.sendOpts) cls.userID = defaultUserID _, err := cls.Send("abcd") assert.NoError(t, err) assert.Len(t, mcrli.sendOpts, 0) cls.SetTopicTemplate("TEST_TOPIC_TEMPLATE") _, err = cls.Send("abcd") assert.NoError(t, err) assert.Equal(t, "TEST_TOPIC_TEMPLATE", cls.topicTemplate) assert.Len(t, mcrli.sendOpts, 1) } func TestIdleTimeReset(t *testing.T) { // mocking cls := NewCLIService(context.Background(), messageService{}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() mcrli := &MockRemoteCLI{} cls.rcli = mcrli cls.userID = defaultUserID // test timePreConnect := time.Now() err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID) assert.NoError(t, err) assert.Greater(t, time.Since(timePreConnect), cls.IdleTime()) //checks connect has reset the time oldIdleTime := cls.IdleTime() _, err = cls.Send("a command") assert.Greater(t, oldIdleTime, cls.IdleTime()) // checks send has reset the time assert.Nil(t, err) }