...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/user/service/userservice_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/user/service

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/DATA-DOG/go-sqlmock"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  // testing helper type
    16  type helper interface {
    17  	Helper()
    18  }
    19  
    20  func EqualError(message string) assert.ErrorAssertionFunc {
    21  	return func(t assert.TestingT, err error, i ...interface{}) bool {
    22  		if tt, ok := t.(helper); ok {
    23  			tt.Helper()
    24  		}
    25  
    26  		return assert.EqualError(t, err, message, i...)
    27  	}
    28  }
    29  
    30  // StringSliceValueConverter converts a slice of strings to a PostgreSQL array representation.
    31  type StringSliceValueConverter struct{}
    32  
    33  // ConvertValue implements the driver.ValueConverter interface.
    34  func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) {
    35  	if vv, ok := v.([]string); ok {
    36  		// Convert []string to a PostgreSQL array representation.
    37  		// Note: Proper escaping and handling of special characters is necessary for production code.
    38  		arrayStr := "{" + strings.Join(vv, ",") + "}"
    39  		return arrayStr, nil
    40  	}
    41  	// Fallback for other types.
    42  	return driver.DefaultParameterConverter.ConvertValue(v)
    43  }
    44  
    45  func TestGetEARoles(t *testing.T) {
    46  	t.Parallel()
    47  
    48  	tests := map[string]struct {
    49  		roles        []string
    50  		expectations func(sqlmock.Sqlmock)
    51  		expResult    []string
    52  		expErr       assert.ErrorAssertionFunc
    53  	}{
    54  		"Single role query": {
    55  			roles: []string{"role1"},
    56  			expectations: func(s sqlmock.Sqlmock) {
    57  				s.ExpectQuery(getEARolesQuery).
    58  					WithArgs([]string{"role1"}).
    59  					RowsWillBeClosed().
    60  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
    61  						AddRow("role1", "priv1"),
    62  					)
    63  			},
    64  			expResult: []string{"priv1"},
    65  			expErr:    assert.NoError,
    66  		},
    67  		"Multiple rows repeated privilege": {
    68  			roles: []string{"role1", "role2"},
    69  			expectations: func(s sqlmock.Sqlmock) {
    70  				s.ExpectQuery(getEARolesQuery).
    71  					WithArgs([]string{"role1", "role2"}).
    72  					RowsWillBeClosed().
    73  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
    74  						AddRow("role1", "priv1").
    75  						AddRow("role2", "priv1"),
    76  					)
    77  			},
    78  			expResult: []string{"priv1"},
    79  			expErr:    assert.NoError,
    80  		},
    81  		"Multiple rows different privileges": {
    82  			roles: []string{"role1", "role2"},
    83  			expectations: func(s sqlmock.Sqlmock) {
    84  				s.ExpectQuery(getEARolesQuery).
    85  					WithArgs([]string{"role1", "role2"}).
    86  					RowsWillBeClosed().
    87  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
    88  						AddRow("role1", "priv1").
    89  						AddRow("role2", "priv2"),
    90  					)
    91  			},
    92  			expResult: []string{"priv1", "priv2"},
    93  			expErr:    assert.NoError,
    94  		},
    95  		"Missing ea-role query": {
    96  			roles: []string{},
    97  			expectations: func(_ sqlmock.Sqlmock) {
    98  				// Should not execute any queries
    99  			},
   100  			expResult: []string{},
   101  			expErr:    assert.NoError,
   102  		},
   103  		"Unknown ea-role mapping": {
   104  			roles: []string{"unknown"},
   105  			expectations: func(s sqlmock.Sqlmock) {
   106  				s.ExpectQuery(getEARolesQuery).
   107  					WithArgs([]string{"unknown"}).
   108  					RowsWillBeClosed().
   109  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}))
   110  			},
   111  			expResult: []string{},
   112  			expErr:    assert.NoError,
   113  		},
   114  		"Query error": {
   115  			roles: []string{"role1"},
   116  			expectations: func(s sqlmock.Sqlmock) {
   117  				s.ExpectQuery(getEARolesQuery).
   118  					WithArgs([]string{"role1"}).
   119  					WillReturnError(fmt.Errorf("an error"))
   120  			},
   121  			expErr: EqualError("error querying earoles: an error"),
   122  		},
   123  		"Scan error": {
   124  			roles: []string{"role1"},
   125  			expectations: func(s sqlmock.Sqlmock) {
   126  				s.ExpectQuery(getEARolesQuery).
   127  					WithArgs([]string{"role1"}).
   128  					RowsWillBeClosed().
   129  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
   130  						AddRow("role1", "priv1").
   131  						RowError(0, fmt.Errorf("error on row 0")),
   132  					)
   133  			},
   134  			expErr: EqualError("error while reading earoles: error on row 0"),
   135  		},
   136  		"Scan Error": {
   137  			roles: []string{"role1"},
   138  			expectations: func(s sqlmock.Sqlmock) {
   139  				s.ExpectQuery(getEARolesQuery).
   140  					WithArgs([]string{"role1"}).
   141  					RowsWillBeClosed().
   142  					WillReturnRows(sqlmock.NewRows([]string{"role_name", "name", "invalid"}).
   143  						AddRow("role1", "priv1", "invalid column"),
   144  					)
   145  			},
   146  			expResult: nil,
   147  			expErr:    EqualError("error scanning earoles row: sql: expected 3 destination arguments in Scan, not 2"),
   148  		},
   149  	}
   150  
   151  	for name, tc := range tests {
   152  		tc := tc
   153  		t.Run(name, func(t *testing.T) {
   154  			t.Parallel()
   155  
   156  			ctx := context.Background()
   157  
   158  			db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{}))
   159  			assert.NoError(t, err)
   160  
   161  			tc.expectations(mock)
   162  
   163  			us, err := NewUser(db)
   164  			assert.NoError(t, err)
   165  
   166  			res, err := us.GetEARoles(ctx, tc.roles)
   167  
   168  			assert.ElementsMatch(t, tc.expResult, res)
   169  
   170  			tc.expErr(t, err)
   171  
   172  			assert.NoError(t, mock.ExpectationsWereMet())
   173  		})
   174  	}
   175  }
   176  
   177  type testCase struct {
   178  	query     []string
   179  	expResult []string
   180  	expErr    assert.ErrorAssertionFunc
   181  }
   182  
   183  func testServerEndpoint(t *testing.T, tc testCase, us User) {
   184  	t.Helper()
   185  
   186  	ctx := context.Background()
   187  	res, err := us.GetEARoles(ctx, tc.query)
   188  	assert.ElementsMatch(t, tc.expResult, res)
   189  	tc.expErr(t, err)
   190  }
   191  
   192  func TestCache(t *testing.T) {
   193  	t.Parallel()
   194  
   195  	role := "role1"
   196  
   197  	db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{}))
   198  	assert.NoError(t, err)
   199  
   200  	expiration := 500 * time.Millisecond
   201  	us := User{
   202  		db: db,
   203  		cache: &cache{
   204  			cache:      make(map[string][]string),
   205  			expiration: expiration,
   206  		},
   207  	}
   208  
   209  	mock.ExpectQuery(getEARolesQuery).
   210  		WithArgs([]string{role}).
   211  		RowsWillBeClosed().
   212  		WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
   213  			AddRow(role, "priv1"),
   214  		)
   215  
   216  	tc := testCase{
   217  		[]string{role},
   218  		[]string{"priv1"},
   219  		assert.NoError,
   220  	}
   221  
   222  	// Test that the db is called and results are cached
   223  	t.Run("Initial DB Call", func(t *testing.T) {
   224  		testServerEndpoint(t, tc, us)
   225  	})
   226  	assert.Equal(t, []string{"priv1"}, us.cache.Get(role))
   227  
   228  	// Update the db
   229  	mock.ExpectQuery(getEARolesQuery).
   230  		WithArgs([]string{role}).
   231  		RowsWillBeClosed().
   232  		WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}).
   233  			AddRow(role, "priv2"),
   234  		)
   235  
   236  	// Test that the cache value is used while not expired
   237  	t.Run("Cache Value Is Used", func(t *testing.T) {
   238  		testServerEndpoint(t, tc, us)
   239  	})
   240  
   241  	assert.Eventually(t, func() bool {
   242  		return len(us.cache.Get(role)) == 0
   243  	}, expiration+(100*time.Millisecond), 10*time.Millisecond)
   244  	tc = testCase{
   245  		[]string{role},
   246  		[]string{"priv2"},
   247  		assert.NoError,
   248  	}
   249  	// Test that after expiry the db is called again
   250  	t.Run("New DB Value Is Returned", func(t *testing.T) {
   251  		testServerEndpoint(t, tc, us)
   252  	})
   253  	assert.Equal(t, []string{"priv2"}, us.cache.Get(role))
   254  
   255  	assert.NoError(t, mock.ExpectationsWereMet())
   256  }
   257  
   258  func TestCheckCache(t *testing.T) {
   259  	t.Parallel()
   260  
   261  	testCache := map[string][]string{
   262  		"role1": {"priv1", "priv2"},
   263  		"role2": {"priv2", "priv3"},
   264  		"role3": {"priv1", "priv3"},
   265  	}
   266  
   267  	tests := map[string]struct {
   268  		roles             []string
   269  		expRemainingRoles []string
   270  		expPrivileges     privilegeSet
   271  	}{
   272  		"Roles In Cache": {
   273  			roles: []string{"role1", "role3"},
   274  			expPrivileges: privilegeSet{
   275  				"priv1": struct{}{},
   276  				"priv2": struct{}{},
   277  				"priv3": struct{}{},
   278  			},
   279  		},
   280  		"Roles Not In Cache": {
   281  			roles:             []string{"rolenotincache1", "rolenotincache2"},
   282  			expRemainingRoles: []string{"rolenotincache1", "rolenotincache2"},
   283  			expPrivileges:     privilegeSet{},
   284  		},
   285  		"Mix of Roles In And Not In Cache": {
   286  			roles:             []string{"role1", "rolenotincache1"},
   287  			expRemainingRoles: []string{"rolenotincache1"},
   288  			expPrivileges: privilegeSet{
   289  				"priv1": struct{}{},
   290  				"priv2": struct{}{},
   291  			},
   292  		},
   293  	}
   294  
   295  	for name, tc := range tests {
   296  		tc := tc
   297  		t.Run(name, func(t *testing.T) {
   298  			t.Parallel()
   299  
   300  			us := User{
   301  				cache: &cache{
   302  					cache: testCache,
   303  				},
   304  			}
   305  			remainingRoles, privileges := us.checkCache(tc.roles)
   306  			assert.Equal(t, tc.expRemainingRoles, remainingRoles)
   307  			assert.Equal(t, tc.expPrivileges, privileges)
   308  		})
   309  	}
   310  }
   311  

View as plain text