...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/remotecli/remotecli_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/remotecli

     1  package remotecli
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"strings"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/go-logr/logr"
    13  	"github.com/go-logr/logr/funcr"
    14  	"github.com/google/uuid"
    15  	"github.com/stretchr/testify/assert"
    16  
    17  	"edge-infra.dev/pkg/lib/fog"
    18  	"edge-infra.dev/pkg/sds/emergencyaccess/eaconst"
    19  	"edge-infra.dev/pkg/sds/emergencyaccess/msgdata"
    20  )
    21  
    22  var (
    23  	defaultTarget = target{
    24  		projectID:  "project",
    25  		bannerID:   "banner",
    26  		storeID:    "store",
    27  		terminalID: "terminal",
    28  	}
    29  	defaultIdentity = "user"
    30  
    31  	defaultWaitTime = 20 * time.Millisecond
    32  	defaultTickTime = 1 * time.Microsecond
    33  )
    34  
    35  type target struct {
    36  	projectID  string
    37  	bannerID   string
    38  	storeID    string
    39  	terminalID string
    40  }
    41  
    42  func (t target) ProjectID() string  { return t.projectID }
    43  func (t target) BannerID() string   { return t.bannerID }
    44  func (t target) StoreID() string    { return t.storeID }
    45  func (t target) TerminalID() string { return t.terminalID }
    46  
    47  type messageService struct {
    48  	subscriptionFunc  func(subscriptioID, projectID string)
    49  	stopPublishFunc   func(string, string)
    50  	subscribeWatchCtx bool
    51  
    52  	pubCommandID string
    53  }
    54  
    55  func (ms *messageService) Subscribe(
    56  	ctx context.Context,
    57  	subscriptionID string,
    58  	projectID string,
    59  	_ func(context.Context, msgdata.CommandResponse),
    60  	_ map[string]string,
    61  ) error {
    62  	if ms.subscriptionFunc != nil {
    63  		ms.subscriptionFunc(subscriptionID, projectID)
    64  	}
    65  
    66  	// Wait for context cancellation when requested, otherwise return immediately
    67  	if ms.subscribeWatchCtx {
    68  		<-ctx.Done()
    69  	}
    70  
    71  	if projectID == "subscribe_error" {
    72  		return fmt.Errorf("subscribe returned error")
    73  	}
    74  	return nil
    75  }
    76  
    77  func (ms *messageService) Publish(_ context.Context, _ string, projectID string, msg msgdata.Request) error {
    78  	ms.pubCommandID = msg.Attributes()[eaconst.CommandIDKey]
    79  
    80  	if projectID == "publish_error" {
    81  		return fmt.Errorf("publish returned error")
    82  	}
    83  	return nil
    84  }
    85  
    86  func (ms messageService) StopPublish(topicID string, projectID string) {
    87  	if ms.stopPublishFunc != nil {
    88  		ms.stopPublishFunc(topicID, projectID)
    89  	}
    90  }
    91  
    92  func TestValidateTarget(t *testing.T) {
    93  	assert.NoError(t, validateTarget(defaultTarget))
    94  
    95  	badTarget := target{}
    96  	assert.Len(t, validateTarget(badTarget), 4)
    97  }
    98  
    99  func TestNewRemoteCLI(t *testing.T) {
   100  	ctx := context.Background()
   101  	ms := messageService{}
   102  	expected := &RemoteCLI{
   103  		msgService:  &ms,
   104  		sessionLock: &sync.RWMutex{},
   105  		context:     ctx,
   106  
   107  		sessionData:      map[string]sessionData{},
   108  		subscriptionData: map[string]subscriptionData{},
   109  		topicData:        map[string]topicData{},
   110  	}
   111  
   112  	rcli := New(ctx, &ms)
   113  
   114  	assert.Equal(t, expected, rcli)
   115  }
   116  
   117  func TestStartSession(t *testing.T) {
   118  	buf := bytes.Buffer{}
   119  	ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf)))
   120  	ch := make(chan msgdata.CommandResponse)
   121  
   122  	ms := messageService{
   123  		subscribeWatchCtx: true,
   124  	}
   125  
   126  	rcli := New(ctx, &ms)
   127  	sessionID := uuid.NewString()
   128  
   129  	target := defaultTarget
   130  	target.projectID = "subscribe_error"
   131  
   132  	err := rcli.StartSession(ctx, sessionID, ch, target)
   133  	assert.NoError(t, err)
   134  
   135  	assert.Equal(t, target, rcli.sessionData[sessionID].target)
   136  
   137  	expSubscriptionID := "sub.store.dsds-ea-response"
   138  	assert.True(t, rcli.subscriptionData[expSubscriptionID].sessions.HasMember(sessionID))
   139  
   140  	cancelFunc()
   141  
   142  	assert.Eventually(t, func() bool {
   143  		return strings.Contains(buf.String(), "subscribe returned error")
   144  	}, 4*time.Second, 1*time.Millisecond, "logs:\n%s\ndoes not contain expected string: %s", buf.String(), "subscribe returned error")
   145  }
   146  
   147  func TestStartSession_ListensToContextDone(t *testing.T) {
   148  	buf := bytes.Buffer{}
   149  	ctx, cancelFunc := context.WithCancel(fog.IntoContext(context.Background(), createLogger(&buf)))
   150  	target := defaultTarget
   151  	ch := make(chan msgdata.CommandResponse)
   152  	sessionID := uuid.NewString()
   153  
   154  	// Make sure the mock subscription does not exit as we don't want the subscription
   155  	// end cleanup to run
   156  	endSubscriptionChan := make(chan struct{})
   157  	subFunc := func(_, _ string) {
   158  		<-endSubscriptionChan
   159  	}
   160  	ms := messageService{
   161  		subscriptionFunc: subFunc,
   162  	}
   163  
   164  	rcli := New(ctx, &ms)
   165  
   166  	err := rcli.StartSession(ctx, sessionID, ch, target)
   167  	assert.NoError(t, err)
   168  
   169  	// Ensure the channel is not closed
   170  	assert.Never(t, func() bool {
   171  		select {
   172  		case _, ok := <-ch:
   173  			if ok {
   174  				return false
   175  			}
   176  			return true
   177  		default:
   178  			return false
   179  		}
   180  	}, defaultWaitTime, defaultTickTime)
   181  
   182  	// Cancel the context
   183  	cancelFunc()
   184  	// Confirm the channel is closed
   185  	assert.Eventually(t, func() bool {
   186  		select {
   187  		case _, ok := <-ch:
   188  			if ok {
   189  				return false
   190  			}
   191  			return true
   192  		default:
   193  			return false
   194  		}
   195  	}, defaultWaitTime, defaultTickTime)
   196  
   197  	// test cleanup of goroutines
   198  	endSubscriptionChan <- struct{}{}
   199  }
   200  
   201  func TestSubscribeCalledOnce(t *testing.T) {
   202  	// Test msgsvc.Subscribe only called once for startsession with same subscriptionID
   203  
   204  	ctx := fog.IntoContext(context.Background(), logr.Discard())
   205  	target := defaultTarget
   206  	ch := make(chan msgdata.CommandResponse)
   207  	sessionID := uuid.NewString()
   208  
   209  	var subscriptionCalledCounter int
   210  	subFunc := func(_, _ string) {
   211  		subscriptionCalledCounter = subscriptionCalledCounter + 1
   212  	}
   213  
   214  	ms := messageService{
   215  		subscriptionFunc:  subFunc,
   216  		subscribeWatchCtx: true,
   217  	}
   218  
   219  	rcli := New(ctx, &ms)
   220  
   221  	err := rcli.StartSession(ctx, sessionID, ch, target)
   222  	assert.NoError(t, err)
   223  	assert.Eventually(t, func() bool {
   224  		return subscriptionCalledCounter == 1
   225  	}, defaultWaitTime, defaultTickTime)
   226  
   227  	// Test ms.Subscribe call not duplicated
   228  	ch = make(chan msgdata.CommandResponse)
   229  	sessionID = uuid.NewString()
   230  	err = rcli.StartSession(ctx, sessionID, ch, target)
   231  	assert.NoError(t, err)
   232  	assert.Never(t, func() bool {
   233  		return subscriptionCalledCounter == 2
   234  	}, defaultWaitTime, defaultTickTime)
   235  
   236  	// Test ms.subscribe called for different subscription
   237  	ch = make(chan msgdata.CommandResponse)
   238  	sessionID = uuid.NewString()
   239  	target.storeID = "anotherstore"
   240  	err = rcli.StartSession(ctx, sessionID, ch, target)
   241  	assert.NoError(t, err)
   242  	assert.Eventually(t, func() bool {
   243  		return subscriptionCalledCounter == 2
   244  	}, defaultWaitTime, defaultTickTime)
   245  }
   246  
   247  func TestStartSessionInvalidTarget(t *testing.T) {
   248  	buf := bytes.Buffer{}
   249  	logger := createLogger(&buf)
   250  	sessionID := uuid.NewString()
   251  
   252  	ctx := fog.IntoContext(context.Background(), logger)
   253  	ch, ms := make(chan msgdata.CommandResponse), messageService{}
   254  	rcli := New(ctx, &ms)
   255  	target := target{}
   256  
   257  	err := rcli.StartSession(ctx, sessionID, ch, target)
   258  	assert.Contains(t, err.Error(), validateTarget(target).Error())
   259  }
   260  
   261  func TestSend(t *testing.T) {
   262  	ctx := fog.IntoContext(context.Background(), logr.Discard())
   263  	ch := make(chan msgdata.CommandResponse)
   264  	ms := messageService{subscribeWatchCtx: true}
   265  	sessionID := uuid.NewString()
   266  	rcli := New(ctx, &ms)
   267  
   268  	identity, target := defaultIdentity, defaultTarget
   269  	target.projectID = "publish_error"
   270  	command := "echo hello"
   271  	request, err := msgdata.NewV1_0Request(command)
   272  	assert.NoError(t, err)
   273  	commandID := uuid.NewString()
   274  
   275  	err = rcli.StartSession(ctx, sessionID, ch, target)
   276  	assert.NoError(t, err)
   277  
   278  	err = rcli.Send(ctx, identity, sessionID, commandID, request)
   279  	assert.Contains(t, err.Error(), fmt.Errorf("publish returned error").Error())
   280  
   281  	assert.NotEmpty(t, ms.pubCommandID)
   282  	assert.Equal(t, commandID, ms.pubCommandID)
   283  }
   284  
   285  func TestSendNoSessionStarted(t *testing.T) {
   286  	buf := bytes.Buffer{}
   287  	logger := createLogger(&buf)
   288  
   289  	ctx := fog.IntoContext(context.Background(), logger)
   290  	ms := messageService{}
   291  	rcli := New(ctx, &ms)
   292  	identity, command, sessionID := defaultIdentity, "echo hello", "invalid-session-id"
   293  	request, err := msgdata.NewV1_0Request(command)
   294  	assert.NoError(t, err)
   295  
   296  	err = rcli.Send(ctx, identity, sessionID, uuid.NewString(), request)
   297  	assert.Contains(t, err.Error(), "invalid session id")
   298  }
   299  
   300  func TestEndSession(t *testing.T) {
   301  	buf := bytes.Buffer{}
   302  	logger := createLogger(&buf)
   303  	sessionID := uuid.NewString()
   304  
   305  	ms := messageService{
   306  		subscribeWatchCtx: true,
   307  	}
   308  
   309  	ctx := fog.IntoContext(context.Background(), logger)
   310  	ch, target := make(chan msgdata.CommandResponse), defaultTarget
   311  	rcli := New(ctx, &ms)
   312  	logOK := "Session stopped"
   313  
   314  	var err error
   315  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, "nonexistant_session") })
   316  	assert.ErrorContains(t, err, "unknown session ID")
   317  	buf.Reset()
   318  
   319  	err = rcli.StartSession(ctx, sessionID, ch, target)
   320  	assert.NoError(t, err)
   321  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
   322  	assert.NoError(t, err)
   323  	assert.Contains(t, buf.String(), logOK)
   324  	buf.Reset()
   325  
   326  	// Call EndSession with same SessionID again
   327  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
   328  	assert.ErrorContains(t, err, "unknown session ID")
   329  	buf.Reset()
   330  
   331  	// verifies the context was cancelled
   332  	assert.Eventually(t, func() bool {
   333  		select {
   334  		case v, ok := <-ch:
   335  			if ok {
   336  				t.Errorf("unexpected value on display channel: %s", v)
   337  				return false
   338  			}
   339  			return true
   340  		default:
   341  			return false
   342  		}
   343  	}, 5*time.Second, time.Microsecond)
   344  }
   345  
   346  func TestEndSession_Topics(t *testing.T) {
   347  	buf := bytes.Buffer{}
   348  	logger := createLogger(&buf)
   349  	sessionID := uuid.NewString()
   350  
   351  	// Info struct and map to track calls of messageService StopPublish function
   352  	type publishInfo struct {
   353  		topicID   string
   354  		projectID string
   355  	}
   356  	endSessionCallMap := make(map[publishInfo]int)
   357  
   358  	// Called whenever the messageService StopPublish function is called
   359  	// Tracks number of times func is called with each info
   360  	stopPublishFunc := func(topicID string, projectID string) {
   361  		pInfo := publishInfo{topicID, projectID}
   362  		endSessionCallMap[pInfo] = endSessionCallMap[pInfo] + 1
   363  	}
   364  
   365  	ms := messageService{
   366  		stopPublishFunc:   stopPublishFunc,
   367  		subscribeWatchCtx: true,
   368  	}
   369  
   370  	ctx := fog.IntoContext(context.Background(), logger)
   371  	ch, target := make(chan msgdata.CommandResponse), defaultTarget
   372  	rcli := New(ctx, &ms)
   373  
   374  	var err error
   375  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
   376  	assert.ErrorContains(t, err, "unknown session ID")
   377  
   378  	request, err := msgdata.NewV1_0Request("command")
   379  	assert.NoError(t, err)
   380  
   381  	// Session has not been posted to, so topic cleanup should not have been run
   382  	assert.Len(t, endSessionCallMap, 0)
   383  
   384  	// Start a session and send message to two different topics
   385  	assert.NoError(t, rcli.StartSession(ctx, sessionID, ch, target))
   386  	assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request, WithOptionalTemplate("a")))
   387  	assert.NoError(t, rcli.Send(ctx, defaultIdentity, sessionID, uuid.NewString(), request))
   388  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
   389  	assert.NoError(t, err)
   390  
   391  	// Check Stop Publish is called once for each of the topics published to
   392  	assert.Len(t, endSessionCallMap, 2)
   393  	assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}])
   394  	assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}])
   395  
   396  	// Repeat endsession call with same session ID
   397  	assert.NotPanics(t, func() { err = rcli.EndSession(ctx, sessionID) })
   398  	assert.ErrorContains(t, err, "unknown session ID")
   399  
   400  	// Check StopPublish has not been called again
   401  	assert.Len(t, endSessionCallMap, 2)
   402  	assert.Equal(t, 1, endSessionCallMap[publishInfo{"a", target.projectID}])
   403  	assert.Equal(t, 1, endSessionCallMap[publishInfo{eaconst.DefaultTopTemplate, target.projectID}])
   404  }
   405  
   406  func handlerAttrs(sessionID string) map[string]string {
   407  	return map[string]string{
   408  		"bannerId":             "banner",
   409  		"storeId":              "store",
   410  		"terminalId":           "terminal",
   411  		"sessionId":            sessionID,
   412  		"identity":             "identity",
   413  		"version":              "1.0",
   414  		"signature":            "signature",
   415  		"request-message-uuid": "id",
   416  	}
   417  }
   418  
   419  func handlerData(output string) []byte {
   420  	return []byte(fmt.Sprintf(`
   421  {
   422  	"type": "Output",
   423  	"exitCode": 0,
   424  	"output": "%s",
   425  	"timestamp": "01-01-2023 00:00:00",
   426  	"duration": 0.1
   427  }`, output))
   428  }
   429  
   430  func TestHandler(t *testing.T) {
   431  	ch1 := make(chan msgdata.CommandResponse, 3)
   432  	ch2 := make(chan msgdata.CommandResponse, 1)
   433  	ch3 := make(chan msgdata.CommandResponse)
   434  
   435  	rcli := RemoteCLI{
   436  		sessionLock: &sync.RWMutex{},
   437  		sessionData: map[string]sessionData{
   438  			"orderingKey": {
   439  				displayChan: ch1,
   440  			},
   441  			"orderingKey2": {
   442  				displayChan: ch2,
   443  			},
   444  			"orderingKey3": {
   445  				displayChan: ch3,
   446  			},
   447  		},
   448  	}
   449  	ctx, cancelFunc := context.WithCancel(context.Background())
   450  	fn := rcli.handler()
   451  
   452  	data1 := handlerData("message 1")
   453  	data2 := handlerData("message 2")
   454  	data3 := handlerData("other ordering key")
   455  	data4 := handlerData("this message should not be received")
   456  
   457  	msg1, _ := msgdata.NewCommandResponse(data1, handlerAttrs("orderingKey"))
   458  	msg2, _ := msgdata.NewCommandResponse(data2, handlerAttrs("orderingKey"))
   459  	msg3, _ := msgdata.NewCommandResponse(data3, handlerAttrs("orderingKey2"))
   460  	msg4, _ := msgdata.NewCommandResponse(data4, handlerAttrs("orderingKey3"))
   461  
   462  	fn(ctx, msg1)
   463  	assert.Empty(t, ch2)
   464  	assert.Equal(t, msg1, <-ch1)
   465  
   466  	fn(ctx, msg2)
   467  	assert.Empty(t, ch2)
   468  	assert.Equal(t, msg2, <-ch1)
   469  
   470  	// New sessionID
   471  	fn(context.Background(), msg3)
   472  	assert.Empty(t, ch1)
   473  	assert.Equal(t, msg3, <-ch2)
   474  
   475  	cancelFunc()
   476  	time.Sleep(time.Millisecond * 100)
   477  
   478  	fn(ctx, msg4)
   479  	assert.Empty(t, ch1)
   480  	assert.Empty(t, ch2)
   481  	assert.Empty(t, ch3)
   482  }
   483  
   484  func TestCreateOptionalConfig(t *testing.T) {
   485  	template1, template2, template3 := "template-string-1", "template-string-2", "template-string-3"
   486  	opts := []RCLIOption{WithOptionalTemplate(template1), WithOptionalTemplate(template2), WithOptionalTemplate(template3)}
   487  
   488  	expected := &templateConfig{template: &template3}
   489  	assert.Equal(t, expected, createOptionalConfig(opts))
   490  
   491  	assert.Nil(t, createOptionalConfig(nil))
   492  }
   493  
   494  func TestFillTemplate(t *testing.T) {
   495  	defaultTemplate := "default.<PROJECT_ID>.<BANNER_ID>.<STORE_ID>.<TERMINAL_ID>"
   496  	optionalTemplate := "optional.<PROJECT_ID>.<BANNER_ID>.<STORE_ID>.<TERMINAL_ID>"
   497  
   498  	target := defaultTarget
   499  	config := &templateConfig{template: &optionalTemplate}
   500  
   501  	expected := fmt.Sprintf("%s.%s.%s.%s", target.projectID, target.bannerID, target.storeID, target.terminalID)
   502  	assert.Equal(t, "default."+expected, fillTemplate(target, defaultTemplate, nil))
   503  	assert.Equal(t, "optional."+expected, fillTemplate(target, defaultTemplate, config))
   504  }
   505  
   506  func createLogger(buf *bytes.Buffer) logr.Logger {
   507  	return funcr.New(func(prefix, args string) {
   508  		if prefix != "" {
   509  			fmt.Fprintf(buf, "%s: %s\n", prefix, args)
   510  		} else {
   511  			fmt.Fprintln(buf, args)
   512  		}
   513  	}, funcr.Options{})
   514  }
   515  

View as plain text