...

Source file src/github.com/DATA-DOG/go-sqlmock/rows.go

Documentation: github.com/DATA-DOG/go-sqlmock

     1  package sqlmock
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql/driver"
     6  	"encoding/csv"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"strings"
    11  )
    12  
    13  const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
    14  
    15  // CSVColumnParser is a function which converts trimmed csv
    16  // column string to a []byte representation. Currently
    17  // transforms NULL to nil
    18  var CSVColumnParser = func(s string) interface{} {
    19  	switch {
    20  	case strings.ToLower(s) == "null":
    21  		return nil
    22  	}
    23  	return []byte(s)
    24  }
    25  
    26  type rowSets struct {
    27  	sets []*Rows
    28  	pos  int
    29  	ex   *ExpectedQuery
    30  	raw  [][]byte
    31  }
    32  
    33  func (rs *rowSets) Columns() []string {
    34  	return rs.sets[rs.pos].cols
    35  }
    36  
    37  func (rs *rowSets) Close() error {
    38  	rs.invalidateRaw()
    39  	rs.ex.rowsWereClosed = true
    40  	return rs.sets[rs.pos].closeErr
    41  }
    42  
    43  // advances to next row
    44  func (rs *rowSets) Next(dest []driver.Value) error {
    45  	r := rs.sets[rs.pos]
    46  	r.pos++
    47  	rs.invalidateRaw()
    48  	if r.pos > len(r.rows) {
    49  		return io.EOF // per interface spec
    50  	}
    51  
    52  	for i, col := range r.rows[r.pos-1] {
    53  		if b, ok := rawBytes(col); ok {
    54  			rs.raw = append(rs.raw, b)
    55  			dest[i] = b
    56  			continue
    57  		}
    58  		dest[i] = col
    59  	}
    60  
    61  	return r.nextErr[r.pos-1]
    62  }
    63  
    64  // transforms to debuggable printable string
    65  func (rs *rowSets) String() string {
    66  	if rs.empty() {
    67  		return "with empty rows"
    68  	}
    69  
    70  	msg := "should return rows:\n"
    71  	if len(rs.sets) == 1 {
    72  		for n, row := range rs.sets[0].rows {
    73  			msg += fmt.Sprintf("    row %d - %+v\n", n, row)
    74  		}
    75  		return strings.TrimSpace(msg)
    76  	}
    77  	for i, set := range rs.sets {
    78  		msg += fmt.Sprintf("    result set: %d\n", i)
    79  		for n, row := range set.rows {
    80  			msg += fmt.Sprintf("      row %d - %+v\n", n, row)
    81  		}
    82  	}
    83  	return strings.TrimSpace(msg)
    84  }
    85  
    86  func (rs *rowSets) empty() bool {
    87  	for _, set := range rs.sets {
    88  		if len(set.rows) > 0 {
    89  			return false
    90  		}
    91  	}
    92  	return true
    93  }
    94  
    95  func rawBytes(col driver.Value) (_ []byte, ok bool) {
    96  	val, ok := col.([]byte)
    97  	if !ok || len(val) == 0 {
    98  		return nil, false
    99  	}
   100  	// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
   101  	// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
   102  	b := make([]byte, len(val))
   103  	copy(b, val)
   104  	return b, true
   105  }
   106  
   107  // Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
   108  // If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
   109  func (rs *rowSets) invalidateRaw() {
   110  	// Replace the content of slices previously returned
   111  	b := []byte(invalidate)
   112  	for _, r := range rs.raw {
   113  		copy(r, bytes.Repeat(b, len(r)/len(b)+1))
   114  	}
   115  	// Start with new slices for the next scan
   116  	rs.raw = nil
   117  }
   118  
   119  // Rows is a mocked collection of rows to
   120  // return for Query result
   121  type Rows struct {
   122  	converter driver.ValueConverter
   123  	cols      []string
   124  	def       []*Column
   125  	rows      [][]driver.Value
   126  	pos       int
   127  	nextErr   map[int]error
   128  	closeErr  error
   129  }
   130  
   131  // NewRows allows Rows to be created from a
   132  // sql driver.Value slice or from the CSV string and
   133  // to be used as sql driver.Rows.
   134  // Use Sqlmock.NewRows instead if using a custom converter
   135  func NewRows(columns []string) *Rows {
   136  	return &Rows{
   137  		cols:      columns,
   138  		nextErr:   make(map[int]error),
   139  		converter: driver.DefaultParameterConverter,
   140  	}
   141  }
   142  
   143  // CloseError allows to set an error
   144  // which will be returned by rows.Close
   145  // function.
   146  //
   147  // The close error will be triggered only in cases
   148  // when rows.Next() EOF was not yet reached, that is
   149  // a default sql library behavior
   150  func (r *Rows) CloseError(err error) *Rows {
   151  	r.closeErr = err
   152  	return r
   153  }
   154  
   155  // RowError allows to set an error
   156  // which will be returned when a given
   157  // row number is read
   158  func (r *Rows) RowError(row int, err error) *Rows {
   159  	r.nextErr[row] = err
   160  	return r
   161  }
   162  
   163  // AddRow composed from database driver.Value slice
   164  // return the same instance to perform subsequent actions.
   165  // Note that the number of values must match the number
   166  // of columns
   167  func (r *Rows) AddRow(values ...driver.Value) *Rows {
   168  	if len(values) != len(r.cols) {
   169  		panic(fmt.Sprintf("Expected number of values to match number of columns: expected %d, actual %d", len(values), len(r.cols)))
   170  	}
   171  
   172  	row := make([]driver.Value, len(r.cols))
   173  	for i, v := range values {
   174  		// Convert user-friendly values (such as int or driver.Valuer)
   175  		// to database/sql native value (driver.Value such as int64)
   176  		var err error
   177  		v, err = r.converter.ConvertValue(v)
   178  		if err != nil {
   179  			panic(fmt.Errorf(
   180  				"row #%d, column #%d (%q) type %T: %s",
   181  				len(r.rows)+1, i, r.cols[i], values[i], err,
   182  			))
   183  		}
   184  
   185  		row[i] = v
   186  	}
   187  
   188  	r.rows = append(r.rows, row)
   189  	return r
   190  }
   191  
   192  // AddRows adds multiple rows composed from database driver.Value slice and
   193  // returns the same instance to perform subsequent actions.
   194  func (r *Rows) AddRows(values ...[]driver.Value) *Rows {
   195  	for _, value := range values {
   196  		r.AddRow(value...)
   197  	}
   198  
   199  	return r
   200  }
   201  
   202  // FromCSVString build rows from csv string.
   203  // return the same instance to perform subsequent actions.
   204  // Note that the number of values must match the number
   205  // of columns
   206  func (r *Rows) FromCSVString(s string) *Rows {
   207  	res := strings.NewReader(strings.TrimSpace(s))
   208  	csvReader := csv.NewReader(res)
   209  
   210  	for {
   211  		res, err := csvReader.Read()
   212  		if err != nil {
   213  			if errors.Is(err, io.EOF) {
   214  				break
   215  			}
   216  			panic(fmt.Sprintf("Parsing CSV string failed: %s", err.Error()))
   217  		}
   218  
   219  		row := make([]driver.Value, len(r.cols))
   220  		for i, v := range res {
   221  			row[i] = CSVColumnParser(strings.TrimSpace(v))
   222  		}
   223  		r.rows = append(r.rows, row)
   224  	}
   225  	return r
   226  }
   227  

View as plain text