...

Source file src/github.com/doug-martin/goqu/v9/database.go

Documentation: github.com/doug-martin/goqu/v9

     1  package goqu
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"sync"
     7  
     8  	"github.com/doug-martin/goqu/v9/exec"
     9  )
    10  
    11  type (
    12  	Logger interface {
    13  		Printf(format string, v ...interface{})
    14  	}
    15  	// Interface for sql.DB, an interface is used so you can use with other
    16  	// libraries such as sqlx instead of the native sql.DB
    17  	SQLDatabase interface {
    18  		Begin() (*sql.Tx, error)
    19  		BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
    20  		ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    21  		PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
    22  		QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    23  		QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
    24  	}
    25  	// This struct is the wrapper for a Db. The struct delegates most calls to either an Exec instance or to the Db
    26  	// passed into the constructor.
    27  	Database struct {
    28  		logger  Logger
    29  		dialect string
    30  		// nolint: stylecheck // keep for backwards compatibility
    31  		Db     SQLDatabase
    32  		qf     exec.QueryFactory
    33  		qfOnce sync.Once
    34  	}
    35  )
    36  
    37  // This is the common entry point into goqu.
    38  //
    39  // dialect: This is the adapter dialect, you should see your database adapter for the string to use. Built in adapters
    40  // can be found at https://github.com/doug-martin/goqu/tree/master/adapters
    41  //
    42  // db: A sql.Db to use for querying the database
    43  //      import (
    44  //          "database/sql"
    45  //          "fmt"
    46  //          "github.com/doug-martin/goqu/v9"
    47  //          _ "github.com/doug-martin/goqu/v9/dialect/postgres"
    48  //          _ "github.com/lib/pq"
    49  //      )
    50  //
    51  //      func main() {
    52  //          sqlDb, err := sql.Open("postgres", "user=postgres dbname=goqupostgres sslmode=disable ")
    53  //          if err != nil {
    54  //              panic(err.Error())
    55  //          }
    56  //          db := goqu.New("postgres", sqlDb)
    57  //      }
    58  // The most commonly used Database method is From, which creates a new Dataset that uses the correct adapter and
    59  // supports queries.
    60  //          var ids []uint32
    61  //          if err := db.From("items").Where(goqu.I("id").Gt(10)).Pluck("id", &ids); err != nil {
    62  //              panic(err.Error())
    63  //          }
    64  //          fmt.Printf("%+v", ids)
    65  func newDatabase(dialect string, db SQLDatabase) *Database {
    66  	return &Database{
    67  		logger:  nil,
    68  		dialect: dialect,
    69  		Db:      db,
    70  		qf:      nil,
    71  		qfOnce:  sync.Once{},
    72  	}
    73  }
    74  
    75  // returns this databases dialect
    76  func (d *Database) Dialect() string {
    77  	return d.dialect
    78  }
    79  
    80  // Starts a new Transaction.
    81  func (d *Database) Begin() (*TxDatabase, error) {
    82  	sqlTx, err := d.Db.Begin()
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	tx := NewTx(d.dialect, sqlTx)
    87  	tx.Logger(d.logger)
    88  	return tx, nil
    89  }
    90  
    91  // Starts a new Transaction. See sql.DB#BeginTx for option description
    92  func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*TxDatabase, error) {
    93  	sqlTx, err := d.Db.BeginTx(ctx, opts)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	tx := NewTx(d.dialect, sqlTx)
    98  	tx.Logger(d.logger)
    99  	return tx, nil
   100  }
   101  
   102  // WithTx starts a new transaction and executes it in Wrap method
   103  func (d *Database) WithTx(fn func(*TxDatabase) error) error {
   104  	tx, err := d.Begin()
   105  	if err != nil {
   106  		return err
   107  	}
   108  	return tx.Wrap(func() error { return fn(tx) })
   109  }
   110  
   111  // Creates a new Dataset that uses the correct adapter and supports queries.
   112  //          var ids []uint32
   113  //          if err := db.From("items").Where(goqu.I("id").Gt(10)).Pluck("id", &ids); err != nil {
   114  //              panic(err.Error())
   115  //          }
   116  //          fmt.Printf("%+v", ids)
   117  //
   118  // from...: Sources for you dataset, could be table names (strings), a goqu.Literal or another goqu.Dataset
   119  func (d *Database) From(from ...interface{}) *SelectDataset {
   120  	return newDataset(d.dialect, d.queryFactory()).From(from...)
   121  }
   122  
   123  func (d *Database) Select(cols ...interface{}) *SelectDataset {
   124  	return newDataset(d.dialect, d.queryFactory()).Select(cols...)
   125  }
   126  
   127  func (d *Database) Update(table interface{}) *UpdateDataset {
   128  	return newUpdateDataset(d.dialect, d.queryFactory()).Table(table)
   129  }
   130  
   131  func (d *Database) Insert(table interface{}) *InsertDataset {
   132  	return newInsertDataset(d.dialect, d.queryFactory()).Into(table)
   133  }
   134  
   135  func (d *Database) Delete(table interface{}) *DeleteDataset {
   136  	return newDeleteDataset(d.dialect, d.queryFactory()).From(table)
   137  }
   138  
   139  func (d *Database) Truncate(table ...interface{}) *TruncateDataset {
   140  	return newTruncateDataset(d.dialect, d.queryFactory()).Table(table...)
   141  }
   142  
   143  // Sets the logger for to use when logging queries
   144  func (d *Database) Logger(logger Logger) {
   145  	d.logger = logger
   146  }
   147  
   148  // Logs a given operation with the specified sql and arguments
   149  func (d *Database) Trace(op, sqlString string, args ...interface{}) {
   150  	if d.logger != nil {
   151  		if sqlString != "" {
   152  			if len(args) != 0 {
   153  				d.logger.Printf("[goqu] %s [query:=`%s` args:=%+v]", op, sqlString, args)
   154  			} else {
   155  				d.logger.Printf("[goqu] %s [query:=`%s`]", op, sqlString)
   156  			}
   157  		} else {
   158  			d.logger.Printf("[goqu] %s", op)
   159  		}
   160  	}
   161  }
   162  
   163  // Uses the db to Execute the query with arguments and return the sql.Result
   164  //
   165  // query: The SQL to execute
   166  //
   167  // args...: for any placeholder parameters in the query
   168  func (d *Database) Exec(query string, args ...interface{}) (sql.Result, error) {
   169  	return d.ExecContext(context.Background(), query, args...)
   170  }
   171  
   172  // Uses the db to Execute the query with arguments and return the sql.Result
   173  //
   174  // query: The SQL to execute
   175  //
   176  // args...: for any placeholder parameters in the query
   177  func (d *Database) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   178  	d.Trace("EXEC", query, args...)
   179  	return d.Db.ExecContext(ctx, query, args...)
   180  }
   181  
   182  // Can be used to prepare a query.
   183  //
   184  // You can use this in tandem with a dataset by doing the following.
   185  //    sql, args, err := db.From("items").Where(goqu.I("id").Gt(10)).ToSQL(true)
   186  //    if err != nil{
   187  //        panic(err.Error()) //you could gracefully handle the error also
   188  //    }
   189  //    stmt, err := db.Prepare(sql)
   190  //    if err != nil{
   191  //        panic(err.Error()) //you could gracefully handle the error also
   192  //    }
   193  //    defer stmt.Close()
   194  //    rows, err := stmt.Query(args)
   195  //    if err != nil{
   196  //        panic(err.Error()) //you could gracefully handle the error also
   197  //    }
   198  //    defer rows.Close()
   199  //    for rows.Next(){
   200  //              //scan your rows
   201  //    }
   202  //    if rows.Err() != nil{
   203  //        panic(err.Error()) //you could gracefully handle the error also
   204  //    }
   205  //
   206  // query: The SQL statement to prepare.
   207  func (d *Database) Prepare(query string) (*sql.Stmt, error) {
   208  	return d.PrepareContext(context.Background(), query)
   209  }
   210  
   211  // Can be used to prepare a query.
   212  //
   213  // You can use this in tandem with a dataset by doing the following.
   214  //    sql, args, err := db.From("items").Where(goqu.I("id").Gt(10)).ToSQL(true)
   215  //    if err != nil{
   216  //        panic(err.Error()) //you could gracefully handle the error also
   217  //    }
   218  //    stmt, err := db.Prepare(sql)
   219  //    if err != nil{
   220  //        panic(err.Error()) //you could gracefully handle the error also
   221  //    }
   222  //    defer stmt.Close()
   223  //    rows, err := stmt.QueryContext(ctx, args)
   224  //    if err != nil{
   225  //        panic(err.Error()) //you could gracefully handle the error also
   226  //    }
   227  //    defer rows.Close()
   228  //    for rows.Next(){
   229  //              //scan your rows
   230  //    }
   231  //    if rows.Err() != nil{
   232  //        panic(err.Error()) //you could gracefully handle the error also
   233  //    }
   234  //
   235  // query: The SQL statement to prepare.
   236  func (d *Database) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   237  	d.Trace("PREPARE", query)
   238  	return d.Db.PrepareContext(ctx, query)
   239  }
   240  
   241  // Used to query for multiple rows.
   242  //
   243  // You can use this in tandem with a dataset by doing the following.
   244  //    sql, err := db.From("items").Where(goqu.I("id").Gt(10)).ToSQL()
   245  //    if err != nil{
   246  //        panic(err.Error()) //you could gracefully handle the error also
   247  //    }
   248  //    rows, err := stmt.Query(args)
   249  //    if err != nil{
   250  //        panic(err.Error()) //you could gracefully handle the error also
   251  //    }
   252  //    defer rows.Close()
   253  //    for rows.Next(){
   254  //              //scan your rows
   255  //    }
   256  //    if rows.Err() != nil{
   257  //        panic(err.Error()) //you could gracefully handle the error also
   258  //    }
   259  //
   260  // query: The SQL to execute
   261  //
   262  // args...: for any placeholder parameters in the query
   263  func (d *Database) Query(query string, args ...interface{}) (*sql.Rows, error) {
   264  	return d.QueryContext(context.Background(), query, args...)
   265  }
   266  
   267  // Used to query for multiple rows.
   268  //
   269  // You can use this in tandem with a dataset by doing the following.
   270  //    sql, err := db.From("items").Where(goqu.I("id").Gt(10)).ToSQL()
   271  //    if err != nil{
   272  //        panic(err.Error()) //you could gracefully handle the error also
   273  //    }
   274  //    rows, err := stmt.QueryContext(ctx, args)
   275  //    if err != nil{
   276  //        panic(err.Error()) //you could gracefully handle the error also
   277  //    }
   278  //    defer rows.Close()
   279  //    for rows.Next(){
   280  //              //scan your rows
   281  //    }
   282  //    if rows.Err() != nil{
   283  //        panic(err.Error()) //you could gracefully handle the error also
   284  //    }
   285  //
   286  // query: The SQL to execute
   287  //
   288  // args...: for any placeholder parameters in the query
   289  func (d *Database) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   290  	d.Trace("QUERY", query, args...)
   291  	return d.Db.QueryContext(ctx, query, args...)
   292  }
   293  
   294  // Used to query for a single row.
   295  //
   296  // You can use this in tandem with a dataset by doing the following.
   297  //    sql, err := db.From("items").Where(goqu.I("id").Gt(10)).Limit(1).ToSQL()
   298  //    if err != nil{
   299  //        panic(err.Error()) //you could gracefully handle the error also
   300  //    }
   301  //    rows, err := stmt.QueryRow(args)
   302  //    if err != nil{
   303  //        panic(err.Error()) //you could gracefully handle the error also
   304  //    }
   305  //    //scan your row
   306  //
   307  // query: The SQL to execute
   308  //
   309  // args...: for any placeholder parameters in the query
   310  func (d *Database) QueryRow(query string, args ...interface{}) *sql.Row {
   311  	return d.QueryRowContext(context.Background(), query, args...)
   312  }
   313  
   314  // Used to query for a single row.
   315  //
   316  // You can use this in tandem with a dataset by doing the following.
   317  //    sql, err := db.From("items").Where(goqu.I("id").Gt(10)).Limit(1).ToSQL()
   318  //    if err != nil{
   319  //        panic(err.Error()) //you could gracefully handle the error also
   320  //    }
   321  //    rows, err := stmt.QueryRowContext(ctx, args)
   322  //    if err != nil{
   323  //        panic(err.Error()) //you could gracefully handle the error also
   324  //    }
   325  //    //scan your row
   326  //
   327  // query: The SQL to execute
   328  //
   329  // args...: for any placeholder parameters in the query
   330  func (d *Database) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   331  	d.Trace("QUERY ROW", query, args...)
   332  	return d.Db.QueryRowContext(ctx, query, args...)
   333  }
   334  
   335  func (d *Database) queryFactory() exec.QueryFactory {
   336  	d.qfOnce.Do(func() {
   337  		d.qf = exec.NewQueryFactory(d)
   338  	})
   339  	return d.qf
   340  }
   341  
   342  // Queries the database using the supplied query, and args and uses CrudExec.ScanStructs to scan the results into a
   343  // slice of structs
   344  //
   345  // i: A pointer to a slice of structs
   346  //
   347  // query: The SQL to execute
   348  //
   349  // args...: for any placeholder parameters in the query
   350  func (d *Database) ScanStructs(i interface{}, query string, args ...interface{}) error {
   351  	return d.ScanStructsContext(context.Background(), i, query, args...)
   352  }
   353  
   354  // Queries the database using the supplied context, query, and args and uses CrudExec.ScanStructsContext to scan the
   355  // results into a slice of structs
   356  //
   357  // i: A pointer to a slice of structs
   358  //
   359  // query: The SQL to execute
   360  //
   361  // args...: for any placeholder parameters in the query
   362  func (d *Database) ScanStructsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
   363  	return d.queryFactory().FromSQL(query, args...).ScanStructsContext(ctx, i)
   364  }
   365  
   366  // Queries the database using the supplied query, and args and uses CrudExec.ScanStruct to scan the results into a
   367  // struct
   368  //
   369  // i: A pointer to a struct
   370  //
   371  // query: The SQL to execute
   372  //
   373  // args...: for any placeholder parameters in the query
   374  func (d *Database) ScanStruct(i interface{}, query string, args ...interface{}) (bool, error) {
   375  	return d.ScanStructContext(context.Background(), i, query, args...)
   376  }
   377  
   378  // Queries the database using the supplied context, query, and args and uses CrudExec.ScanStructContext to scan the
   379  // results into a struct
   380  //
   381  // i: A pointer to a struct
   382  //
   383  // query: The SQL to execute
   384  //
   385  // args...: for any placeholder parameters in the query
   386  func (d *Database) ScanStructContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
   387  	return d.queryFactory().FromSQL(query, args...).ScanStructContext(ctx, i)
   388  }
   389  
   390  // Queries the database using the supplied query, and args and uses CrudExec.ScanVals to scan the results into a slice
   391  // of primitive values
   392  //
   393  // i: A pointer to a slice of primitive values
   394  //
   395  // query: The SQL to execute
   396  //
   397  // args...: for any placeholder parameters in the query
   398  func (d *Database) ScanVals(i interface{}, query string, args ...interface{}) error {
   399  	return d.ScanValsContext(context.Background(), i, query, args...)
   400  }
   401  
   402  // Queries the database using the supplied context, query, and args and uses CrudExec.ScanValsContext to scan the
   403  // results into a slice of primitive values
   404  //
   405  // i: A pointer to a slice of primitive values
   406  //
   407  // query: The SQL to execute
   408  //
   409  // args...: for any placeholder parameters in the query
   410  func (d *Database) ScanValsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
   411  	return d.queryFactory().FromSQL(query, args...).ScanValsContext(ctx, i)
   412  }
   413  
   414  // Queries the database using the supplied query, and args and uses CrudExec.ScanVal to scan the results into a
   415  // primitive value
   416  //
   417  // i: A pointer to a primitive value
   418  //
   419  // query: The SQL to execute
   420  //
   421  // args...: for any placeholder parameters in the query
   422  func (d *Database) ScanVal(i interface{}, query string, args ...interface{}) (bool, error) {
   423  	return d.ScanValContext(context.Background(), i, query, args...)
   424  }
   425  
   426  // Queries the database using the supplied context, query, and args and uses CrudExec.ScanValContext to scan the
   427  // results into a primitive value
   428  //
   429  // i: A pointer to a primitive value
   430  //
   431  // query: The SQL to execute
   432  //
   433  // args...: for any placeholder parameters in the query
   434  func (d *Database) ScanValContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
   435  	return d.queryFactory().FromSQL(query, args...).ScanValContext(ctx, i)
   436  }
   437  
   438  // A wrapper around a sql.Tx and works the same way as Database
   439  type (
   440  	// Interface for sql.Tx, an interface is used so you can use with other
   441  	// libraries such as sqlx instead of the native sql.DB
   442  	SQLTx interface {
   443  		ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
   444  		PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
   445  		QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
   446  		QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
   447  		Commit() error
   448  		Rollback() error
   449  	}
   450  	TxDatabase struct {
   451  		logger  Logger
   452  		dialect string
   453  		Tx      SQLTx
   454  		qf      exec.QueryFactory
   455  		qfOnce  sync.Once
   456  	}
   457  )
   458  
   459  // Creates a new TxDatabase
   460  func NewTx(dialect string, tx SQLTx) *TxDatabase {
   461  	return &TxDatabase{dialect: dialect, Tx: tx}
   462  }
   463  
   464  // returns this databases dialect
   465  func (td *TxDatabase) Dialect() string {
   466  	return td.dialect
   467  }
   468  
   469  // Creates a new Dataset for querying a Database.
   470  func (td *TxDatabase) From(cols ...interface{}) *SelectDataset {
   471  	return newDataset(td.dialect, td.queryFactory()).From(cols...)
   472  }
   473  
   474  func (td *TxDatabase) Select(cols ...interface{}) *SelectDataset {
   475  	return newDataset(td.dialect, td.queryFactory()).Select(cols...)
   476  }
   477  
   478  func (td *TxDatabase) Update(table interface{}) *UpdateDataset {
   479  	return newUpdateDataset(td.dialect, td.queryFactory()).Table(table)
   480  }
   481  
   482  func (td *TxDatabase) Insert(table interface{}) *InsertDataset {
   483  	return newInsertDataset(td.dialect, td.queryFactory()).Into(table)
   484  }
   485  
   486  func (td *TxDatabase) Delete(table interface{}) *DeleteDataset {
   487  	return newDeleteDataset(td.dialect, td.queryFactory()).From(table)
   488  }
   489  
   490  func (td *TxDatabase) Truncate(table ...interface{}) *TruncateDataset {
   491  	return newTruncateDataset(td.dialect, td.queryFactory()).Table(table...)
   492  }
   493  
   494  // Sets the logger
   495  func (td *TxDatabase) Logger(logger Logger) {
   496  	td.logger = logger
   497  }
   498  
   499  func (td *TxDatabase) Trace(op, sqlString string, args ...interface{}) {
   500  	if td.logger != nil {
   501  		if sqlString != "" {
   502  			if len(args) != 0 {
   503  				td.logger.Printf("[goqu - transaction] %s [query:=`%s` args:=%+v] ", op, sqlString, args)
   504  			} else {
   505  				td.logger.Printf("[goqu - transaction] %s [query:=`%s`] ", op, sqlString)
   506  			}
   507  		} else {
   508  			td.logger.Printf("[goqu - transaction] %s", op)
   509  		}
   510  	}
   511  }
   512  
   513  // See Database#Exec
   514  func (td *TxDatabase) Exec(query string, args ...interface{}) (sql.Result, error) {
   515  	return td.ExecContext(context.Background(), query, args...)
   516  }
   517  
   518  // See Database#ExecContext
   519  func (td *TxDatabase) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   520  	td.Trace("EXEC", query, args...)
   521  	return td.Tx.ExecContext(ctx, query, args...)
   522  }
   523  
   524  // See Database#Prepare
   525  func (td *TxDatabase) Prepare(query string) (*sql.Stmt, error) {
   526  	return td.PrepareContext(context.Background(), query)
   527  }
   528  
   529  // See Database#PrepareContext
   530  func (td *TxDatabase) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   531  	td.Trace("PREPARE", query)
   532  	return td.Tx.PrepareContext(ctx, query)
   533  }
   534  
   535  // See Database#Query
   536  func (td *TxDatabase) Query(query string, args ...interface{}) (*sql.Rows, error) {
   537  	return td.QueryContext(context.Background(), query, args...)
   538  }
   539  
   540  // See Database#QueryContext
   541  func (td *TxDatabase) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   542  	td.Trace("QUERY", query, args...)
   543  	return td.Tx.QueryContext(ctx, query, args...)
   544  }
   545  
   546  // See Database#QueryRow
   547  func (td *TxDatabase) QueryRow(query string, args ...interface{}) *sql.Row {
   548  	return td.QueryRowContext(context.Background(), query, args...)
   549  }
   550  
   551  // See Database#QueryRowContext
   552  func (td *TxDatabase) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   553  	td.Trace("QUERY ROW", query, args...)
   554  	return td.Tx.QueryRowContext(ctx, query, args...)
   555  }
   556  
   557  func (td *TxDatabase) queryFactory() exec.QueryFactory {
   558  	td.qfOnce.Do(func() {
   559  		td.qf = exec.NewQueryFactory(td)
   560  	})
   561  	return td.qf
   562  }
   563  
   564  // See Database#ScanStructs
   565  func (td *TxDatabase) ScanStructs(i interface{}, query string, args ...interface{}) error {
   566  	return td.ScanStructsContext(context.Background(), i, query, args...)
   567  }
   568  
   569  // See Database#ScanStructsContext
   570  func (td *TxDatabase) ScanStructsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
   571  	return td.queryFactory().FromSQL(query, args...).ScanStructsContext(ctx, i)
   572  }
   573  
   574  // See Database#ScanStruct
   575  func (td *TxDatabase) ScanStruct(i interface{}, query string, args ...interface{}) (bool, error) {
   576  	return td.ScanStructContext(context.Background(), i, query, args...)
   577  }
   578  
   579  // See Database#ScanStructContext
   580  func (td *TxDatabase) ScanStructContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
   581  	return td.queryFactory().FromSQL(query, args...).ScanStructContext(ctx, i)
   582  }
   583  
   584  // See Database#ScanVals
   585  func (td *TxDatabase) ScanVals(i interface{}, query string, args ...interface{}) error {
   586  	return td.ScanValsContext(context.Background(), i, query, args...)
   587  }
   588  
   589  // See Database#ScanValsContext
   590  func (td *TxDatabase) ScanValsContext(ctx context.Context, i interface{}, query string, args ...interface{}) error {
   591  	return td.queryFactory().FromSQL(query, args...).ScanValsContext(ctx, i)
   592  }
   593  
   594  // See Database#ScanVal
   595  func (td *TxDatabase) ScanVal(i interface{}, query string, args ...interface{}) (bool, error) {
   596  	return td.ScanValContext(context.Background(), i, query, args...)
   597  }
   598  
   599  // See Database#ScanValContext
   600  func (td *TxDatabase) ScanValContext(ctx context.Context, i interface{}, query string, args ...interface{}) (bool, error) {
   601  	return td.queryFactory().FromSQL(query, args...).ScanValContext(ctx, i)
   602  }
   603  
   604  // COMMIT the transaction
   605  func (td *TxDatabase) Commit() error {
   606  	td.Trace("COMMIT", "")
   607  	return td.Tx.Commit()
   608  }
   609  
   610  // ROLLBACK the transaction
   611  func (td *TxDatabase) Rollback() error {
   612  	td.Trace("ROLLBACK", "")
   613  	return td.Tx.Rollback()
   614  }
   615  
   616  // A helper method that will automatically COMMIT or ROLLBACK once the supplied function is done executing
   617  //
   618  //      tx, err := db.Begin()
   619  //      if err != nil{
   620  //           panic(err.Error()) // you could gracefully handle the error also
   621  //      }
   622  //      if err := tx.Wrap(func() error{
   623  //          if _, err := tx.From("test").Insert(Record{"a":1, "b": "b"}).Exec(){
   624  //              // this error will be the return error from the Wrap call
   625  //              return err
   626  //          }
   627  //          return nil
   628  //      }); err != nil{
   629  //           panic(err.Error()) // you could gracefully handle the error also
   630  //      }
   631  func (td *TxDatabase) Wrap(fn func() error) (err error) {
   632  	defer func() {
   633  		if p := recover(); p != nil {
   634  			_ = td.Rollback()
   635  			panic(p)
   636  		}
   637  		if err != nil {
   638  			if rollbackErr := td.Rollback(); rollbackErr != nil {
   639  				err = rollbackErr
   640  			}
   641  		} else {
   642  			if commitErr := td.Commit(); commitErr != nil {
   643  				err = commitErr
   644  			}
   645  		}
   646  	}()
   647  	return fn()
   648  }
   649  

View as plain text