1 package pgxtype
2
3 import (
4 "context"
5 "errors"
6
7 "github.com/jackc/pgconn"
8 "github.com/jackc/pgtype"
9 "github.com/jackc/pgx/v4"
10 )
11
12 type Querier interface {
13 Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
14 Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
15 QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
16 }
17
18
19
20 func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
21 var oid uint32
22
23 err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
24 if err != nil {
25 return pgtype.DataType{}, err
26 }
27
28 var typtype string
29
30 err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
31 if err != nil {
32 return pgtype.DataType{}, err
33 }
34
35 switch typtype {
36 case "b":
37 elementOID, err := GetArrayElementOID(ctx, conn, oid)
38 if err != nil {
39 return pgtype.DataType{}, err
40 }
41
42 var element pgtype.ValueTranscoder
43 if dt, ok := ci.DataTypeForOID(elementOID); ok {
44 if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok {
45 return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder")
46 }
47 }
48
49 newElement := func() pgtype.ValueTranscoder {
50 return pgtype.NewValue(element).(pgtype.ValueTranscoder)
51 }
52
53 at := pgtype.NewArrayType(typeName, elementOID, newElement)
54 return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil
55 case "c":
56 fields, err := GetCompositeFields(ctx, conn, oid)
57 if err != nil {
58 return pgtype.DataType{}, err
59 }
60 ct, err := pgtype.NewCompositeType(typeName, fields, ci)
61 if err != nil {
62 return pgtype.DataType{}, err
63 }
64 return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
65 case "e":
66 members, err := GetEnumMembers(ctx, conn, oid)
67 if err != nil {
68 return pgtype.DataType{}, err
69 }
70 return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
71 default:
72 return pgtype.DataType{}, errors.New("unknown typtype")
73 }
74 }
75
76 func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
77 var typelem uint32
78
79 err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
80 if err != nil {
81 return 0, err
82 }
83
84 return typelem, nil
85 }
86
87
88 func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
89 var typrelid uint32
90
91 err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
92 if err != nil {
93 return nil, err
94 }
95
96 var fields []pgtype.CompositeTypeField
97
98 rows, err := conn.Query(ctx, `select attname, atttypid
99 from pg_attribute
100 where attrelid=$1
101 order by attnum`, typrelid)
102 if err != nil {
103 return nil, err
104 }
105
106 for rows.Next() {
107 var f pgtype.CompositeTypeField
108 err := rows.Scan(&f.Name, &f.OID)
109 if err != nil {
110 return nil, err
111 }
112 fields = append(fields, f)
113 }
114
115 if rows.Err() != nil {
116 return nil, rows.Err()
117 }
118
119 return fields, nil
120 }
121
122
123 func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
124 members := []string{}
125
126 rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
127 if err != nil {
128 return nil, err
129 }
130
131 for rows.Next() {
132 var m string
133 err := rows.Scan(&m)
134 if err != nil {
135 return nil, err
136 }
137 members = append(members, m)
138 }
139
140 if rows.Err() != nil {
141 return nil, rows.Err()
142 }
143
144 return members, nil
145 }
146
View as plain text