1 package pgx_test
2
3 import (
4 "context"
5 "testing"
6
7 "github.com/jackc/pgx/v5"
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10 )
11
12 func TestNamedArgsRewriteQuery(t *testing.T) {
13 t.Parallel()
14
15 for i, tt := range []struct {
16 sql string
17 args []any
18 namedArgs pgx.NamedArgs
19 expectedSQL string
20 expectedArgs []any
21 }{
22 {
23 sql: "select * from users where id = @id",
24 namedArgs: pgx.NamedArgs{"id": int32(42)},
25 expectedSQL: "select * from users where id = $1",
26 expectedArgs: []any{int32(42)},
27 },
28 {
29 sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
30 namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
31 expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
32 expectedArgs: []any{int32(42), int32(1)},
33 },
34 {
35 sql: "select @a::int, @b::text",
36 namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
37 expectedSQL: "select $1::int, $2::text",
38 expectedArgs: []any{int32(42), "foo"},
39 },
40 {
41 sql: "select @Abc::int, @b_4::text, @_c::int",
42 namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
43 expectedSQL: "select $1::int, $2::text, $3::int",
44 expectedArgs: []any{int32(42), "foo", int32(1)},
45 },
46 {
47 sql: "at end @",
48 namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
49 expectedSQL: "at end @",
50 expectedArgs: []any{},
51 },
52 {
53 sql: "ignores without valid character after @ foo bar",
54 namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
55 expectedSQL: "ignores without valid character after @ foo bar",
56 expectedArgs: []any{},
57 },
58 {
59 sql: "name cannot start with number @1 foo bar",
60 namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
61 expectedSQL: "name cannot start with number @1 foo bar",
62 expectedArgs: []any{},
63 },
64 {
65 sql: `select *, '@foo' as "@bar" from users where id = @id`,
66 namedArgs: pgx.NamedArgs{"id": int32(42)},
67 expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
68 expectedArgs: []any{int32(42)},
69 },
70 {
71 sql: `select * -- @foo
72 from users -- @single line comments
73 where id = @id;`,
74 namedArgs: pgx.NamedArgs{"id": int32(42)},
75 expectedSQL: `select * -- @foo
76 from users -- @single line comments
77 where id = $1;`,
78 expectedArgs: []any{int32(42)},
79 },
80 {
81 sql: `select * /* @multi line
82 @comment
83 */
84 /* /* with @nesting */ */
85 from users
86 where id = @id;`,
87 namedArgs: pgx.NamedArgs{"id": int32(42)},
88 expectedSQL: `select * /* @multi line
89 @comment
90 */
91 /* /* with @nesting */ */
92 from users
93 where id = $1;`,
94 expectedArgs: []any{int32(42)},
95 },
96
97
98 } {
99 sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
100 require.NoError(t, err)
101 assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
102 assert.Equalf(t, tt.expectedArgs, args, "%d", i)
103 }
104 }
105
View as plain text