...

Source file src/edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/data_test.go

Documentation: edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database

     1  package database
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"testing"
     8  
     9  	"edge-infra.dev/pkg/lib/fog"
    10  	"edge-infra.dev/pkg/lib/uuid"
    11  	datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql"
    12  
    13  	"github.com/DATA-DOG/go-sqlmock"
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  // assert.ErrorAssertionFunc
    18  func EqualError(message string) assert.ErrorAssertionFunc {
    19  	return func(t assert.TestingT, err error, i ...interface{}) bool {
    20  		return assert.EqualError(t, err, message, i...)
    21  	}
    22  }
    23  
    24  func initMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) {
    25  	db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
    26  	if err != nil {
    27  		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
    28  	}
    29  	return db, mock
    30  }
    31  
    32  func TestNew(t *testing.T) {
    33  	log := fog.New()
    34  	db, _, err := sqlmock.New()
    35  	assert.NoError(t, err)
    36  
    37  	expected := Dataset{log, db}
    38  	actual := New(log, db)
    39  
    40  	assert.Equal(t, expected, actual)
    41  }
    42  
    43  func TestGetProjectAndBannerID(t *testing.T) {
    44  	t.Parallel()
    45  
    46  	validBannerUUID, validProjectUUID := uuid.New().UUID, uuid.New().UUID
    47  
    48  	tests := map[string]struct {
    49  		banner string
    50  
    51  		expectations   func(mock sqlmock.Sqlmock)
    52  		expProjectID   string
    53  		expBannerID    string
    54  		errorAssertion assert.ErrorAssertionFunc
    55  	}{
    56  		"Using Banner ID": {
    57  			banner: validBannerUUID,
    58  			expectations: func(mock sqlmock.Sqlmock) {
    59  				mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
    60  					WithArgs(validBannerUUID, validBannerUUID).
    61  					WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID))
    62  			},
    63  			expProjectID:   validProjectUUID,
    64  			expBannerID:    validBannerUUID,
    65  			errorAssertion: assert.NoError,
    66  		},
    67  		"Using Banner Name": {
    68  			banner: "name",
    69  			expectations: func(mock sqlmock.Sqlmock) {
    70  				mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
    71  					WithArgs(nil, "name").
    72  					WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID))
    73  			},
    74  			expProjectID:   validProjectUUID,
    75  			expBannerID:    validBannerUUID,
    76  			errorAssertion: assert.NoError,
    77  		},
    78  		"Query Error": {
    79  			banner: "name",
    80  			expectations: func(mock sqlmock.Sqlmock) {
    81  				mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
    82  					WithArgs(nil, "name").
    83  					WillReturnError(fmt.Errorf("error"))
    84  			},
    85  			expProjectID:   "",
    86  			expBannerID:    "",
    87  			errorAssertion: EqualError("error querying db in data:GetProjectIDAndBannerID: error"),
    88  		},
    89  		"Multiple Rows": {
    90  			banner: "name",
    91  			expectations: func(mock sqlmock.Sqlmock) {
    92  				mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
    93  					WithArgs(nil, "name").
    94  					WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID).AddRow("oops", "double-oops"))
    95  			},
    96  			expProjectID:   "",
    97  			expBannerID:    "",
    98  			errorAssertion: EqualError("error scanning rows in data:GetProjectIDAndBannerID: error multiple rows returned in data:scanRows"),
    99  		},
   100  		"No Rows Returned": {
   101  			banner: "name",
   102  			expectations: func(mock sqlmock.Sqlmock) {
   103  				mock.ExpectQuery(datasql.SelectProjectIDAndBannerID).
   104  					WithArgs(nil, "name").
   105  					WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}))
   106  			},
   107  			expProjectID:   "",
   108  			expBannerID:    "",
   109  			errorAssertion: assert.NoError,
   110  		},
   111  	}
   112  
   113  	for name, tc := range tests {
   114  		tc := tc
   115  		t.Run(name, func(t *testing.T) {
   116  			t.Parallel()
   117  
   118  			db, mock := initMockDB(t)
   119  			defer db.Close()
   120  
   121  			ds := Dataset{db: db}
   122  
   123  			tc.expectations(mock)
   124  
   125  			projectID, bannerID, err := ds.GetProjectAndBannerID(context.Background(), tc.banner)
   126  			tc.errorAssertion(t, err)
   127  			assert.Equal(t, tc.expProjectID, projectID)
   128  			assert.Equal(t, tc.expBannerID, bannerID)
   129  
   130  			assert.NoError(t, mock.ExpectationsWereMet())
   131  		})
   132  	}
   133  }
   134  
   135  //nolint:dupl
   136  func TestGetStoreID(t *testing.T) {
   137  	t.Parallel()
   138  
   139  	validUUID := uuid.New().UUID
   140  
   141  	tests := map[string]struct {
   142  		store    string
   143  		bannerID string
   144  
   145  		expectations   func(mock sqlmock.Sqlmock)
   146  		expID          string
   147  		errorAssertion assert.ErrorAssertionFunc
   148  	}{
   149  		"Using Store ID": {
   150  			store:    validUUID,
   151  			bannerID: "id",
   152  			expectations: func(mock sqlmock.Sqlmock) {
   153  				mock.ExpectQuery(datasql.SelectStoreID).
   154  					WithArgs(validUUID, validUUID, "id").
   155  					WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID))
   156  			},
   157  			expID:          validUUID,
   158  			errorAssertion: assert.NoError,
   159  		},
   160  		"Using Store Name": {
   161  			store:    "name",
   162  			bannerID: "id",
   163  			expectations: func(mock sqlmock.Sqlmock) {
   164  				mock.ExpectQuery(datasql.SelectStoreID).
   165  					WithArgs(nil, "name", "id").
   166  					WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID))
   167  			},
   168  			expID:          validUUID,
   169  			errorAssertion: assert.NoError,
   170  		},
   171  		"Query Error": {
   172  			store:    "name",
   173  			bannerID: "id",
   174  			expectations: func(mock sqlmock.Sqlmock) {
   175  				mock.ExpectQuery(datasql.SelectStoreID).
   176  					WithArgs(nil, "name", "id").
   177  					WillReturnError(fmt.Errorf("error"))
   178  			},
   179  			expID:          "",
   180  			errorAssertion: EqualError("error querying db in data:GetStoreID: error"),
   181  		},
   182  		"Multiple Rows": {
   183  			store:    "name",
   184  			bannerID: "id",
   185  			expectations: func(mock sqlmock.Sqlmock) {
   186  				mock.ExpectQuery(datasql.SelectStoreID).
   187  					WithArgs(nil, "name", "id").
   188  					WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID).AddRow("oops"))
   189  			},
   190  			expID:          "",
   191  			errorAssertion: EqualError("error scanning rows in data:GetStoreID: error multiple rows returned in data:scanRows"),
   192  		},
   193  		"No Rows Returned": {
   194  			store:    "name",
   195  			bannerID: "id",
   196  			expectations: func(mock sqlmock.Sqlmock) {
   197  				mock.ExpectQuery(datasql.SelectStoreID).
   198  					WithArgs(nil, "name", "id").
   199  					WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}))
   200  			},
   201  			expID:          "",
   202  			errorAssertion: assert.NoError,
   203  		},
   204  	}
   205  
   206  	for name, tc := range tests {
   207  		tc := tc
   208  		t.Run(name, func(t *testing.T) {
   209  			t.Parallel()
   210  
   211  			db, mock := initMockDB(t)
   212  			defer db.Close()
   213  
   214  			ds := Dataset{db: db}
   215  
   216  			tc.expectations(mock)
   217  
   218  			storeID, err := ds.GetStoreID(context.Background(), tc.store, tc.bannerID)
   219  			tc.errorAssertion(t, err)
   220  			assert.Equal(t, tc.expID, storeID)
   221  
   222  			assert.NoError(t, mock.ExpectationsWereMet())
   223  		})
   224  	}
   225  }
   226  
   227  //nolint:dupl
   228  func TestGetTerminalID(t *testing.T) {
   229  	t.Parallel()
   230  
   231  	validUUID := uuid.New().UUID
   232  
   233  	tests := map[string]struct {
   234  		terminal string
   235  		storeID  string
   236  
   237  		expectations   func(mock sqlmock.Sqlmock)
   238  		expID          string
   239  		errorAssertion assert.ErrorAssertionFunc
   240  	}{
   241  		"Using Terminal ID": {
   242  			terminal: validUUID,
   243  			storeID:  "id",
   244  			expectations: func(mock sqlmock.Sqlmock) {
   245  				mock.ExpectQuery(datasql.SelectTerminalID).
   246  					WithArgs(validUUID, validUUID, "id").
   247  					WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID))
   248  			},
   249  			expID:          validUUID,
   250  			errorAssertion: assert.NoError,
   251  		},
   252  		"Using Terminal Name": {
   253  			terminal: "name",
   254  			storeID:  "id",
   255  			expectations: func(mock sqlmock.Sqlmock) {
   256  				mock.ExpectQuery(datasql.SelectTerminalID).
   257  					WithArgs(nil, "name", "id").
   258  					WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID))
   259  			},
   260  			expID:          validUUID,
   261  			errorAssertion: assert.NoError,
   262  		},
   263  		"Query Error": {
   264  			terminal: "name",
   265  			storeID:  "id",
   266  			expectations: func(mock sqlmock.Sqlmock) {
   267  				mock.ExpectQuery(datasql.SelectTerminalID).
   268  					WithArgs(nil, "name", "id").
   269  					WillReturnError(fmt.Errorf("error"))
   270  			},
   271  			expID:          "",
   272  			errorAssertion: EqualError("error querying db in data:GetTerminalID: error"),
   273  		},
   274  		"Multiple Rows": {
   275  			terminal: "name",
   276  			storeID:  "id",
   277  			expectations: func(mock sqlmock.Sqlmock) {
   278  				mock.ExpectQuery(datasql.SelectTerminalID).
   279  					WithArgs(nil, "name", "id").
   280  					WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID).AddRow("oops"))
   281  			},
   282  			expID:          "",
   283  			errorAssertion: EqualError("error scanning rows in data:GetTerminalID: error multiple rows returned in data:scanRows"),
   284  		},
   285  		"No Rows Returned": {
   286  			terminal: "name",
   287  			storeID:  "id",
   288  			expectations: func(mock sqlmock.Sqlmock) {
   289  				mock.ExpectQuery(datasql.SelectTerminalID).
   290  					WithArgs(nil, "name", "id").
   291  					WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}))
   292  			},
   293  			expID:          "",
   294  			errorAssertion: assert.NoError,
   295  		},
   296  	}
   297  
   298  	for name, tc := range tests {
   299  		tc := tc
   300  		t.Run(name, func(t *testing.T) {
   301  			t.Parallel()
   302  
   303  			db, mock := initMockDB(t)
   304  			defer db.Close()
   305  
   306  			ds := Dataset{db: db}
   307  
   308  			tc.expectations(mock)
   309  
   310  			terminalID, err := ds.GetTerminalID(context.Background(), tc.terminal, tc.storeID)
   311  			tc.errorAssertion(t, err)
   312  			assert.Equal(t, tc.expID, terminalID)
   313  
   314  			assert.NoError(t, mock.ExpectationsWereMet())
   315  		})
   316  	}
   317  }
   318  
   319  func mockRowsToSQLRows(mockRows *sqlmock.Rows) (*sql.Rows, error) {
   320  	db, mock, err := sqlmock.New()
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  	defer db.Close()
   325  	mock.ExpectQuery("select").WillReturnRows(mockRows)
   326  	rows, err := db.Query("select")
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	return rows, nil
   331  }
   332  
   333  func TestScanRowsSingleColumn(t *testing.T) {
   334  	t.Parallel()
   335  
   336  	tests := map[string]struct {
   337  		mockRows  *sqlmock.Rows
   338  		expRes    string
   339  		errAssert assert.ErrorAssertionFunc
   340  	}{
   341  		"Valid": {
   342  			mockRows:  sqlmock.NewRows([]string{"col"}).AddRow("val"),
   343  			expRes:    "val",
   344  			errAssert: assert.NoError,
   345  		},
   346  		"No Rows": {
   347  			mockRows:  sqlmock.NewRows([]string{}),
   348  			expRes:    "",
   349  			errAssert: assert.NoError,
   350  		},
   351  		"Multiple Rows": {
   352  			mockRows:  sqlmock.NewRows([]string{"col"}).AddRow("val").AddRow("oops"),
   353  			expRes:    "val",
   354  			errAssert: EqualError("error multiple rows returned in data:scanRows"),
   355  		},
   356  	}
   357  
   358  	for name, tc := range tests {
   359  		tc := tc
   360  		t.Run(name, func(t *testing.T) {
   361  			t.Parallel()
   362  
   363  			rows, err := mockRowsToSQLRows(tc.mockRows)
   364  			assert.NoError(t, err)
   365  			defer rows.Close()
   366  
   367  			empty := ""
   368  			result := &empty
   369  			err = scanRowsForIDs(rows, &result)
   370  			tc.errAssert(t, err)
   371  			assert.Equal(t, tc.expRes, *result)
   372  		})
   373  	}
   374  }
   375  
   376  func TestScanRowsMultiColumn(t *testing.T) {
   377  	t.Parallel()
   378  
   379  	tests := map[string]struct {
   380  		mockRows  *sqlmock.Rows
   381  		expRes1   string
   382  		expRes2   string
   383  		errAssert assert.ErrorAssertionFunc
   384  	}{
   385  		"Valid": {
   386  			mockRows:  sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2"),
   387  			expRes1:   "val1",
   388  			expRes2:   "val2",
   389  			errAssert: assert.NoError,
   390  		},
   391  		"No Rows": {
   392  			mockRows:  sqlmock.NewRows([]string{}),
   393  			expRes1:   "",
   394  			expRes2:   "",
   395  			errAssert: assert.NoError,
   396  		},
   397  		"Multiple Rows": {
   398  			mockRows:  sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2").AddRow("oops1", "oops2"),
   399  			expRes1:   "val1",
   400  			expRes2:   "val2",
   401  			errAssert: EqualError("error multiple rows returned in data:scanRows"),
   402  		},
   403  	}
   404  
   405  	for name, tc := range tests {
   406  		tc := tc
   407  		t.Run(name, func(t *testing.T) {
   408  			t.Parallel()
   409  
   410  			rows, err := mockRowsToSQLRows(tc.mockRows)
   411  			assert.NoError(t, err)
   412  			defer rows.Close()
   413  
   414  			empty := ""
   415  			result1 := &empty
   416  			result2 := &empty
   417  			err = scanRowsForIDs(rows, &result1, &result2)
   418  			tc.errAssert(t, err)
   419  			assert.Equal(t, tc.expRes1, *result1)
   420  			assert.Equal(t, tc.expRes2, *result2)
   421  		})
   422  	}
   423  }
   424  
   425  func TestIsUUID(t *testing.T) {
   426  	t.Parallel()
   427  
   428  	tests := map[string]struct {
   429  		val string
   430  		exp bool
   431  	}{
   432  		"Valid UUID": {
   433  			val: uuid.New().UUID,
   434  			exp: true,
   435  		},
   436  		"Invalid UUID": {
   437  			val: "an-invalid-uuid",
   438  			exp: false,
   439  		},
   440  		"Empty String": {
   441  			val: "",
   442  			exp: false,
   443  		},
   444  	}
   445  
   446  	for name, tc := range tests {
   447  		tc := tc
   448  		t.Run(name, func(t *testing.T) {
   449  			t.Parallel()
   450  			assert.Equal(t, tc.exp, isUUID(tc.val))
   451  		})
   452  	}
   453  }
   454  
   455  func TestSafeStringDereference(t *testing.T) {
   456  	s := "a-string"
   457  
   458  	tests := map[string]struct {
   459  		input    *string
   460  		expected string
   461  	}{
   462  		"Valid": {
   463  			input:    &s,
   464  			expected: s,
   465  		},
   466  		"Empty String": {
   467  			input:    new(string),
   468  			expected: "",
   469  		},
   470  		"Nil": {
   471  			input:    nil,
   472  			expected: "",
   473  		},
   474  	}
   475  
   476  	for name, tc := range tests {
   477  		tc := tc
   478  		t.Run(name, func(t *testing.T) {
   479  			t.Parallel()
   480  			assert.Equal(t, tc.expected, safeStringDereference(tc.input))
   481  		})
   482  	}
   483  }
   484  

View as plain text