1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/hex"
7 "fmt"
8 )
9
10 type UUIDScanner interface {
11 ScanUUID(v UUID) error
12 }
13
14 type UUIDValuer interface {
15 UUIDValue() (UUID, error)
16 }
17
18 type UUID struct {
19 Bytes [16]byte
20 Valid bool
21 }
22
23 func (b *UUID) ScanUUID(v UUID) error {
24 *b = v
25 return nil
26 }
27
28 func (b UUID) UUIDValue() (UUID, error) {
29 return b, nil
30 }
31
32
33 func parseUUID(src string) (dst [16]byte, err error) {
34 switch len(src) {
35 case 36:
36 src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:]
37 case 32:
38
39 default:
40
41 return dst, fmt.Errorf("cannot parse UUID %v", src)
42 }
43
44 buf, err := hex.DecodeString(src)
45 if err != nil {
46 return dst, err
47 }
48
49 copy(dst[:], buf)
50 return dst, err
51 }
52
53
54 func encodeUUID(src [16]byte) string {
55 var buf [36]byte
56
57 hex.Encode(buf[0:8], src[:4])
58 buf[8] = '-'
59 hex.Encode(buf[9:13], src[4:6])
60 buf[13] = '-'
61 hex.Encode(buf[14:18], src[6:8])
62 buf[18] = '-'
63 hex.Encode(buf[19:23], src[8:10])
64 buf[23] = '-'
65 hex.Encode(buf[24:], src[10:])
66
67 return string(buf[:])
68 }
69
70
71 func (dst *UUID) Scan(src any) error {
72 if src == nil {
73 *dst = UUID{}
74 return nil
75 }
76
77 switch src := src.(type) {
78 case string:
79 buf, err := parseUUID(src)
80 if err != nil {
81 return err
82 }
83 *dst = UUID{Bytes: buf, Valid: true}
84 return nil
85 }
86
87 return fmt.Errorf("cannot scan %T", src)
88 }
89
90
91 func (src UUID) Value() (driver.Value, error) {
92 if !src.Valid {
93 return nil, nil
94 }
95
96 return encodeUUID(src.Bytes), nil
97 }
98
99 func (src UUID) MarshalJSON() ([]byte, error) {
100 if !src.Valid {
101 return []byte("null"), nil
102 }
103
104 var buff bytes.Buffer
105 buff.WriteByte('"')
106 buff.WriteString(encodeUUID(src.Bytes))
107 buff.WriteByte('"')
108 return buff.Bytes(), nil
109 }
110
111 func (dst *UUID) UnmarshalJSON(src []byte) error {
112 if bytes.Equal(src, []byte("null")) {
113 *dst = UUID{}
114 return nil
115 }
116 if len(src) != 38 {
117 return fmt.Errorf("invalid length for UUID: %v", len(src))
118 }
119 buf, err := parseUUID(string(src[1 : len(src)-1]))
120 if err != nil {
121 return err
122 }
123 *dst = UUID{Bytes: buf, Valid: true}
124 return nil
125 }
126
127 type UUIDCodec struct{}
128
129 func (UUIDCodec) FormatSupported(format int16) bool {
130 return format == TextFormatCode || format == BinaryFormatCode
131 }
132
133 func (UUIDCodec) PreferredFormat() int16 {
134 return BinaryFormatCode
135 }
136
137 func (UUIDCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
138 if _, ok := value.(UUIDValuer); !ok {
139 return nil
140 }
141
142 switch format {
143 case BinaryFormatCode:
144 return encodePlanUUIDCodecBinaryUUIDValuer{}
145 case TextFormatCode:
146 return encodePlanUUIDCodecTextUUIDValuer{}
147 }
148
149 return nil
150 }
151
152 type encodePlanUUIDCodecBinaryUUIDValuer struct{}
153
154 func (encodePlanUUIDCodecBinaryUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
155 uuid, err := value.(UUIDValuer).UUIDValue()
156 if err != nil {
157 return nil, err
158 }
159
160 if !uuid.Valid {
161 return nil, nil
162 }
163
164 return append(buf, uuid.Bytes[:]...), nil
165 }
166
167 type encodePlanUUIDCodecTextUUIDValuer struct{}
168
169 func (encodePlanUUIDCodecTextUUIDValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
170 uuid, err := value.(UUIDValuer).UUIDValue()
171 if err != nil {
172 return nil, err
173 }
174
175 if !uuid.Valid {
176 return nil, nil
177 }
178
179 return append(buf, encodeUUID(uuid.Bytes)...), nil
180 }
181
182 func (UUIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
183 switch format {
184 case BinaryFormatCode:
185 switch target.(type) {
186 case UUIDScanner:
187 return scanPlanBinaryUUIDToUUIDScanner{}
188 case TextScanner:
189 return scanPlanBinaryUUIDToTextScanner{}
190 }
191 case TextFormatCode:
192 switch target.(type) {
193 case UUIDScanner:
194 return scanPlanTextAnyToUUIDScanner{}
195 }
196 }
197
198 return nil
199 }
200
201 type scanPlanBinaryUUIDToUUIDScanner struct{}
202
203 func (scanPlanBinaryUUIDToUUIDScanner) Scan(src []byte, dst any) error {
204 scanner := (dst).(UUIDScanner)
205
206 if src == nil {
207 return scanner.ScanUUID(UUID{})
208 }
209
210 if len(src) != 16 {
211 return fmt.Errorf("invalid length for UUID: %v", len(src))
212 }
213
214 uuid := UUID{Valid: true}
215 copy(uuid.Bytes[:], src)
216
217 return scanner.ScanUUID(uuid)
218 }
219
220 type scanPlanBinaryUUIDToTextScanner struct{}
221
222 func (scanPlanBinaryUUIDToTextScanner) Scan(src []byte, dst any) error {
223 scanner := (dst).(TextScanner)
224
225 if src == nil {
226 return scanner.ScanText(Text{})
227 }
228
229 if len(src) != 16 {
230 return fmt.Errorf("invalid length for UUID: %v", len(src))
231 }
232
233 var buf [16]byte
234 copy(buf[:], src)
235
236 return scanner.ScanText(Text{String: encodeUUID(buf), Valid: true})
237 }
238
239 type scanPlanTextAnyToUUIDScanner struct{}
240
241 func (scanPlanTextAnyToUUIDScanner) Scan(src []byte, dst any) error {
242 scanner := (dst).(UUIDScanner)
243
244 if src == nil {
245 return scanner.ScanUUID(UUID{})
246 }
247
248 buf, err := parseUUID(string(src))
249 if err != nil {
250 return err
251 }
252
253 return scanner.ScanUUID(UUID{Bytes: buf, Valid: true})
254 }
255
256 func (c UUIDCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
257 if src == nil {
258 return nil, nil
259 }
260
261 var uuid UUID
262 err := codecScan(c, m, oid, format, src, &uuid)
263 if err != nil {
264 return nil, err
265 }
266
267 return encodeUUID(uuid.Bytes), nil
268 }
269
270 func (c UUIDCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
271 if src == nil {
272 return nil, nil
273 }
274
275 var uuid UUID
276 err := codecScan(c, m, oid, format, src, &uuid)
277 if err != nil {
278 return nil, err
279 }
280 return uuid.Bytes, nil
281 }
282
View as plain text