...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/cliservice/cliservice_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/cliservice

     1  package cliservice
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  
    11  	"edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
    12  	"edge-infra.dev/pkg/sds/emergencyaccess/remotecli"
    13  )
    14  
    15  type helper interface {
    16  	Helper()
    17  }
    18  
    19  func EqualError(message string) assert.ErrorAssertionFunc {
    20  	return func(t assert.TestingT, err error, i ...interface{}) bool {
    21  		if help, ok := t.(helper); ok {
    22  			help.Helper()
    23  		}
    24  
    25  		return assert.EqualError(t, err, message, i...)
    26  	}
    27  }
    28  
    29  type MockRemoteCLI struct {
    30  	sessionID      string
    31  	outputIdentity string
    32  	outputCommID   string
    33  	outputCommand  string
    34  	outputSeshID   string
    35  
    36  	sendOpts         []remotecli.RCLIOption
    37  	startSessionOpts []remotecli.RCLIOption
    38  }
    39  
    40  func (mrcli *MockRemoteCLI) Send(_ context.Context, userID string, sessionID string, commandID string, command msgdata.Request, opts ...remotecli.RCLIOption) error {
    41  	mrcli.outputIdentity = userID
    42  	mrcli.outputSeshID = sessionID
    43  	mrcli.outputCommID = commandID
    44  	data, _ := command.Data()
    45  	mrcli.outputCommand = string(data)
    46  	mrcli.sendOpts = opts
    47  	return nil
    48  }
    49  
    50  func (mrcli *MockRemoteCLI) StartSession(_ context.Context, sessionID string, _ chan<- msgdata.CommandResponse, _ remotecli.Target, opts ...remotecli.RCLIOption) error {
    51  	mrcli.sessionID = sessionID
    52  	mrcli.startSessionOpts = opts
    53  	return nil
    54  }
    55  
    56  func (mrcli *MockRemoteCLI) EndSession(_ context.Context, _ string) error {
    57  	mrcli.sessionID = ""
    58  
    59  	return nil
    60  }
    61  
    62  // shamelessly "borrowed" these two structs from Joshua Reyes-Traverso's test framework
    63  var (
    64  	defaultTarget = target{
    65  		projectID:  "project",
    66  		bannerID:   "banner",
    67  		storeID:    "store",
    68  		terminalID: "terminal",
    69  	}
    70  
    71  	defaultUserID = "user"
    72  )
    73  
    74  type messageService struct {
    75  	// return
    76  	retErr error
    77  
    78  	// parameters
    79  	sessionID      string
    80  	subscriptionID string
    81  	projectID      string
    82  	responseTopic  string
    83  }
    84  
    85  func (ms messageService) Subscribe(context.Context, string, string,
    86  	func(context.Context, msgdata.CommandResponse), map[string]string) error {
    87  	return nil
    88  }
    89  
    90  func (ms messageService) Publish(context.Context, string, string, msgdata.Request) error {
    91  	return nil
    92  }
    93  
    94  func (ms messageService) StopPublish(string, string) {}
    95  
    96  func (ms *messageService) CreateSubscription(_ context.Context, sessionID, subscriptionID, projectID, responseTopic string) error {
    97  	ms.sessionID = sessionID
    98  	ms.subscriptionID = subscriptionID
    99  	ms.projectID = projectID
   100  	ms.responseTopic = responseTopic
   101  	return ms.retErr
   102  }
   103  
   104  func (ms *messageService) DeleteSubscription(_ context.Context, subscriptionID string, projectID string) error {
   105  	ms.subscriptionID = subscriptionID
   106  	ms.projectID = projectID
   107  	return ms.retErr
   108  }
   109  
   110  func TestSuccessConnect(t *testing.T) {
   111  	cls := NewCLIService(context.Background(), &messageService{})
   112  	ctx, cancel := context.WithCancel(context.Background())
   113  	defer cancel()
   114  
   115  	mcrli := &MockRemoteCLI{}
   116  	cls.rcli = mcrli
   117  	err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   118  	assert.NoError(t, err)
   119  	assert.Equal(t, cls.sessionID, mcrli.sessionID)
   120  }
   121  
   122  func TestUnsuccessfulConnect(t *testing.T) {
   123  	cls := NewCLIService(context.Background(), &messageService{})
   124  	ctx, cancel := context.WithCancel(context.Background())
   125  	defer cancel()
   126  
   127  	mcrli := &MockRemoteCLI{}
   128  	cls.rcli = mcrli
   129  	err := cls.Connect(ctx, "", defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   130  	assert.ErrorContains(t, err, "Project ID is a required field")
   131  }
   132  
   133  func TestCreateSubscription(t *testing.T) {
   134  	t.Parallel()
   135  
   136  	tests := map[string]struct {
   137  		DisablePerSessionSubscription bool
   138  		createErr                     error
   139  		sessionID                     string
   140  		bannerID                      string
   141  		storeID                       string
   142  		terminalID                    string
   143  		projectID                     string
   144  
   145  		err        assert.ErrorAssertionFunc
   146  		expOptsLen int
   147  		expMsgSvc  messageService
   148  	}{
   149  		"Default Enabled": {
   150  			DisablePerSessionSubscription: false,
   151  			sessionID:                     "abcd",
   152  			projectID:                     "efgh",
   153  			bannerID:                      "ijkl",
   154  			storeID:                       "mnop",
   155  			terminalID:                    "qrst",
   156  			err:                           assert.NoError,
   157  			expOptsLen:                    1,
   158  			expMsgSvc: messageService{
   159  				sessionID:      "abcd",
   160  				subscriptionID: "sub.session.abcd.dsds-ea-response",
   161  				projectID:      "efgh",
   162  				responseTopic:  "topic.dsds-ea-response",
   163  			},
   164  		},
   165  		"Disabled": {
   166  			DisablePerSessionSubscription: true,
   167  			sessionID:                     "abcd",
   168  			projectID:                     "efgh",
   169  			bannerID:                      "ijkl",
   170  			storeID:                       "mnop",
   171  			terminalID:                    "qrst",
   172  			err:                           assert.NoError,
   173  			expOptsLen:                    0,
   174  			expMsgSvc:                     messageService{},
   175  		},
   176  		"Error": {
   177  			DisablePerSessionSubscription: false,
   178  			createErr:                     fmt.Errorf("error uvwx"),
   179  			sessionID:                     "abcd",
   180  			projectID:                     "efgh",
   181  			bannerID:                      "ijkl",
   182  			storeID:                       "mnop",
   183  			terminalID:                    "qrst",
   184  			err:                           EqualError("error creating subscription: error uvwx"),
   185  			expOptsLen:                    0,
   186  			expMsgSvc: messageService{
   187  				retErr:         fmt.Errorf("error uvwx"),
   188  				sessionID:      "abcd",
   189  				subscriptionID: "sub.session.abcd.dsds-ea-response",
   190  				projectID:      "efgh",
   191  				responseTopic:  "topic.dsds-ea-response",
   192  			},
   193  		},
   194  	}
   195  
   196  	for name, tc := range tests {
   197  		tc := tc
   198  		t.Run(name, func(t *testing.T) {
   199  			t.Parallel()
   200  
   201  			ms := messageService{
   202  				retErr: tc.createErr,
   203  			}
   204  
   205  			cls := NewCLIService(context.Background(), &ms)
   206  
   207  			if tc.DisablePerSessionSubscription {
   208  				cls.DisablePerSessionSubscription()
   209  			}
   210  
   211  			cls.sessionID = tc.sessionID
   212  
   213  			opts, err := cls.createSubscription(context.Background(), tc.projectID)
   214  			tc.err(t, err)
   215  			assert.Equal(t, tc.expOptsLen, len(opts))
   216  			assert.Equal(t, tc.expMsgSvc, ms)
   217  		})
   218  	}
   219  }
   220  
   221  func TestSuccessEnd(t *testing.T) {
   222  	cls := NewCLIService(context.Background(), messageService{})
   223  	ctx, cancel := context.WithCancel(context.Background())
   224  	defer cancel()
   225  
   226  	mcrli := &MockRemoteCLI{}
   227  	cls.rcli = mcrli
   228  	err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   229  	assert.NoError(t, err)
   230  
   231  	err = cls.End()
   232  	assert.NoError(t, err)
   233  	assert.Empty(t, mcrli.sessionID)
   234  }
   235  
   236  func TestDeleteSubscription(t *testing.T) {
   237  	t.Parallel()
   238  
   239  	tests := map[string]struct {
   240  		DisablePerSessionSubscription bool
   241  		sessionID                     string
   242  		projectID                     string
   243  		retErr                        error
   244  
   245  		err       assert.ErrorAssertionFunc
   246  		expMsgSvc messageService
   247  	}{
   248  		"Default": {
   249  			DisablePerSessionSubscription: false,
   250  			sessionID:                     "abcd",
   251  			projectID:                     "efgh",
   252  			err:                           assert.NoError,
   253  			expMsgSvc: messageService{
   254  				subscriptionID: "sub.session.abcd.dsds-ea-response",
   255  				projectID:      "efgh",
   256  			},
   257  		},
   258  		"Disabled": {
   259  			DisablePerSessionSubscription: true,
   260  			sessionID:                     "abcd",
   261  			projectID:                     "efgh",
   262  			err:                           assert.NoError,
   263  			expMsgSvc:                     messageService{},
   264  		},
   265  		"Error": {
   266  			DisablePerSessionSubscription: false,
   267  			sessionID:                     "abcd",
   268  			projectID:                     "efgh",
   269  			retErr:                        fmt.Errorf("bad"),
   270  			err:                           EqualError("error deleting per session subscription: bad"),
   271  			expMsgSvc: messageService{
   272  				retErr:         fmt.Errorf("bad"),
   273  				projectID:      "efgh",
   274  				subscriptionID: "sub.session.abcd.dsds-ea-response",
   275  			},
   276  		},
   277  	}
   278  
   279  	for name, tc := range tests {
   280  		tc := tc
   281  		t.Run(name, func(t *testing.T) {
   282  			t.Parallel()
   283  
   284  			ms := messageService{
   285  				retErr: tc.retErr,
   286  			}
   287  
   288  			cls := NewCLIService(context.Background(), &ms)
   289  
   290  			if tc.DisablePerSessionSubscription {
   291  				cls.DisablePerSessionSubscription()
   292  			}
   293  
   294  			cls.sessionID = tc.sessionID
   295  			cls.target = target{projectID: tc.projectID}
   296  
   297  			err := cls.deleteSubscription(context.Background())
   298  			tc.err(t, err)
   299  			assert.Equal(t, tc.expMsgSvc, ms)
   300  		})
   301  	}
   302  }
   303  
   304  func TestSuccessSend(t *testing.T) {
   305  	cls := NewCLIService(context.Background(), messageService{})
   306  	ctx, cancel := context.WithCancel(context.Background())
   307  	defer cancel()
   308  
   309  	mcrli := &MockRemoteCLI{}
   310  	cls.rcli = mcrli
   311  	cls.userID = defaultUserID
   312  	err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   313  	assert.NoError(t, err)
   314  
   315  	commandID, err := cls.Send("hello world")
   316  	assert.NoError(t, err)
   317  
   318  	assert.JSONEq(t, `{"command": "hello world"}`, mcrli.outputCommand)
   319  	assert.Equal(t, mcrli.outputSeshID, cls.sessionID)
   320  	assert.Equal(t, mcrli.outputCommID, commandID)
   321  }
   322  
   323  func TestSubscriptionTemplate(t *testing.T) {
   324  	mcrli := &MockRemoteCLI{}
   325  	cls := CLIService{
   326  		rcli: mcrli,
   327  	}
   328  
   329  	assert.Nil(t, mcrli.startSessionOpts)
   330  	_ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   331  
   332  	assert.Len(t, mcrli.startSessionOpts, 0)
   333  
   334  	cls.SetSubscriptionTemplate("TEST_SUBSCRIPTION_TEMPLATE")
   335  	_ = cls.Connect(context.Background(), defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   336  
   337  	assert.Equal(t, "TEST_SUBSCRIPTION_TEMPLATE", cls.subscriptionTemplate)
   338  	assert.Len(t, mcrli.startSessionOpts, 1)
   339  	// TODO: Currently only testing that remotecli.Send is called with some sort
   340  	//       of option, need to test it is called with the correct
   341  	//       TEST_SUBSCRIPTION_TEMPLATE value
   342  }
   343  
   344  func TestTopicTemplate(t *testing.T) {
   345  	mcrli := &MockRemoteCLI{}
   346  	cls := CLIService{
   347  		rcli:    mcrli,
   348  		seshCtx: context.Background(),
   349  	}
   350  
   351  	assert.Nil(t, mcrli.sendOpts)
   352  	cls.userID = defaultUserID
   353  	_, err := cls.Send("abcd")
   354  	assert.NoError(t, err)
   355  
   356  	assert.Len(t, mcrli.sendOpts, 0)
   357  
   358  	cls.SetTopicTemplate("TEST_TOPIC_TEMPLATE")
   359  	_, err = cls.Send("abcd")
   360  	assert.NoError(t, err)
   361  
   362  	assert.Equal(t, "TEST_TOPIC_TEMPLATE", cls.topicTemplate)
   363  	assert.Len(t, mcrli.sendOpts, 1)
   364  }
   365  
   366  func TestIdleTimeReset(t *testing.T) {
   367  	// mocking
   368  	cls := NewCLIService(context.Background(), messageService{})
   369  	ctx, cancel := context.WithCancel(context.Background())
   370  	defer cancel()
   371  	mcrli := &MockRemoteCLI{}
   372  	cls.rcli = mcrli
   373  	cls.userID = defaultUserID
   374  	// test
   375  	timePreConnect := time.Now()
   376  	err := cls.Connect(ctx, defaultTarget.projectID, defaultTarget.bannerID, defaultTarget.storeID, defaultTarget.terminalID)
   377  	assert.NoError(t, err)
   378  	assert.Greater(t, time.Since(timePreConnect), cls.IdleTime()) //checks connect has reset the time
   379  	oldIdleTime := cls.IdleTime()
   380  	_, err = cls.Send("a command")
   381  	assert.Greater(t, oldIdleTime, cls.IdleTime()) // checks send has reset the time
   382  	assert.Nil(t, err)
   383  }
   384  

View as plain text