...
1 package uuid
2
3 import (
4 "database/sql/driver"
5 "errors"
6 "fmt"
7
8 "github.com/gofrs/uuid"
9 "github.com/jackc/pgtype"
10 )
11
12 var errUndefined = errors.New("cannot encode status undefined")
13 var errBadStatus = errors.New("invalid status")
14
15 type UUID struct {
16 UUID uuid.UUID
17 Status pgtype.Status
18 }
19
20 func (dst *UUID) Set(src interface{}) error {
21 if src == nil {
22 *dst = UUID{Status: pgtype.Null}
23 return nil
24 }
25
26 if value, ok := src.(interface{ Get() interface{} }); ok {
27 value2 := value.Get()
28 if value2 != value {
29 return dst.Set(value2)
30 }
31 }
32
33 switch value := src.(type) {
34 case uuid.UUID:
35 *dst = UUID{UUID: value, Status: pgtype.Present}
36 case [16]byte:
37 *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present}
38 case []byte:
39 if len(value) != 16 {
40 return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value))
41 }
42 *dst = UUID{Status: pgtype.Present}
43 copy(dst.UUID[:], value)
44 case string:
45 uuid, err := uuid.FromString(value)
46 if err != nil {
47 return err
48 }
49 *dst = UUID{UUID: uuid, Status: pgtype.Present}
50 default:
51
52 pgUUID := &pgtype.UUID{}
53 if err := pgUUID.Set(value); err != nil {
54 return fmt.Errorf("cannot convert %v to UUID", value)
55 }
56
57 *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status}
58 }
59
60 return nil
61 }
62
63 func (dst UUID) Get() interface{} {
64 switch dst.Status {
65 case pgtype.Present:
66 return dst.UUID
67 case pgtype.Null:
68 return nil
69 default:
70 return dst.Status
71 }
72 }
73
74 func (src *UUID) AssignTo(dst interface{}) error {
75 switch src.Status {
76 case pgtype.Present:
77 switch v := dst.(type) {
78 case *uuid.UUID:
79 *v = src.UUID
80 return nil
81 case *[16]byte:
82 *v = [16]byte(src.UUID)
83 return nil
84 case *[]byte:
85 *v = make([]byte, 16)
86 copy(*v, src.UUID[:])
87 return nil
88 case *string:
89 *v = src.UUID.String()
90 return nil
91 default:
92 if nextDst, retry := pgtype.GetAssignToDstType(v); retry {
93 return src.AssignTo(nextDst)
94 }
95 return fmt.Errorf("unable to assign to %T", dst)
96 }
97 case pgtype.Null:
98 return pgtype.NullAssignTo(dst)
99 }
100
101 return fmt.Errorf("cannot assign %v into %T", src, dst)
102 }
103
104 func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
105 if src == nil {
106 *dst = UUID{Status: pgtype.Null}
107 return nil
108 }
109
110 u, err := uuid.FromString(string(src))
111 if err != nil {
112 return err
113 }
114
115 *dst = UUID{UUID: u, Status: pgtype.Present}
116 return nil
117 }
118
119 func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
120 if src == nil {
121 *dst = UUID{Status: pgtype.Null}
122 return nil
123 }
124
125 if len(src) != 16 {
126 return fmt.Errorf("invalid length for UUID: %v", len(src))
127 }
128
129 *dst = UUID{Status: pgtype.Present}
130 copy(dst.UUID[:], src)
131 return nil
132 }
133
134 func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
135 switch src.Status {
136 case pgtype.Null:
137 return nil, nil
138 case pgtype.Undefined:
139 return nil, errUndefined
140 }
141
142 return append(buf, src.UUID.String()...), nil
143 }
144
145 func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
146 switch src.Status {
147 case pgtype.Null:
148 return nil, nil
149 case pgtype.Undefined:
150 return nil, errUndefined
151 }
152
153 return append(buf, src.UUID[:]...), nil
154 }
155
156
157 func (dst *UUID) Scan(src interface{}) error {
158 if src == nil {
159 *dst = UUID{Status: pgtype.Null}
160 return nil
161 }
162
163 switch src := src.(type) {
164 case string:
165 return dst.DecodeText(nil, []byte(src))
166 case []byte:
167 return dst.DecodeText(nil, src)
168 }
169
170 return fmt.Errorf("cannot scan %T", src)
171 }
172
173
174 func (src UUID) Value() (driver.Value, error) {
175 return pgtype.EncodeValueText(src)
176 }
177
178 func (src UUID) MarshalJSON() ([]byte, error) {
179 switch src.Status {
180 case pgtype.Present:
181 return []byte(`"` + src.UUID.String() + `"`), nil
182 case pgtype.Null:
183 return []byte("null"), nil
184 case pgtype.Undefined:
185 return nil, errUndefined
186 }
187
188 return nil, errBadStatus
189 }
190
191 func (dst *UUID) UnmarshalJSON(b []byte) error {
192 u := uuid.NullUUID{}
193 err := u.UnmarshalJSON(b)
194 if err != nil {
195 return err
196 }
197
198 status := pgtype.Null
199 if u.Valid {
200 status = pgtype.Present
201 }
202 *dst = UUID{UUID: u.UUID, Status: status}
203
204 return nil
205 }
206
View as plain text