...

Source file src/github.com/DATA-DOG/go-sqlmock/rows_go18_test.go

Documentation: github.com/DATA-DOG/go-sqlmock

     1  // +build go1.8
     2  
     3  package sqlmock
     4  
     5  import (
     6  	"database/sql"
     7  	"encoding/json"
     8  	"fmt"
     9  	"reflect"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  func TestQueryMultiRows(t *testing.T) {
    15  	t.Parallel()
    16  	db, mock, err := New()
    17  	if err != nil {
    18  		t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
    19  	}
    20  	defer db.Close()
    21  
    22  	rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
    23  	rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
    24  
    25  	mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
    26  		WithArgs(5).
    27  		WillReturnRows(rs1, rs2)
    28  
    29  	rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
    30  	if err != nil {
    31  		t.Errorf("error was not expected, but got: %v", err)
    32  	}
    33  	defer rows.Close()
    34  
    35  	if !rows.Next() {
    36  		t.Error("expected a row to be available in first result set")
    37  	}
    38  
    39  	var id int
    40  	var name string
    41  
    42  	err = rows.Scan(&id, &name)
    43  	if err != nil {
    44  		t.Errorf("error was not expected, but got: %v", err)
    45  	}
    46  
    47  	if id != 5 || name != "hello world" {
    48  		t.Errorf("unexpected row values id: %v name: %v", id, name)
    49  	}
    50  
    51  	if rows.Next() {
    52  		t.Error("was not expecting next row in first result set")
    53  	}
    54  
    55  	if !rows.NextResultSet() {
    56  		t.Error("had to have next result set")
    57  	}
    58  
    59  	if !rows.Next() {
    60  		t.Error("expected a row to be available in second result set")
    61  	}
    62  
    63  	err = rows.Scan(&name)
    64  	if err != nil {
    65  		t.Errorf("error was not expected, but got: %v", err)
    66  	}
    67  
    68  	if name != "gopher" {
    69  		t.Errorf("unexpected row name: %v", name)
    70  	}
    71  
    72  	if !rows.Next() {
    73  		t.Error("expected a row to be available in second result set")
    74  	}
    75  
    76  	err = rows.Scan(&name)
    77  	if err != nil {
    78  		t.Errorf("error was not expected, but got: %v", err)
    79  	}
    80  
    81  	if name != "john" {
    82  		t.Errorf("unexpected row name: %v", name)
    83  	}
    84  
    85  	if rows.Next() {
    86  		t.Error("expected next row to produce error")
    87  	}
    88  
    89  	if rows.Err() == nil {
    90  		t.Error("expected an error, but there was none")
    91  	}
    92  
    93  	if err := mock.ExpectationsWereMet(); err != nil {
    94  		t.Errorf("there were unfulfilled expectations: %s", err)
    95  	}
    96  }
    97  
    98  func TestQueryRowBytesInvalidatedByNext_jsonRawMessageIntoRawBytes(t *testing.T) {
    99  	t.Parallel()
   100  	replace := []byte(invalid)
   101  	rows := NewRows([]string{"raw"}).
   102  		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
   103  		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
   104  	scan := func(rs *sql.Rows) ([]byte, error) {
   105  		var raw sql.RawBytes
   106  		return raw, rs.Scan(&raw)
   107  	}
   108  	want := []struct {
   109  		Initial  []byte
   110  		Replaced []byte
   111  	}{
   112  		{Initial: []byte(`{"thing": "one", "thing2": "two"}`), Replaced: replace[:len(replace)-6]},
   113  		{Initial: []byte(`{"that": "foo", "this": "bar"}`), Replaced: replace[:len(replace)-9]},
   114  	}
   115  	queryRowBytesInvalidatedByNext(t, rows, scan, want)
   116  }
   117  
   118  func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoBytes(t *testing.T) {
   119  	t.Parallel()
   120  	rows := NewRows([]string{"raw"}).
   121  		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
   122  		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
   123  	scan := func(rs *sql.Rows) ([]byte, error) {
   124  		var b []byte
   125  		return b, rs.Scan(&b)
   126  	}
   127  	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
   128  	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
   129  }
   130  
   131  func TestQueryRowBytesNotInvalidatedByNext_bytesIntoCustomBytes(t *testing.T) {
   132  	t.Parallel()
   133  	rows := NewRows([]string{"raw"}).
   134  		AddRow([]byte(`one binary value with some text!`)).
   135  		AddRow([]byte(`two binary value with even more text than the first one`))
   136  	scan := func(rs *sql.Rows) ([]byte, error) {
   137  		type customBytes []byte
   138  		var b customBytes
   139  		return b, rs.Scan(&b)
   140  	}
   141  	want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)}
   142  	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
   143  }
   144  
   145  func TestQueryRowBytesNotInvalidatedByNext_jsonRawMessageIntoCustomBytes(t *testing.T) {
   146  	t.Parallel()
   147  	rows := NewRows([]string{"raw"}).
   148  		AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`)).
   149  		AddRow(json.RawMessage(`{"that": "foo", "this": "bar"}`))
   150  	scan := func(rs *sql.Rows) ([]byte, error) {
   151  		type customBytes []byte
   152  		var b customBytes
   153  		return b, rs.Scan(&b)
   154  	}
   155  	want := [][]byte{[]byte(`{"thing": "one", "thing2": "two"}`), []byte(`{"that": "foo", "this": "bar"}`)}
   156  	queryRowBytesNotInvalidatedByNext(t, rows, scan, want)
   157  }
   158  
   159  func TestQueryRowBytesNotInvalidatedByClose_bytesIntoCustomBytes(t *testing.T) {
   160  	t.Parallel()
   161  	rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`))
   162  	scan := func(rs *sql.Rows) ([]byte, error) {
   163  		type customBytes []byte
   164  		var b customBytes
   165  		return b, rs.Scan(&b)
   166  	}
   167  	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`))
   168  }
   169  
   170  func TestQueryRowBytesInvalidatedByClose_jsonRawMessageIntoRawBytes(t *testing.T) {
   171  	t.Parallel()
   172  	replace := []byte(invalid)
   173  	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
   174  	scan := func(rs *sql.Rows) ([]byte, error) {
   175  		var raw sql.RawBytes
   176  		return raw, rs.Scan(&raw)
   177  	}
   178  	want := struct {
   179  		Initial  []byte
   180  		Replaced []byte
   181  	}{
   182  		Initial:  []byte(`{"thing": "one", "thing2": "two"}`),
   183  		Replaced: replace[:len(replace)-6],
   184  	}
   185  	queryRowBytesInvalidatedByClose(t, rows, scan, want)
   186  }
   187  
   188  func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoBytes(t *testing.T) {
   189  	t.Parallel()
   190  	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
   191  	scan := func(rs *sql.Rows) ([]byte, error) {
   192  		var b []byte
   193  		return b, rs.Scan(&b)
   194  	}
   195  	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
   196  }
   197  
   198  func TestQueryRowBytesNotInvalidatedByClose_jsonRawMessageIntoCustomBytes(t *testing.T) {
   199  	t.Parallel()
   200  	rows := NewRows([]string{"raw"}).AddRow(json.RawMessage(`{"thing": "one", "thing2": "two"}`))
   201  	scan := func(rs *sql.Rows) ([]byte, error) {
   202  		type customBytes []byte
   203  		var b customBytes
   204  		return b, rs.Scan(&b)
   205  	}
   206  	queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`{"thing": "one", "thing2": "two"}`))
   207  }
   208  
   209  func TestNewColumnWithDefinition(t *testing.T) {
   210  	now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
   211  
   212  	t.Run("with one ResultSet", func(t *testing.T) {
   213  		db, mock, _ := New()
   214  		column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
   215  		column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
   216  		column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
   217  		rows := mock.NewRowsWithColumnDefinition(column1, column2, column3)
   218  		rows.AddRow("foo.bar", float64(10.123), now)
   219  
   220  		mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
   221  		isQuery := mQuery.WillReturnRows(rows)
   222  		isQueryClosed := mQuery.RowsWillBeClosed()
   223  		isDbClosed := mock.ExpectClose()
   224  
   225  		query, _ := db.Query("SELECT test, number, when from dummy")
   226  
   227  		if false == isQuery.fulfilled() {
   228  			t.Error("Query is not executed")
   229  		}
   230  
   231  		if query.Next() {
   232  			var test string
   233  			var number float64
   234  			var when time.Time
   235  
   236  			if queryError := query.Scan(&test, &number, &when); queryError != nil {
   237  				t.Error(queryError)
   238  			} else if test != "foo.bar" {
   239  				t.Error("field test is not 'foo.bar'")
   240  			} else if number != float64(10.123) {
   241  				t.Error("field number is not '10.123'")
   242  			} else if when != now {
   243  				t.Errorf("field when is not %v", now)
   244  			}
   245  
   246  			if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
   247  				t.Error(colTypErr)
   248  			} else if len(columnTypes) != 3 {
   249  				t.Error("number of columnTypes")
   250  			} else if name := columnTypes[0].Name(); name != "test" {
   251  				t.Errorf("field 'test' has a wrong name '%s'", name)
   252  			} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
   253  				t.Errorf("field 'test' has a wrong db type '%s'", dbType)
   254  			} else if columnTypes[0].ScanType().Kind() != reflect.String {
   255  				t.Error("field 'test' has a wrong scanType")
   256  			} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
   257  				t.Error("field 'test' should have not precision, scale")
   258  			} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
   259  				t.Errorf("field 'test' has a wrong length '%d'", length)
   260  			} else if name := columnTypes[1].Name(); name != "number" {
   261  				t.Errorf("field 'number' has a wrong name '%s'", name)
   262  			} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
   263  				t.Errorf("field 'number' has a wrong db type '%s'", dbType)
   264  			} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
   265  				t.Error("field 'number' has a wrong scanType")
   266  			} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
   267  				t.Error("field 'number' has a wrong precision, scale")
   268  			} else if _, ok := columnTypes[1].Length(); ok {
   269  				t.Error("field 'number' is not variable length type")
   270  			} else if _, ok := columnTypes[2].Nullable(); ok {
   271  				t.Error("field 'when' should have nullability unknown")
   272  			}
   273  		} else {
   274  			t.Error("no result set")
   275  		}
   276  
   277  		query.Close()
   278  		if false == isQueryClosed.fulfilled() {
   279  			t.Error("Query is not executed")
   280  		}
   281  
   282  		db.Close()
   283  		if false == isDbClosed.fulfilled() {
   284  			t.Error("Db is not closed")
   285  		}
   286  	})
   287  
   288  	t.Run("with more then one ResultSet", func(t *testing.T) {
   289  		db, mock, _ := New()
   290  		column1 := mock.NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
   291  		column2 := mock.NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
   292  		column3 := mock.NewColumn("when").OfType("TIMESTAMP", now)
   293  		rows1 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
   294  		rows1.AddRow("foo.bar", float64(10.123), now)
   295  		rows2 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
   296  		rows2.AddRow("bar.foo", float64(123.10), now.Add(time.Second*10))
   297  		rows3 := mock.NewRowsWithColumnDefinition(column1, column2, column3)
   298  		rows3.AddRow("lollipop", float64(10.321), now.Add(time.Second*20))
   299  
   300  		mQuery := mock.ExpectQuery("SELECT test, number, when from dummy")
   301  		isQuery := mQuery.WillReturnRows(rows1, rows2, rows3)
   302  		isQueryClosed := mQuery.RowsWillBeClosed()
   303  		isDbClosed := mock.ExpectClose()
   304  
   305  		query, _ := db.Query("SELECT test, number, when from dummy")
   306  
   307  		if false == isQuery.fulfilled() {
   308  			t.Error("Query is not executed")
   309  		}
   310  
   311  		rowsSi := 0
   312  
   313  		for query.Next() {
   314  			var test string
   315  			var number float64
   316  			var when time.Time
   317  
   318  			if queryError := query.Scan(&test, &number, &when); queryError != nil {
   319  				t.Error(queryError)
   320  
   321  			} else if rowsSi == 0 && test != "foo.bar" {
   322  				t.Error("field test is not 'foo.bar'")
   323  			} else if rowsSi == 0 && number != float64(10.123) {
   324  				t.Error("field number is not '10.123'")
   325  			} else if rowsSi == 0 && when != now {
   326  				t.Errorf("field when is not %v", now)
   327  
   328  			} else if rowsSi == 1 && test != "bar.foo" {
   329  				t.Error("field test is not 'bar.bar'")
   330  			} else if rowsSi == 1 && number != float64(123.10) {
   331  				t.Error("field number is not '123.10'")
   332  			} else if rowsSi == 1 && when != now.Add(time.Second*10) {
   333  				t.Errorf("field when is not %v", now)
   334  
   335  			} else if rowsSi == 2 && test != "lollipop" {
   336  				t.Error("field test is not 'lollipop'")
   337  			} else if rowsSi == 2 && number != float64(10.321) {
   338  				t.Error("field number is not '10.321'")
   339  			} else if rowsSi == 2 && when != now.Add(time.Second*20) {
   340  				t.Errorf("field when is not %v", now)
   341  			}
   342  
   343  			rowsSi++
   344  
   345  			if columnTypes, colTypErr := query.ColumnTypes(); colTypErr != nil {
   346  				t.Error(colTypErr)
   347  			} else if len(columnTypes) != 3 {
   348  				t.Error("number of columnTypes")
   349  			} else if name := columnTypes[0].Name(); name != "test" {
   350  				t.Errorf("field 'test' has a wrong name '%s'", name)
   351  			} else if dbType := columnTypes[0].DatabaseTypeName(); dbType != "VARCHAR" {
   352  				t.Errorf("field 'test' has a wrong db type '%s'", dbType)
   353  			} else if columnTypes[0].ScanType().Kind() != reflect.String {
   354  				t.Error("field 'test' has a wrong scanType")
   355  			} else if _, _, ok := columnTypes[0].DecimalSize(); ok {
   356  				t.Error("field 'test' should not have precision, scale")
   357  			} else if length, ok := columnTypes[0].Length(); length != 100 || !ok {
   358  				t.Errorf("field 'test' has a wrong length '%d'", length)
   359  			} else if name := columnTypes[1].Name(); name != "number" {
   360  				t.Errorf("field 'number' has a wrong name '%s'", name)
   361  			} else if dbType := columnTypes[1].DatabaseTypeName(); dbType != "DECIMAL" {
   362  				t.Errorf("field 'number' has a wrong db type '%s'", dbType)
   363  			} else if columnTypes[1].ScanType().Kind() != reflect.Float64 {
   364  				t.Error("field 'number' has a wrong scanType")
   365  			} else if precision, scale, ok := columnTypes[1].DecimalSize(); precision != int64(10) || scale != int64(4) || !ok {
   366  				t.Error("field 'number' has a wrong precision, scale")
   367  			} else if _, ok := columnTypes[1].Length(); ok {
   368  				t.Error("field 'number' is not variable length type")
   369  			} else if _, ok := columnTypes[2].Nullable(); ok {
   370  				t.Error("field 'when' should have nullability unknown")
   371  			}
   372  		}
   373  		if rowsSi == 0 {
   374  			t.Error("no result set")
   375  		}
   376  
   377  		query.Close()
   378  		if false == isQueryClosed.fulfilled() {
   379  			t.Error("Query is not executed")
   380  		}
   381  
   382  		db.Close()
   383  		if false == isDbClosed.fulfilled() {
   384  			t.Error("Db is not closed")
   385  		}
   386  	})
   387  }
   388  

View as plain text