package database import ( "context" "database/sql" "fmt" "testing" "edge-infra.dev/pkg/lib/fog" "edge-infra.dev/pkg/lib/uuid" datasql "edge-infra.dev/pkg/sds/emergencyaccess/authservice/storage/database/sql" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" ) // assert.ErrorAssertionFunc func EqualError(message string) assert.ErrorAssertionFunc { return func(t assert.TestingT, err error, i ...interface{}) bool { return assert.EqualError(t, err, message, i...) } } func initMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } return db, mock } func TestNew(t *testing.T) { log := fog.New() db, _, err := sqlmock.New() assert.NoError(t, err) expected := Dataset{log, db} actual := New(log, db) assert.Equal(t, expected, actual) } func TestGetProjectAndBannerID(t *testing.T) { t.Parallel() validBannerUUID, validProjectUUID := uuid.New().UUID, uuid.New().UUID tests := map[string]struct { banner string expectations func(mock sqlmock.Sqlmock) expProjectID string expBannerID string errorAssertion assert.ErrorAssertionFunc }{ "Using Banner ID": { banner: validBannerUUID, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectProjectIDAndBannerID). WithArgs(validBannerUUID, validBannerUUID). WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID)) }, expProjectID: validProjectUUID, expBannerID: validBannerUUID, errorAssertion: assert.NoError, }, "Using Banner Name": { banner: "name", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectProjectIDAndBannerID). WithArgs(nil, "name"). WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID)) }, expProjectID: validProjectUUID, expBannerID: validBannerUUID, errorAssertion: assert.NoError, }, "Query Error": { banner: "name", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectProjectIDAndBannerID). WithArgs(nil, "name"). WillReturnError(fmt.Errorf("error")) }, expProjectID: "", expBannerID: "", errorAssertion: EqualError("error querying db in data:GetProjectIDAndBannerID: error"), }, "Multiple Rows": { banner: "name", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectProjectIDAndBannerID). WithArgs(nil, "name"). WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"}).AddRow(validProjectUUID, validBannerUUID).AddRow("oops", "double-oops")) }, expProjectID: "", expBannerID: "", errorAssertion: EqualError("error scanning rows in data:GetProjectIDAndBannerID: error multiple rows returned in data:scanRows"), }, "No Rows Returned": { banner: "name", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectProjectIDAndBannerID). WithArgs(nil, "name"). WillReturnRows(sqlmock.NewRows([]string{"project_id", "banner_edge_id"})) }, expProjectID: "", expBannerID: "", errorAssertion: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) projectID, bannerID, err := ds.GetProjectAndBannerID(context.Background(), tc.banner) tc.errorAssertion(t, err) assert.Equal(t, tc.expProjectID, projectID) assert.Equal(t, tc.expBannerID, bannerID) assert.NoError(t, mock.ExpectationsWereMet()) }) } } //nolint:dupl func TestGetStoreID(t *testing.T) { t.Parallel() validUUID := uuid.New().UUID tests := map[string]struct { store string bannerID string expectations func(mock sqlmock.Sqlmock) expID string errorAssertion assert.ErrorAssertionFunc }{ "Using Store ID": { store: validUUID, bannerID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectStoreID). WithArgs(validUUID, validUUID, "id"). WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID)) }, expID: validUUID, errorAssertion: assert.NoError, }, "Using Store Name": { store: "name", bannerID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectStoreID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID)) }, expID: validUUID, errorAssertion: assert.NoError, }, "Query Error": { store: "name", bannerID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectStoreID). WithArgs(nil, "name", "id"). WillReturnError(fmt.Errorf("error")) }, expID: "", errorAssertion: EqualError("error querying db in data:GetStoreID: error"), }, "Multiple Rows": { store: "name", bannerID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectStoreID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"}).AddRow(validUUID).AddRow("oops")) }, expID: "", errorAssertion: EqualError("error scanning rows in data:GetStoreID: error multiple rows returned in data:scanRows"), }, "No Rows Returned": { store: "name", bannerID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectStoreID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"cluster_edge_id"})) }, expID: "", errorAssertion: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) storeID, err := ds.GetStoreID(context.Background(), tc.store, tc.bannerID) tc.errorAssertion(t, err) assert.Equal(t, tc.expID, storeID) assert.NoError(t, mock.ExpectationsWereMet()) }) } } //nolint:dupl func TestGetTerminalID(t *testing.T) { t.Parallel() validUUID := uuid.New().UUID tests := map[string]struct { terminal string storeID string expectations func(mock sqlmock.Sqlmock) expID string errorAssertion assert.ErrorAssertionFunc }{ "Using Terminal ID": { terminal: validUUID, storeID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectTerminalID). WithArgs(validUUID, validUUID, "id"). WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID)) }, expID: validUUID, errorAssertion: assert.NoError, }, "Using Terminal Name": { terminal: "name", storeID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectTerminalID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID)) }, expID: validUUID, errorAssertion: assert.NoError, }, "Query Error": { terminal: "name", storeID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectTerminalID). WithArgs(nil, "name", "id"). WillReturnError(fmt.Errorf("error")) }, expID: "", errorAssertion: EqualError("error querying db in data:GetTerminalID: error"), }, "Multiple Rows": { terminal: "name", storeID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectTerminalID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"terminal_id"}).AddRow(validUUID).AddRow("oops")) }, expID: "", errorAssertion: EqualError("error scanning rows in data:GetTerminalID: error multiple rows returned in data:scanRows"), }, "No Rows Returned": { terminal: "name", storeID: "id", expectations: func(mock sqlmock.Sqlmock) { mock.ExpectQuery(datasql.SelectTerminalID). WithArgs(nil, "name", "id"). WillReturnRows(sqlmock.NewRows([]string{"terminal_id"})) }, expID: "", errorAssertion: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) terminalID, err := ds.GetTerminalID(context.Background(), tc.terminal, tc.storeID) tc.errorAssertion(t, err) assert.Equal(t, tc.expID, terminalID) assert.NoError(t, mock.ExpectationsWereMet()) }) } } func mockRowsToSQLRows(mockRows *sqlmock.Rows) (*sql.Rows, error) { db, mock, err := sqlmock.New() if err != nil { return nil, err } defer db.Close() mock.ExpectQuery("select").WillReturnRows(mockRows) rows, err := db.Query("select") if err != nil { return nil, err } return rows, nil } func TestScanRowsSingleColumn(t *testing.T) { t.Parallel() tests := map[string]struct { mockRows *sqlmock.Rows expRes string errAssert assert.ErrorAssertionFunc }{ "Valid": { mockRows: sqlmock.NewRows([]string{"col"}).AddRow("val"), expRes: "val", errAssert: assert.NoError, }, "No Rows": { mockRows: sqlmock.NewRows([]string{}), expRes: "", errAssert: assert.NoError, }, "Multiple Rows": { mockRows: sqlmock.NewRows([]string{"col"}).AddRow("val").AddRow("oops"), expRes: "val", errAssert: EqualError("error multiple rows returned in data:scanRows"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() rows, err := mockRowsToSQLRows(tc.mockRows) assert.NoError(t, err) defer rows.Close() empty := "" result := &empty err = scanRowsForIDs(rows, &result) tc.errAssert(t, err) assert.Equal(t, tc.expRes, *result) }) } } func TestScanRowsMultiColumn(t *testing.T) { t.Parallel() tests := map[string]struct { mockRows *sqlmock.Rows expRes1 string expRes2 string errAssert assert.ErrorAssertionFunc }{ "Valid": { mockRows: sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2"), expRes1: "val1", expRes2: "val2", errAssert: assert.NoError, }, "No Rows": { mockRows: sqlmock.NewRows([]string{}), expRes1: "", expRes2: "", errAssert: assert.NoError, }, "Multiple Rows": { mockRows: sqlmock.NewRows([]string{"col1", "col2"}).AddRow("val1", "val2").AddRow("oops1", "oops2"), expRes1: "val1", expRes2: "val2", errAssert: EqualError("error multiple rows returned in data:scanRows"), }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() rows, err := mockRowsToSQLRows(tc.mockRows) assert.NoError(t, err) defer rows.Close() empty := "" result1 := &empty result2 := &empty err = scanRowsForIDs(rows, &result1, &result2) tc.errAssert(t, err) assert.Equal(t, tc.expRes1, *result1) assert.Equal(t, tc.expRes2, *result2) }) } } func TestIsUUID(t *testing.T) { t.Parallel() tests := map[string]struct { val string exp bool }{ "Valid UUID": { val: uuid.New().UUID, exp: true, }, "Invalid UUID": { val: "an-invalid-uuid", exp: false, }, "Empty String": { val: "", exp: false, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() assert.Equal(t, tc.exp, isUUID(tc.val)) }) } } func TestSafeStringDereference(t *testing.T) { s := "a-string" tests := map[string]struct { input *string expected string }{ "Valid": { input: &s, expected: s, }, "Empty String": { input: new(string), expected: "", }, "Nil": { input: nil, expected: "", }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() assert.Equal(t, tc.expected, safeStringDereference(tc.input)) }) } }