1 package pgtype_test
2
3 import (
4 "context"
5 "fmt"
6 "os"
7 "testing"
8
9 "github.com/jackc/pgtype"
10 "github.com/jackc/pgtype/testutil"
11 pgx "github.com/jackc/pgx/v4"
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14 )
15
16 func TestCompositeTypeSetAndGet(t *testing.T) {
17 ci := pgtype.NewConnInfo()
18 ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
19 {"a", pgtype.TextOID},
20 {"b", pgtype.Int4OID},
21 }, ci)
22 require.NoError(t, err)
23 assert.Equal(t, pgtype.Undefined, ct.Get())
24
25 nilTests := []struct {
26 src interface{}
27 }{
28 {nil},
29 {(*[]interface{})(nil)},
30 }
31
32 for i, tt := range nilTests {
33 err := ct.Set(tt.src)
34 assert.NoErrorf(t, err, "%d", i)
35 assert.Equal(t, nil, ct.Get())
36 }
37
38 compatibleValuesTests := []struct {
39 src []interface{}
40 expected map[string]interface{}
41 }{
42 {
43 src: []interface{}{"foo", int32(42)},
44 expected: map[string]interface{}{"a": "foo", "b": int32(42)},
45 },
46 {
47 src: []interface{}{nil, nil},
48 expected: map[string]interface{}{"a": nil, "b": nil},
49 },
50 {
51 src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}},
52 expected: map[string]interface{}{"a": "hi", "b": int32(7)},
53 },
54 }
55
56 for i, tt := range compatibleValuesTests {
57 err := ct.Set(tt.src)
58 assert.NoErrorf(t, err, "%d", i)
59 assert.EqualValues(t, tt.expected, ct.Get())
60 }
61 }
62
63 func TestCompositeTypeAssignTo(t *testing.T) {
64 ci := pgtype.NewConnInfo()
65 ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
66 {"a", pgtype.TextOID},
67 {"b", pgtype.Int4OID},
68 }, ci)
69 require.NoError(t, err)
70
71 {
72 err := ct.Set([]interface{}{"foo", int32(42)})
73 assert.NoError(t, err)
74
75 var a string
76 var b int32
77
78 err = ct.AssignTo([]interface{}{&a, &b})
79 assert.NoError(t, err)
80
81 assert.Equal(t, "foo", a)
82 assert.Equal(t, int32(42), b)
83 }
84
85 {
86 err := ct.Set([]interface{}{"foo", int32(42)})
87 assert.NoError(t, err)
88
89 var a pgtype.Text
90 var b pgtype.Int4
91
92 err = ct.AssignTo([]interface{}{&a, &b})
93 assert.NoError(t, err)
94
95 assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
96 assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
97 }
98
99
100 {
101 err := ct.Set([]interface{}{"foo", int32(42)})
102 assert.NoError(t, err)
103
104 var b int32
105
106 err = ct.AssignTo([]interface{}{nil, &b})
107 assert.NoError(t, err)
108
109 assert.Equal(t, int32(42), b)
110 }
111
112
113 {
114 err := ct.Set(nil)
115 assert.NoError(t, err)
116
117 var a pgtype.Text
118 var b pgtype.Int4
119 dst := []interface{}{&a, &b}
120
121 err = ct.AssignTo(&dst)
122 assert.NoError(t, err)
123
124 assert.Nil(t, dst)
125 }
126
127
128 {
129 err := ct.Set([]interface{}{"foo", int32(42)})
130 assert.NoError(t, err)
131
132 var a pgtype.Text
133 var b pgtype.Int4
134 dst := []interface{}{&a, &b}
135
136 err = ct.AssignTo(&dst)
137 assert.NoError(t, err)
138
139 assert.NotNil(t, dst)
140 assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
141 assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
142 }
143
144
145 {
146 err := ct.Set([]interface{}{"foo", int32(42)})
147 assert.NoError(t, err)
148
149 s := struct {
150 A string
151 B int32
152 }{}
153
154 err = ct.AssignTo(&s)
155 if assert.NoError(t, err) {
156 assert.Equal(t, "foo", s.A)
157 assert.Equal(t, int32(42), s.B)
158 }
159 }
160 }
161
162 func TestCompositeTypeTranscode(t *testing.T) {
163 conn := testutil.MustConnectPgx(t)
164 defer testutil.MustCloseContext(t, conn)
165
166 _, err := conn.Exec(context.Background(), `drop type if exists ct_test;
167
168 create type ct_test as (
169 a text,
170 b int4
171 );`)
172 require.NoError(t, err)
173 defer conn.Exec(context.Background(), "drop type ct_test")
174
175 var oid uint32
176 err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
177 require.NoError(t, err)
178
179 defer conn.Exec(context.Background(), "drop type ct_test")
180
181 ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
182 {"a", pgtype.TextOID},
183 {"b", pgtype.Int4OID},
184 }, conn.ConnInfo())
185 require.NoError(t, err)
186 conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
187
188
189 simpleProtocols := []bool{true, false}
190
191 var a string
192 var b int32
193
194 for _, simpleProtocol := range simpleProtocols {
195 err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
196 pgtype.CompositeFields{"hi", int32(42)},
197 ).Scan(
198 []interface{}{&a, &b},
199 )
200 if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
201 assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
202 assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
203 }
204 }
205 }
206
207
208 func TestCompositeTypeTextDecodeNested(t *testing.T) {
209 newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType {
210 fields := make([]pgtype.CompositeTypeField, len(fieldNames))
211 for i, name := range fieldNames {
212 fields[i] = pgtype.CompositeTypeField{Name: name}
213 }
214
215 rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals)
216 require.NoError(t, err)
217 return rowType
218 }
219
220 dimensionsType := func() pgtype.ValueTranscoder {
221 return newCompositeType(
222 "dimensions",
223 []string{"width", "height"},
224 &pgtype.Int4{},
225 &pgtype.Int4{},
226 )
227 }
228 productImageType := func() pgtype.ValueTranscoder {
229 return newCompositeType(
230 "product_image_type",
231 []string{"source", "dimensions"},
232 &pgtype.Text{},
233 dimensionsType(),
234 )
235 }
236 productImageSetType := newCompositeType(
237 "product_image_set_type",
238 []string{"name", "orig_image", "images"},
239 &pgtype.Text{},
240 productImageType(),
241 pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder {
242 return productImageType()
243 }),
244 )
245
246 err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`))
247 require.NoError(t, err)
248 }
249
250 func Example_composite() {
251 conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
252 if err != nil {
253 fmt.Println(err)
254 return
255 }
256
257 defer conn.Close(context.Background())
258 _, err = conn.Exec(context.Background(), `drop type if exists mytype;`)
259 if err != nil {
260 fmt.Println(err)
261 return
262 }
263
264 _, err = conn.Exec(context.Background(), `create type mytype as (
265 a int4,
266 b text
267 );`)
268 if err != nil {
269 fmt.Println(err)
270 return
271 }
272 defer conn.Exec(context.Background(), "drop type mytype")
273
274 var oid uint32
275 err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid)
276 if err != nil {
277 fmt.Println(err)
278 return
279 }
280
281 ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
282 {"a", pgtype.Int4OID},
283 {"b", pgtype.TextOID},
284 }, conn.ConnInfo())
285 if err != nil {
286 fmt.Println(err)
287 return
288 }
289 conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
290
291 var a int
292 var b *string
293
294 err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b})
295 if err != nil {
296 fmt.Println(err)
297 return
298 }
299
300 fmt.Printf("First: a=%d b=%s\n", a, *b)
301
302 err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b})
303 if err != nil {
304 fmt.Println(err)
305 return
306 }
307
308 fmt.Printf("Second: a=%d b=%v\n", a, b)
309
310 scanTarget := []interface{}{&a, &b}
311 err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget)
312 E(err)
313
314 fmt.Printf("Third: isNull=%v\n", scanTarget == nil)
315
316
317
318
319
320 }
321
View as plain text