package service import ( "context" "database/sql/driver" "fmt" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" ) // testing helper type type helper interface { Helper() } func EqualError(message string) assert.ErrorAssertionFunc { return func(t assert.TestingT, err error, i ...interface{}) bool { if tt, ok := t.(helper); ok { tt.Helper() } return assert.EqualError(t, err, message, i...) } } // StringSliceValueConverter converts a slice of strings to a PostgreSQL array representation. type StringSliceValueConverter struct{} // ConvertValue implements the driver.ValueConverter interface. func (c StringSliceValueConverter) ConvertValue(v interface{}) (driver.Value, error) { if vv, ok := v.([]string); ok { // Convert []string to a PostgreSQL array representation. // Note: Proper escaping and handling of special characters is necessary for production code. arrayStr := "{" + strings.Join(vv, ",") + "}" return arrayStr, nil } // Fallback for other types. return driver.DefaultParameterConverter.ConvertValue(v) } func TestGetEARoles(t *testing.T) { t.Parallel() tests := map[string]struct { roles []string expectations func(sqlmock.Sqlmock) expResult []string expErr assert.ErrorAssertionFunc }{ "Single role query": { roles: []string{"role1"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow("role1", "priv1"), ) }, expResult: []string{"priv1"}, expErr: assert.NoError, }, "Multiple rows repeated privilege": { roles: []string{"role1", "role2"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1", "role2"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow("role1", "priv1"). AddRow("role2", "priv1"), ) }, expResult: []string{"priv1"}, expErr: assert.NoError, }, "Multiple rows different privileges": { roles: []string{"role1", "role2"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1", "role2"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow("role1", "priv1"). AddRow("role2", "priv2"), ) }, expResult: []string{"priv1", "priv2"}, expErr: assert.NoError, }, "Missing ea-role query": { roles: []string{}, expectations: func(_ sqlmock.Sqlmock) { // Should not execute any queries }, expResult: []string{}, expErr: assert.NoError, }, "Unknown ea-role mapping": { roles: []string{"unknown"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"unknown"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"})) }, expResult: []string{}, expErr: assert.NoError, }, "Query error": { roles: []string{"role1"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1"}). WillReturnError(fmt.Errorf("an error")) }, expErr: EqualError("error querying earoles: an error"), }, "Scan error": { roles: []string{"role1"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow("role1", "priv1"). RowError(0, fmt.Errorf("error on row 0")), ) }, expErr: EqualError("error while reading earoles: error on row 0"), }, "Scan Error": { roles: []string{"role1"}, expectations: func(s sqlmock.Sqlmock) { s.ExpectQuery(getEARolesQuery). WithArgs([]string{"role1"}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name", "invalid"}). AddRow("role1", "priv1", "invalid column"), ) }, expResult: nil, expErr: EqualError("error scanning earoles row: sql: expected 3 destination arguments in Scan, not 2"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() ctx := context.Background() db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{})) assert.NoError(t, err) tc.expectations(mock) us, err := NewUser(db) assert.NoError(t, err) res, err := us.GetEARoles(ctx, tc.roles) assert.ElementsMatch(t, tc.expResult, res) tc.expErr(t, err) assert.NoError(t, mock.ExpectationsWereMet()) }) } } type testCase struct { query []string expResult []string expErr assert.ErrorAssertionFunc } func testServerEndpoint(t *testing.T, tc testCase, us User) { t.Helper() ctx := context.Background() res, err := us.GetEARoles(ctx, tc.query) assert.ElementsMatch(t, tc.expResult, res) tc.expErr(t, err) } func TestCache(t *testing.T) { t.Parallel() role := "role1" db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), sqlmock.ValueConverterOption(StringSliceValueConverter{})) assert.NoError(t, err) expiration := 500 * time.Millisecond us := User{ db: db, cache: &cache{ cache: make(map[string][]string), expiration: expiration, }, } mock.ExpectQuery(getEARolesQuery). WithArgs([]string{role}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow(role, "priv1"), ) tc := testCase{ []string{role}, []string{"priv1"}, assert.NoError, } // Test that the db is called and results are cached t.Run("Initial DB Call", func(t *testing.T) { testServerEndpoint(t, tc, us) }) assert.Equal(t, []string{"priv1"}, us.cache.Get(role)) // Update the db mock.ExpectQuery(getEARolesQuery). WithArgs([]string{role}). RowsWillBeClosed(). WillReturnRows(sqlmock.NewRows([]string{"role_name", "name"}). AddRow(role, "priv2"), ) // Test that the cache value is used while not expired t.Run("Cache Value Is Used", func(t *testing.T) { testServerEndpoint(t, tc, us) }) assert.Eventually(t, func() bool { return len(us.cache.Get(role)) == 0 }, expiration+(100*time.Millisecond), 10*time.Millisecond) tc = testCase{ []string{role}, []string{"priv2"}, assert.NoError, } // Test that after expiry the db is called again t.Run("New DB Value Is Returned", func(t *testing.T) { testServerEndpoint(t, tc, us) }) assert.Equal(t, []string{"priv2"}, us.cache.Get(role)) assert.NoError(t, mock.ExpectationsWereMet()) } func TestCheckCache(t *testing.T) { t.Parallel() testCache := map[string][]string{ "role1": {"priv1", "priv2"}, "role2": {"priv2", "priv3"}, "role3": {"priv1", "priv3"}, } tests := map[string]struct { roles []string expRemainingRoles []string expPrivileges privilegeSet }{ "Roles In Cache": { roles: []string{"role1", "role3"}, expPrivileges: privilegeSet{ "priv1": struct{}{}, "priv2": struct{}{}, "priv3": struct{}{}, }, }, "Roles Not In Cache": { roles: []string{"rolenotincache1", "rolenotincache2"}, expRemainingRoles: []string{"rolenotincache1", "rolenotincache2"}, expPrivileges: privilegeSet{}, }, "Mix of Roles In And Not In Cache": { roles: []string{"role1", "rolenotincache1"}, expRemainingRoles: []string{"rolenotincache1"}, expPrivileges: privilegeSet{ "priv1": struct{}{}, "priv2": struct{}{}, }, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() us := User{ cache: &cache{ cache: testCache, }, } remainingRoles, privileges := us.checkCache(tc.roles) assert.Equal(t, tc.expRemainingRoles, remainingRoles) assert.Equal(t, tc.expPrivileges, privileges) }) } }