package database import ( "context" "database/sql" "fmt" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" rulesengine "edge-infra.dev/pkg/sds/emergencyaccess/rules" datasql "edge-infra.dev/pkg/sds/emergencyaccess/rules/storage/database/sql" ) func initMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } return db, mock } func TestAddNamesDefer(t *testing.T) { t.Parallel() tests := map[string]struct { names []string expRes rulesengine.AddNameResult expectations func(mock sqlmock.Sqlmock) errorAssertion assert.ErrorAssertionFunc }{ "Successful commit": { names: nil, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectCommit().WillReturnError(nil) }, errorAssertion: assert.NoError, }, "Commit Returns Error": { names: nil, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectCommit().WillReturnError(fmt.Errorf("commit error")) }, errorAssertion: EqualError("error committing transaction: commit error"), }, "Begin and do nothing on conflict": { names: []string{"a"}, expRes: rulesengine.AddNameResult{}, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(datasql.InsertCommand). WithArgs("a").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() }, errorAssertion: assert.NoError, }, "Begin and add name, no conflict": { names: []string{"a"}, expRes: rulesengine.AddNameResult{}, expectations: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(datasql.InsertCommand). WithArgs("a").WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, errorAssertion: assert.NoError, }, } for name, tc := range tests { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() db, mock := initMockDB(t) defer db.Close() ds := Dataset{db: db} tc.expectations(mock) res, err := ds.addNames(context.Background(), tc.names, datasql.InsertCommand) tc.errorAssertion(t, err) assert.Equal(t, tc.expRes, res) assert.NoError(t, mock.ExpectationsWereMet()) }) } }