...

Source file src/github.com/jackc/pgx/v5/pgtype/bytea_test.go

Documentation: github.com/jackc/pgx/v5/pgtype

     1  package pgtype_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"testing"
     8  
     9  	pgx "github.com/jackc/pgx/v5"
    10  	"github.com/jackc/pgx/v5/pgtype"
    11  	"github.com/jackc/pgx/v5/pgxtest"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func isExpectedEqBytes(a any) func(any) bool {
    16  	return func(v any) bool {
    17  		ab := a.([]byte)
    18  		vb := v.([]byte)
    19  
    20  		if (ab == nil) != (vb == nil) {
    21  			return false
    22  		}
    23  
    24  		if ab == nil {
    25  			return true
    26  		}
    27  
    28  		return bytes.Equal(ab, vb)
    29  	}
    30  }
    31  
    32  func TestByteaCodec(t *testing.T) {
    33  	pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{
    34  		{[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})},
    35  		{[]byte{}, new([]byte), isExpectedEqBytes([]byte{})},
    36  		{[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))},
    37  		{nil, new([]byte), isExpectedEqBytes([]byte(nil))},
    38  	})
    39  }
    40  
    41  func TestDriverBytesQueryRow(t *testing.T) {
    42  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
    43  		var buf []byte
    44  		err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
    45  		require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
    46  	})
    47  }
    48  
    49  func TestDriverBytes(t *testing.T) {
    50  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
    51  		argBuf := make([]byte, 128)
    52  		for i := range argBuf {
    53  			argBuf[i] = byte(i)
    54  		}
    55  
    56  		rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
    57  		require.NoError(t, err)
    58  		defer rows.Close()
    59  
    60  		rowCount := 0
    61  		resultBuf := argBuf
    62  		detectedResultMutation := false
    63  		for rows.Next() {
    64  			rowCount++
    65  
    66  			// At some point the buffer should be reused and change.
    67  			if !bytes.Equal(argBuf, resultBuf) {
    68  				detectedResultMutation = true
    69  			}
    70  
    71  			err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
    72  			require.NoError(t, err)
    73  
    74  			require.Len(t, resultBuf, len(argBuf))
    75  			require.Equal(t, resultBuf, argBuf)
    76  			require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
    77  		}
    78  
    79  		require.True(t, detectedResultMutation)
    80  
    81  		err = rows.Err()
    82  		require.NoError(t, err)
    83  	})
    84  }
    85  
    86  func TestPreallocBytes(t *testing.T) {
    87  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
    88  		origBuf := []byte{5, 6, 7, 8}
    89  		buf := origBuf
    90  		err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf))
    91  		require.NoError(t, err)
    92  
    93  		require.Len(t, buf, 2)
    94  		require.Equal(t, 4, cap(buf))
    95  		require.Equal(t, buf, []byte{1, 2})
    96  
    97  		require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
    98  
    99  		err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf))
   100  		require.NoError(t, err)
   101  		require.Len(t, buf, 5)
   102  		require.Equal(t, 5, cap(buf))
   103  
   104  		require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
   105  	})
   106  }
   107  
   108  func TestUndecodedBytes(t *testing.T) {
   109  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   110  		var buf []byte
   111  		err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf))
   112  		require.NoError(t, err)
   113  
   114  		require.Len(t, buf, 4)
   115  		require.Equal(t, buf, []byte{0, 0, 0, 1})
   116  	})
   117  }
   118  
   119  func TestByteaCodecDecodeDatabaseSQLValue(t *testing.T) {
   120  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   121  		var buf []byte
   122  		err := conn.QueryRow(ctx, `select '\xa1b2c3d4'::bytea`).Scan(sqlScannerFunc(func(src any) error {
   123  			switch src := src.(type) {
   124  			case []byte:
   125  				buf = make([]byte, len(src))
   126  				copy(buf, src)
   127  				return nil
   128  			default:
   129  				return fmt.Errorf("expected []byte, got %T", src)
   130  			}
   131  		}))
   132  		require.NoError(t, err)
   133  
   134  		require.Len(t, buf, 4)
   135  		require.Equal(t, buf, []byte{0xa1, 0xb2, 0xc3, 0xd4})
   136  	})
   137  }
   138  

View as plain text