1 package pgtype_test
2
3 import (
4 "bytes"
5 "context"
6 "testing"
7
8 "github.com/jackc/pgtype"
9 "github.com/jackc/pgtype/testutil"
10 "github.com/jackc/pgx/v4"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13 )
14
15 func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType {
16 _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;")
17 require.NoError(t, err)
18
19 _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');")
20 require.NoError(t, err)
21
22 var oid uint32
23 err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid)
24 require.NoError(t, err)
25
26 et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"})
27 conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid})
28
29 return et
30 }
31
32 func cleanupEnum(t *testing.T, conn *pgx.Conn) {
33 _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;")
34 require.NoError(t, err)
35 }
36
37 func TestEnumTypeTranscode(t *testing.T) {
38 conn := testutil.MustConnectPgx(t)
39 defer testutil.MustCloseContext(t, conn)
40
41 setupEnum(t, conn)
42 defer cleanupEnum(t, conn)
43
44 var dst string
45 err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst)
46 require.NoError(t, err)
47 require.EqualValues(t, "blue", dst)
48 }
49
50 func TestEnumTypeSet(t *testing.T) {
51 conn := testutil.MustConnectPgx(t)
52 defer testutil.MustCloseContext(t, conn)
53
54 enumType := setupEnum(t, conn)
55 defer cleanupEnum(t, conn)
56
57 successfulTests := []struct {
58 source interface{}
59 result interface{}
60 }{
61 {source: "blue", result: "blue"},
62 {source: _string("green"), result: "green"},
63 {source: (*string)(nil), result: nil},
64 }
65
66 for i, tt := range successfulTests {
67 err := enumType.Set(tt.source)
68 assert.NoErrorf(t, err, "%d", i)
69 assert.Equalf(t, tt.result, enumType.Get(), "%d", i)
70 }
71 }
72
73 func TestEnumTypeAssignTo(t *testing.T) {
74 conn := testutil.MustConnectPgx(t)
75 defer testutil.MustCloseContext(t, conn)
76
77 enumType := setupEnum(t, conn)
78 defer cleanupEnum(t, conn)
79
80 {
81 var s string
82
83 err := enumType.Set("blue")
84 require.NoError(t, err)
85
86 err = enumType.AssignTo(&s)
87 require.NoError(t, err)
88
89 assert.EqualValues(t, "blue", s)
90 }
91
92 {
93 var ps *string
94
95 err := enumType.Set("blue")
96 require.NoError(t, err)
97
98 err = enumType.AssignTo(&ps)
99 require.NoError(t, err)
100
101 assert.EqualValues(t, "blue", *ps)
102 }
103
104 {
105 var ps *string
106
107 err := enumType.Set(nil)
108 require.NoError(t, err)
109
110 err = enumType.AssignTo(&ps)
111 require.NoError(t, err)
112
113 assert.EqualValues(t, (*string)(nil), ps)
114 }
115
116 var buf []byte
117 bytesTests := []struct {
118 src interface{}
119 dst *[]byte
120 expected []byte
121 }{
122 {src: "blue", dst: &buf, expected: []byte("blue")},
123 {src: nil, dst: &buf, expected: nil},
124 }
125
126 for i, tt := range bytesTests {
127 err := enumType.Set(tt.src)
128 require.NoError(t, err, "%d", i)
129
130 err = enumType.AssignTo(tt.dst)
131 require.NoError(t, err, "%d", i)
132
133 if bytes.Compare(*tt.dst, tt.expected) != 0 {
134 t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst)
135 }
136 }
137
138 {
139 var s string
140
141 err := enumType.Set(nil)
142 require.NoError(t, err)
143
144 err = enumType.AssignTo(&s)
145 require.Error(t, err)
146 }
147
148 }
149
View as plain text