...

Source file src/github.com/jackc/pgx/v5/copy_from_test.go

Documentation: github.com/jackc/pgx/v5

     1  package pgx_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/jackc/pgx/v5"
    12  	"github.com/jackc/pgx/v5/pgconn"
    13  	"github.com/jackc/pgx/v5/pgxtest"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestConnCopyWithAllQueryExecModes(t *testing.T) {
    18  	for _, mode := range pgxtest.AllQueryExecModes {
    19  		t.Run(mode.String(), func(t *testing.T) {
    20  			ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
    21  			defer cancel()
    22  
    23  			cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
    24  			cfg.DefaultQueryExecMode = mode
    25  			conn := mustConnect(t, cfg)
    26  			defer closeConn(t, conn)
    27  
    28  			mustExec(t, conn, `create temporary table foo(
    29  			a int2,
    30  			b int4,
    31  			c int8,
    32  			d text,
    33  			e timestamptz
    34  		)`)
    35  
    36  			tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
    37  
    38  			inputRows := [][]any{
    39  				{int16(0), int32(1), int64(2), "abc", tzedTime},
    40  				{nil, nil, nil, nil, nil},
    41  			}
    42  
    43  			copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
    44  			if err != nil {
    45  				t.Errorf("Unexpected error for CopyFrom: %v", err)
    46  			}
    47  			if int(copyCount) != len(inputRows) {
    48  				t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
    49  			}
    50  
    51  			rows, err := conn.Query(ctx, "select * from foo")
    52  			if err != nil {
    53  				t.Errorf("Unexpected error for Query: %v", err)
    54  			}
    55  
    56  			var outputRows [][]any
    57  			for rows.Next() {
    58  				row, err := rows.Values()
    59  				if err != nil {
    60  					t.Errorf("Unexpected error for rows.Values(): %v", err)
    61  				}
    62  				outputRows = append(outputRows, row)
    63  			}
    64  
    65  			if rows.Err() != nil {
    66  				t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
    67  			}
    68  
    69  			if !reflect.DeepEqual(inputRows, outputRows) {
    70  				t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
    71  			}
    72  
    73  			ensureConnValid(t, conn)
    74  		})
    75  	}
    76  }
    77  
    78  func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
    79  
    80  	for _, mode := range pgxtest.KnownOIDQueryExecModes {
    81  		t.Run(mode.String(), func(t *testing.T) {
    82  			ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
    83  			defer cancel()
    84  
    85  			cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
    86  			cfg.DefaultQueryExecMode = mode
    87  			conn := mustConnect(t, cfg)
    88  			defer closeConn(t, conn)
    89  
    90  			mustExec(t, conn, `create temporary table foo(
    91  			a int2,
    92  			b int4,
    93  			c int8,
    94  			d varchar,
    95  			e text,
    96  			f date,
    97  			g timestamptz
    98  		)`)
    99  
   100  			tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
   101  
   102  			inputRows := [][]any{
   103  				{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
   104  				{nil, nil, nil, nil, nil, nil, nil},
   105  			}
   106  
   107  			copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
   108  			if err != nil {
   109  				t.Errorf("Unexpected error for CopyFrom: %v", err)
   110  			}
   111  			if int(copyCount) != len(inputRows) {
   112  				t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
   113  			}
   114  
   115  			rows, err := conn.Query(ctx, "select * from foo")
   116  			if err != nil {
   117  				t.Errorf("Unexpected error for Query: %v", err)
   118  			}
   119  
   120  			var outputRows [][]any
   121  			for rows.Next() {
   122  				row, err := rows.Values()
   123  				if err != nil {
   124  					t.Errorf("Unexpected error for rows.Values(): %v", err)
   125  				}
   126  				outputRows = append(outputRows, row)
   127  			}
   128  
   129  			if rows.Err() != nil {
   130  				t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   131  			}
   132  
   133  			if !reflect.DeepEqual(inputRows, outputRows) {
   134  				t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
   135  			}
   136  
   137  			ensureConnValid(t, conn)
   138  		})
   139  	}
   140  }
   141  
   142  func TestConnCopyFromSmall(t *testing.T) {
   143  	t.Parallel()
   144  
   145  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   146  	defer cancel()
   147  
   148  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   149  	defer closeConn(t, conn)
   150  
   151  	mustExec(t, conn, `create temporary table foo(
   152  		a int2,
   153  		b int4,
   154  		c int8,
   155  		d varchar,
   156  		e text,
   157  		f date,
   158  		g timestamptz
   159  	)`)
   160  
   161  	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
   162  
   163  	inputRows := [][]any{
   164  		{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
   165  		{nil, nil, nil, nil, nil, nil, nil},
   166  	}
   167  
   168  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
   169  	if err != nil {
   170  		t.Errorf("Unexpected error for CopyFrom: %v", err)
   171  	}
   172  	if int(copyCount) != len(inputRows) {
   173  		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
   174  	}
   175  
   176  	rows, err := conn.Query(ctx, "select * from foo")
   177  	if err != nil {
   178  		t.Errorf("Unexpected error for Query: %v", err)
   179  	}
   180  
   181  	var outputRows [][]any
   182  	for rows.Next() {
   183  		row, err := rows.Values()
   184  		if err != nil {
   185  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   186  		}
   187  		outputRows = append(outputRows, row)
   188  	}
   189  
   190  	if rows.Err() != nil {
   191  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   192  	}
   193  
   194  	if !reflect.DeepEqual(inputRows, outputRows) {
   195  		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
   196  	}
   197  
   198  	ensureConnValid(t, conn)
   199  }
   200  
   201  func TestConnCopyFromSliceSmall(t *testing.T) {
   202  	t.Parallel()
   203  
   204  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   205  	defer cancel()
   206  
   207  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   208  	defer closeConn(t, conn)
   209  
   210  	mustExec(t, conn, `create temporary table foo(
   211  		a int2,
   212  		b int4,
   213  		c int8,
   214  		d varchar,
   215  		e text,
   216  		f date,
   217  		g timestamptz
   218  	)`)
   219  
   220  	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
   221  
   222  	inputRows := [][]any{
   223  		{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
   224  		{nil, nil, nil, nil, nil, nil, nil},
   225  	}
   226  
   227  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
   228  		pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
   229  			return inputRows[i], nil
   230  		}))
   231  	if err != nil {
   232  		t.Errorf("Unexpected error for CopyFrom: %v", err)
   233  	}
   234  	if int(copyCount) != len(inputRows) {
   235  		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
   236  	}
   237  
   238  	rows, err := conn.Query(ctx, "select * from foo")
   239  	if err != nil {
   240  		t.Errorf("Unexpected error for Query: %v", err)
   241  	}
   242  
   243  	var outputRows [][]any
   244  	for rows.Next() {
   245  		row, err := rows.Values()
   246  		if err != nil {
   247  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   248  		}
   249  		outputRows = append(outputRows, row)
   250  	}
   251  
   252  	if rows.Err() != nil {
   253  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   254  	}
   255  
   256  	if !reflect.DeepEqual(inputRows, outputRows) {
   257  		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
   258  	}
   259  
   260  	ensureConnValid(t, conn)
   261  }
   262  
   263  func TestConnCopyFromLarge(t *testing.T) {
   264  	t.Parallel()
   265  
   266  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   267  	defer cancel()
   268  
   269  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   270  	defer closeConn(t, conn)
   271  
   272  	mustExec(t, conn, `create temporary table foo(
   273  		a int2,
   274  		b int4,
   275  		c int8,
   276  		d varchar,
   277  		e text,
   278  		f date,
   279  		g timestamptz,
   280  		h bytea
   281  	)`)
   282  
   283  	tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
   284  
   285  	inputRows := [][]any{}
   286  
   287  	for i := 0; i < 10000; i++ {
   288  		inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
   289  	}
   290  
   291  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
   292  	if err != nil {
   293  		t.Errorf("Unexpected error for CopyFrom: %v", err)
   294  	}
   295  	if int(copyCount) != len(inputRows) {
   296  		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
   297  	}
   298  
   299  	rows, err := conn.Query(ctx, "select * from foo")
   300  	if err != nil {
   301  		t.Errorf("Unexpected error for Query: %v", err)
   302  	}
   303  
   304  	var outputRows [][]any
   305  	for rows.Next() {
   306  		row, err := rows.Values()
   307  		if err != nil {
   308  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   309  		}
   310  		outputRows = append(outputRows, row)
   311  	}
   312  
   313  	if rows.Err() != nil {
   314  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   315  	}
   316  
   317  	if !reflect.DeepEqual(inputRows, outputRows) {
   318  		t.Errorf("Input rows and output rows do not equal")
   319  	}
   320  
   321  	ensureConnValid(t, conn)
   322  }
   323  
   324  func TestConnCopyFromEnum(t *testing.T) {
   325  	t.Parallel()
   326  
   327  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   328  	defer cancel()
   329  
   330  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   331  	defer closeConn(t, conn)
   332  
   333  	tx, err := conn.Begin(ctx)
   334  	require.NoError(t, err)
   335  	defer tx.Rollback(ctx)
   336  
   337  	_, err = tx.Exec(ctx, `drop type if exists color`)
   338  	require.NoError(t, err)
   339  
   340  	_, err = tx.Exec(ctx, `drop type if exists fruit`)
   341  	require.NoError(t, err)
   342  
   343  	_, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`)
   344  	require.NoError(t, err)
   345  
   346  	_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
   347  	require.NoError(t, err)
   348  
   349  	// Obviously using conn while a tx is in use and registering a type after the connection has been established are
   350  	// really bad practices, but for the sake of convenience we do it in the test here.
   351  	for _, name := range []string{"fruit", "color"} {
   352  		typ, err := conn.LoadType(ctx, name)
   353  		require.NoError(t, err)
   354  		conn.TypeMap().RegisterType(typ)
   355  	}
   356  
   357  	_, err = tx.Exec(ctx, `create temporary table foo(
   358  		a text,
   359  		b color,
   360  		c fruit,
   361  		d color,
   362  		e fruit,
   363  		f text
   364  	)`)
   365  	require.NoError(t, err)
   366  
   367  	inputRows := [][]any{
   368  		{"abc", "blue", "grape", "orange", "orange", "def"},
   369  		{nil, nil, nil, nil, nil, nil},
   370  	}
   371  
   372  	copyCount, err := tx.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
   373  	require.NoError(t, err)
   374  	require.EqualValues(t, len(inputRows), copyCount)
   375  
   376  	rows, err := tx.Query(ctx, "select * from foo")
   377  	require.NoError(t, err)
   378  
   379  	var outputRows [][]any
   380  	for rows.Next() {
   381  		row, err := rows.Values()
   382  		require.NoError(t, err)
   383  		outputRows = append(outputRows, row)
   384  	}
   385  
   386  	require.NoError(t, rows.Err())
   387  
   388  	if !reflect.DeepEqual(inputRows, outputRows) {
   389  		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
   390  	}
   391  
   392  	err = tx.Rollback(ctx)
   393  	require.NoError(t, err)
   394  
   395  	ensureConnValid(t, conn)
   396  }
   397  
   398  func TestConnCopyFromJSON(t *testing.T) {
   399  	t.Parallel()
   400  
   401  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   402  	defer cancel()
   403  
   404  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   405  	defer closeConn(t, conn)
   406  
   407  	for _, typeName := range []string{"json", "jsonb"} {
   408  		if _, ok := conn.TypeMap().TypeForName(typeName); !ok {
   409  			return // No JSON/JSONB type -- must be running against old PostgreSQL
   410  		}
   411  	}
   412  
   413  	mustExec(t, conn, `create temporary table foo(
   414  		a json,
   415  		b jsonb
   416  	)`)
   417  
   418  	inputRows := [][]any{
   419  		{map[string]any{"foo": "bar"}, map[string]any{"bar": "quz"}},
   420  		{nil, nil},
   421  	}
   422  
   423  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
   424  	if err != nil {
   425  		t.Errorf("Unexpected error for CopyFrom: %v", err)
   426  	}
   427  	if int(copyCount) != len(inputRows) {
   428  		t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
   429  	}
   430  
   431  	rows, err := conn.Query(ctx, "select * from foo")
   432  	if err != nil {
   433  		t.Errorf("Unexpected error for Query: %v", err)
   434  	}
   435  
   436  	var outputRows [][]any
   437  	for rows.Next() {
   438  		row, err := rows.Values()
   439  		if err != nil {
   440  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   441  		}
   442  		outputRows = append(outputRows, row)
   443  	}
   444  
   445  	if rows.Err() != nil {
   446  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   447  	}
   448  
   449  	if !reflect.DeepEqual(inputRows, outputRows) {
   450  		t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
   451  	}
   452  
   453  	ensureConnValid(t, conn)
   454  }
   455  
   456  type clientFailSource struct {
   457  	count int
   458  	err   error
   459  }
   460  
   461  func (cfs *clientFailSource) Next() bool {
   462  	cfs.count++
   463  	return cfs.count < 100
   464  }
   465  
   466  func (cfs *clientFailSource) Values() ([]any, error) {
   467  	if cfs.count == 3 {
   468  		cfs.err = fmt.Errorf("client error")
   469  		return nil, cfs.err
   470  	}
   471  	return []any{make([]byte, 100000)}, nil
   472  }
   473  
   474  func (cfs *clientFailSource) Err() error {
   475  	return cfs.err
   476  }
   477  
   478  func TestConnCopyFromFailServerSideMidway(t *testing.T) {
   479  	t.Parallel()
   480  
   481  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   482  	defer cancel()
   483  
   484  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   485  	defer closeConn(t, conn)
   486  
   487  	mustExec(t, conn, `create temporary table foo(
   488  		a int4,
   489  		b varchar not null
   490  	)`)
   491  
   492  	inputRows := [][]any{
   493  		{int32(1), "abc"},
   494  		{int32(2), nil}, // this row should trigger a failure
   495  		{int32(3), "def"},
   496  	}
   497  
   498  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
   499  	if err == nil {
   500  		t.Errorf("Expected CopyFrom return error, but it did not")
   501  	}
   502  	if _, ok := err.(*pgconn.PgError); !ok {
   503  		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
   504  	}
   505  	if copyCount != 0 {
   506  		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
   507  	}
   508  
   509  	rows, err := conn.Query(ctx, "select * from foo")
   510  	if err != nil {
   511  		t.Errorf("Unexpected error for Query: %v", err)
   512  	}
   513  
   514  	var outputRows [][]any
   515  	for rows.Next() {
   516  		row, err := rows.Values()
   517  		if err != nil {
   518  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   519  		}
   520  		outputRows = append(outputRows, row)
   521  	}
   522  
   523  	if rows.Err() != nil {
   524  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   525  	}
   526  
   527  	if len(outputRows) != 0 {
   528  		t.Errorf("Expected 0 rows, but got %v", outputRows)
   529  	}
   530  
   531  	mustExec(t, conn, "truncate foo")
   532  
   533  	ensureConnValid(t, conn)
   534  }
   535  
   536  type failSource struct {
   537  	count int
   538  }
   539  
   540  func (fs *failSource) Next() bool {
   541  	time.Sleep(time.Millisecond * 100)
   542  	fs.count++
   543  	return fs.count < 100
   544  }
   545  
   546  func (fs *failSource) Values() ([]any, error) {
   547  	if fs.count == 3 {
   548  		return []any{nil}, nil
   549  	}
   550  	return []any{make([]byte, 100000)}, nil
   551  }
   552  
   553  func (fs *failSource) Err() error {
   554  	return nil
   555  }
   556  
   557  func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
   558  	t.Parallel()
   559  
   560  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   561  	defer cancel()
   562  
   563  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   564  	defer closeConn(t, conn)
   565  
   566  	pgxtest.SkipCockroachDB(t, conn, "Server copy error does not fail fast")
   567  
   568  	mustExec(t, conn, `create temporary table foo(
   569  		a bytea not null
   570  	)`)
   571  
   572  	startTime := time.Now()
   573  
   574  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
   575  	if err == nil {
   576  		t.Errorf("Expected CopyFrom return error, but it did not")
   577  	}
   578  	if _, ok := err.(*pgconn.PgError); !ok {
   579  		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
   580  	}
   581  	if copyCount != 0 {
   582  		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
   583  	}
   584  
   585  	endTime := time.Now()
   586  	copyTime := endTime.Sub(startTime)
   587  	if copyTime > time.Second {
   588  		t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
   589  	}
   590  
   591  	rows, err := conn.Query(ctx, "select * from foo")
   592  	if err != nil {
   593  		t.Errorf("Unexpected error for Query: %v", err)
   594  	}
   595  
   596  	var outputRows [][]any
   597  	for rows.Next() {
   598  		row, err := rows.Values()
   599  		if err != nil {
   600  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   601  		}
   602  		outputRows = append(outputRows, row)
   603  	}
   604  
   605  	if rows.Err() != nil {
   606  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   607  	}
   608  
   609  	if len(outputRows) != 0 {
   610  		t.Errorf("Expected 0 rows, but got %v", outputRows)
   611  	}
   612  
   613  	ensureConnValid(t, conn)
   614  }
   615  
   616  type slowFailRaceSource struct {
   617  	count int
   618  }
   619  
   620  func (fs *slowFailRaceSource) Next() bool {
   621  	time.Sleep(time.Millisecond)
   622  	fs.count++
   623  	return fs.count < 1000
   624  }
   625  
   626  func (fs *slowFailRaceSource) Values() ([]any, error) {
   627  	if fs.count == 500 {
   628  		return []any{nil, nil}, nil
   629  	}
   630  	return []any{1, make([]byte, 1000)}, nil
   631  }
   632  
   633  func (fs *slowFailRaceSource) Err() error {
   634  	return nil
   635  }
   636  
   637  func TestConnCopyFromSlowFailRace(t *testing.T) {
   638  	t.Parallel()
   639  
   640  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   641  	defer cancel()
   642  
   643  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   644  	defer closeConn(t, conn)
   645  
   646  	mustExec(t, conn, `create temporary table foo(
   647  		a int not null,
   648  		b bytea not null
   649  	)`)
   650  
   651  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
   652  	if err == nil {
   653  		t.Errorf("Expected CopyFrom return error, but it did not")
   654  	}
   655  	if _, ok := err.(*pgconn.PgError); !ok {
   656  		t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
   657  	}
   658  	if copyCount != 0 {
   659  		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
   660  	}
   661  
   662  	ensureConnValid(t, conn)
   663  }
   664  
   665  func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
   666  	t.Parallel()
   667  
   668  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   669  	defer cancel()
   670  
   671  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   672  	defer closeConn(t, conn)
   673  
   674  	mustExec(t, conn, `create temporary table foo(
   675  		a bytea not null
   676  	)`)
   677  
   678  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
   679  	if err == nil {
   680  		t.Errorf("Expected CopyFrom return error, but it did not")
   681  	}
   682  	if copyCount != 0 {
   683  		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
   684  	}
   685  
   686  	rows, err := conn.Query(ctx, "select * from foo")
   687  	if err != nil {
   688  		t.Errorf("Unexpected error for Query: %v", err)
   689  	}
   690  
   691  	var outputRows [][]any
   692  	for rows.Next() {
   693  		row, err := rows.Values()
   694  		if err != nil {
   695  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   696  		}
   697  		outputRows = append(outputRows, row)
   698  	}
   699  
   700  	if rows.Err() != nil {
   701  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   702  	}
   703  
   704  	if len(outputRows) != 0 {
   705  		t.Errorf("Expected 0 rows, but got %v", len(outputRows))
   706  	}
   707  
   708  	ensureConnValid(t, conn)
   709  }
   710  
   711  type clientFinalErrSource struct {
   712  	count int
   713  }
   714  
   715  func (cfs *clientFinalErrSource) Next() bool {
   716  	cfs.count++
   717  	return cfs.count < 5
   718  }
   719  
   720  func (cfs *clientFinalErrSource) Values() ([]any, error) {
   721  	return []any{make([]byte, 100000)}, nil
   722  }
   723  
   724  func (cfs *clientFinalErrSource) Err() error {
   725  	return fmt.Errorf("final error")
   726  }
   727  
   728  func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
   729  	t.Parallel()
   730  
   731  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   732  	defer cancel()
   733  
   734  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   735  	defer closeConn(t, conn)
   736  
   737  	mustExec(t, conn, `create temporary table foo(
   738  		a bytea not null
   739  	)`)
   740  
   741  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
   742  	if err == nil {
   743  		t.Errorf("Expected CopyFrom return error, but it did not")
   744  	}
   745  	if copyCount != 0 {
   746  		t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
   747  	}
   748  
   749  	rows, err := conn.Query(ctx, "select * from foo")
   750  	if err != nil {
   751  		t.Errorf("Unexpected error for Query: %v", err)
   752  	}
   753  
   754  	var outputRows [][]any
   755  	for rows.Next() {
   756  		row, err := rows.Values()
   757  		if err != nil {
   758  			t.Errorf("Unexpected error for rows.Values(): %v", err)
   759  		}
   760  		outputRows = append(outputRows, row)
   761  	}
   762  
   763  	if rows.Err() != nil {
   764  		t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
   765  	}
   766  
   767  	if len(outputRows) != 0 {
   768  		t.Errorf("Expected 0 rows, but got %v", outputRows)
   769  	}
   770  
   771  	ensureConnValid(t, conn)
   772  }
   773  
   774  func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
   775  	t.Parallel()
   776  
   777  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   778  	defer cancel()
   779  
   780  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   781  	defer closeConn(t, conn)
   782  
   783  	mustExec(t, conn, `create temporary table foo(
   784  		a int8
   785  	)`)
   786  
   787  	inputRows := [][]interface{}{
   788  		{"42"},
   789  		{"7"},
   790  		{8},
   791  	}
   792  
   793  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
   794  	require.NoError(t, err)
   795  	require.EqualValues(t, len(inputRows), copyCount)
   796  
   797  	rows, _ := conn.Query(ctx, "select * from foo")
   798  	nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
   799  	require.NoError(t, err)
   800  
   801  	require.Equal(t, []int64{42, 7, 8}, nums)
   802  
   803  	ensureConnValid(t, conn)
   804  }
   805  
   806  // https://github.com/jackc/pgx/discussions/1891
   807  func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
   808  	t.Parallel()
   809  
   810  	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
   811  	defer cancel()
   812  
   813  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   814  	defer closeConn(t, conn)
   815  
   816  	mustExec(t, conn, `create temporary table foo(
   817  		a numeric[]
   818  	)`)
   819  
   820  	inputRows := [][]interface{}{
   821  		{[]string{"42"}},
   822  		{[]string{"7"}},
   823  		{[]string{"8", "9"}},
   824  		{[][]string{{"10", "11"}, {"12", "13"}}},
   825  	}
   826  
   827  	copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
   828  	require.NoError(t, err)
   829  	require.EqualValues(t, len(inputRows), copyCount)
   830  
   831  	// Test reads as int64 and flattened array for simplicity.
   832  	rows, _ := conn.Query(ctx, "select * from foo")
   833  	nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
   834  	require.NoError(t, err)
   835  	require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
   836  
   837  	ensureConnValid(t, conn)
   838  }
   839  
   840  func TestCopyFromFunc(t *testing.T) {
   841  	t.Parallel()
   842  
   843  	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
   844  	defer closeConn(t, conn)
   845  
   846  	mustExec(t, conn, `create temporary table foo(
   847  		a int
   848  	)`)
   849  
   850  	dataCh := make(chan int, 1)
   851  
   852  	const channelItems = 10
   853  	go func() {
   854  		for i := 0; i < channelItems; i++ {
   855  			dataCh <- i
   856  		}
   857  		close(dataCh)
   858  	}()
   859  
   860  	copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
   861  		pgx.CopyFromFunc(func() ([]any, error) {
   862  			v, ok := <-dataCh
   863  			if !ok {
   864  				return nil, nil
   865  			}
   866  			return []any{v}, nil
   867  		}))
   868  
   869  	require.ErrorIs(t, err, nil)
   870  	require.EqualValues(t, channelItems, copyCount)
   871  
   872  	rows, err := conn.Query(context.Background(), "select * from foo order by a")
   873  	require.NoError(t, err)
   874  	nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
   875  	require.NoError(t, err)
   876  	require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)
   877  
   878  	// simulate a failure
   879  	copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
   880  		pgx.CopyFromFunc(func() func() ([]any, error) {
   881  			x := 9
   882  			return func() ([]any, error) {
   883  				x++
   884  				if x > 100 {
   885  					return nil, fmt.Errorf("simulated error")
   886  				}
   887  				return []any{x}, nil
   888  			}
   889  		}()))
   890  	require.NotErrorIs(t, err, nil)
   891  	require.EqualValues(t, 0, copyCount) // no change, due to error
   892  
   893  	ensureConnValid(t, conn)
   894  }
   895  

View as plain text