1 package pgtype
2
3 import (
4 "bytes"
5 "database/sql/driver"
6 "encoding/hex"
7 "fmt"
8 )
9
10 type UUID struct {
11 Bytes [16]byte
12 Status Status
13 }
14
15 func (dst *UUID) Set(src interface{}) error {
16 if src == nil {
17 *dst = UUID{Status: Null}
18 return nil
19 }
20
21 switch value := src.(type) {
22 case interface{ Get() interface{} }:
23 value2 := value.Get()
24 if value2 != value {
25 return dst.Set(value2)
26 }
27 case fmt.Stringer:
28 value2 := value.String()
29 return dst.Set(value2)
30 case [16]byte:
31 *dst = UUID{Bytes: value, Status: Present}
32 case []byte:
33 if value != nil {
34 if len(value) != 16 {
35 return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value))
36 }
37 *dst = UUID{Status: Present}
38 copy(dst.Bytes[:], value)
39 } else {
40 *dst = UUID{Status: Null}
41 }
42 case string:
43 uuid, err := parseUUID(value)
44 if err != nil {
45 return err
46 }
47 *dst = UUID{Bytes: uuid, Status: Present}
48 case *string:
49 if value == nil {
50 *dst = UUID{Status: Null}
51 } else {
52 return dst.Set(*value)
53 }
54 default:
55 if originalSrc, ok := underlyingUUIDType(src); ok {
56 return dst.Set(originalSrc)
57 }
58 return fmt.Errorf("cannot convert %v to UUID", value)
59 }
60
61 return nil
62 }
63
64 func (dst UUID) Get() interface{} {
65 switch dst.Status {
66 case Present:
67 return dst.Bytes
68 case Null:
69 return nil
70 default:
71 return dst.Status
72 }
73 }
74
75 func (src *UUID) AssignTo(dst interface{}) error {
76 switch src.Status {
77 case Present:
78 switch v := dst.(type) {
79 case *[16]byte:
80 *v = src.Bytes
81 return nil
82 case *[]byte:
83 *v = make([]byte, 16)
84 copy(*v, src.Bytes[:])
85 return nil
86 case *string:
87 *v = encodeUUID(src.Bytes)
88 return nil
89 default:
90 if nextDst, retry := GetAssignToDstType(v); retry {
91 return src.AssignTo(nextDst)
92 }
93 }
94 case Null:
95 return NullAssignTo(dst)
96 }
97
98 return fmt.Errorf("cannot assign %v into %T", src, dst)
99 }
100
101
102 func parseUUID(src string) (dst [16]byte, err error) {
103 switch len(src) {
104 case 36:
105 src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:]
106 case 32:
107
108 default:
109
110 return dst, fmt.Errorf("cannot parse UUID %v", src)
111 }
112
113 buf, err := hex.DecodeString(src)
114 if err != nil {
115 return dst, err
116 }
117
118 copy(dst[:], buf)
119 return dst, err
120 }
121
122
123 func encodeUUID(src [16]byte) string {
124 return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
125 }
126
127 func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error {
128 if src == nil {
129 *dst = UUID{Status: Null}
130 return nil
131 }
132
133 if len(src) != 36 {
134 return fmt.Errorf("invalid length for UUID: %v", len(src))
135 }
136
137 buf, err := parseUUID(string(src))
138 if err != nil {
139 return err
140 }
141
142 *dst = UUID{Bytes: buf, Status: Present}
143 return nil
144 }
145
146 func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error {
147 if src == nil {
148 *dst = UUID{Status: Null}
149 return nil
150 }
151
152 if len(src) != 16 {
153 return fmt.Errorf("invalid length for UUID: %v", len(src))
154 }
155
156 *dst = UUID{Status: Present}
157 copy(dst.Bytes[:], src)
158 return nil
159 }
160
161 func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
162 switch src.Status {
163 case Null:
164 return nil, nil
165 case Undefined:
166 return nil, errUndefined
167 }
168
169 return append(buf, encodeUUID(src.Bytes)...), nil
170 }
171
172 func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
173 switch src.Status {
174 case Null:
175 return nil, nil
176 case Undefined:
177 return nil, errUndefined
178 }
179
180 return append(buf, src.Bytes[:]...), nil
181 }
182
183
184 func (dst *UUID) Scan(src interface{}) error {
185 if src == nil {
186 *dst = UUID{Status: Null}
187 return nil
188 }
189
190 switch src := src.(type) {
191 case string:
192 return dst.DecodeText(nil, []byte(src))
193 case []byte:
194 srcCopy := make([]byte, len(src))
195 copy(srcCopy, src)
196 return dst.DecodeText(nil, srcCopy)
197 }
198
199 return fmt.Errorf("cannot scan %T", src)
200 }
201
202
203 func (src UUID) Value() (driver.Value, error) {
204 return EncodeValueText(src)
205 }
206
207 func (src UUID) MarshalJSON() ([]byte, error) {
208 switch src.Status {
209 case Present:
210 var buff bytes.Buffer
211 buff.WriteByte('"')
212 buff.WriteString(encodeUUID(src.Bytes))
213 buff.WriteByte('"')
214 return buff.Bytes(), nil
215 case Null:
216 return []byte("null"), nil
217 case Undefined:
218 return nil, errUndefined
219 }
220 return nil, errBadStatus
221 }
222
223 func (dst *UUID) UnmarshalJSON(src []byte) error {
224 if bytes.Compare(src, []byte("null")) == 0 {
225 return dst.Set(nil)
226 }
227 if len(src) != 38 {
228 return fmt.Errorf("invalid length for UUID: %v", len(src))
229 }
230 return dst.Set(string(src[1 : len(src)-1]))
231 }
232
View as plain text