1 package pgtype_test
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "testing"
8
9 pgx "github.com/jackc/pgx/v5"
10 "github.com/jackc/pgx/v5/pgtype"
11 "github.com/jackc/pgx/v5/pgxtest"
12 "github.com/stretchr/testify/require"
13 )
14
15 func isExpectedEqBytes(a any) func(any) bool {
16 return func(v any) bool {
17 ab := a.([]byte)
18 vb := v.([]byte)
19
20 if (ab == nil) != (vb == nil) {
21 return false
22 }
23
24 if ab == nil {
25 return true
26 }
27
28 return bytes.Equal(ab, vb)
29 }
30 }
31
32 func TestByteaCodec(t *testing.T) {
33 pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{
34 {[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})},
35 {[]byte{}, new([]byte), isExpectedEqBytes([]byte{})},
36 {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))},
37 {nil, new([]byte), isExpectedEqBytes([]byte(nil))},
38 })
39 }
40
41 func TestDriverBytesQueryRow(t *testing.T) {
42 defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
43 var buf []byte
44 err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
45 require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
46 })
47 }
48
49 func TestDriverBytes(t *testing.T) {
50 defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
51 argBuf := make([]byte, 128)
52 for i := range argBuf {
53 argBuf[i] = byte(i)
54 }
55
56 rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
57 require.NoError(t, err)
58 defer rows.Close()
59
60 rowCount := 0
61 resultBuf := argBuf
62 detectedResultMutation := false
63 for rows.Next() {
64 rowCount++
65
66
67 if !bytes.Equal(argBuf, resultBuf) {
68 detectedResultMutation = true
69 }
70
71 err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
72 require.NoError(t, err)
73
74 require.Len(t, resultBuf, len(argBuf))
75 require.Equal(t, resultBuf, argBuf)
76 require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
77 }
78
79 require.True(t, detectedResultMutation)
80
81 err = rows.Err()
82 require.NoError(t, err)
83 })
84 }
85
86 func TestPreallocBytes(t *testing.T) {
87 defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
88 origBuf := []byte{5, 6, 7, 8}
89 buf := origBuf
90 err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf))
91 require.NoError(t, err)
92
93 require.Len(t, buf, 2)
94 require.Equal(t, 4, cap(buf))
95 require.Equal(t, buf, []byte{1, 2})
96
97 require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
98
99 err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf))
100 require.NoError(t, err)
101 require.Len(t, buf, 5)
102 require.Equal(t, 5, cap(buf))
103
104 require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
105 })
106 }
107
108 func TestUndecodedBytes(t *testing.T) {
109 defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
110 var buf []byte
111 err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf))
112 require.NoError(t, err)
113
114 require.Len(t, buf, 4)
115 require.Equal(t, buf, []byte{0, 0, 0, 1})
116 })
117 }
118
119 func TestByteaCodecDecodeDatabaseSQLValue(t *testing.T) {
120 defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
121 var buf []byte
122 err := conn.QueryRow(ctx, `select '\xa1b2c3d4'::bytea`).Scan(sqlScannerFunc(func(src any) error {
123 switch src := src.(type) {
124 case []byte:
125 buf = make([]byte, len(src))
126 copy(buf, src)
127 return nil
128 default:
129 return fmt.Errorf("expected []byte, got %T", src)
130 }
131 }))
132 require.NoError(t, err)
133
134 require.Len(t, buf, 4)
135 require.Equal(t, buf, []byte{0xa1, 0xb2, 0xc3, 0xd4})
136 })
137 }
138
View as plain text