1 package sqlxx
2
3 import (
4 "bytes"
5 "database/sql"
6 "database/sql/driver"
7 "encoding/json"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/pkg/errors"
13
14 "github.com/ory/x/stringsx"
15 )
16
17
18 type StringSlicePipeDelimiter []string
19
20
21 func (n *StringSlicePipeDelimiter) Scan(value interface{}) error {
22 var s sql.NullString
23 if err := s.Scan(value); err != nil {
24 return err
25 }
26 *n = scanStringSlice('|', s.String)
27 return nil
28 }
29
30
31 func (n StringSlicePipeDelimiter) Value() (driver.Value, error) {
32 return valueStringSlice('|', n), nil
33 }
34
35 func scanStringSlice(delimiter rune, value interface{}) []string {
36 return stringsx.Splitx(fmt.Sprintf("%s", value), string(delimiter))
37 }
38
39 func valueStringSlice(delimiter rune, value []string) string {
40 return strings.Join(value, string(delimiter))
41 }
42
43
44 type NullString string
45
46
47 func (ns NullString) MarshalJSON() ([]byte, error) {
48 return json.Marshal(string(ns))
49 }
50
51
52 func (ns *NullString) UnmarshalJSON(data []byte) error {
53 if ns == nil {
54 return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
55 }
56 if len(data) == 0 {
57 return nil
58 }
59 return errors.WithStack(json.Unmarshal(data, ns))
60 }
61
62
63 func (ns *NullString) Scan(value interface{}) error {
64 var v sql.NullString
65 if err := (&v).Scan(value); err != nil {
66 return err
67 }
68 *ns = NullString(v.String)
69 return nil
70 }
71
72
73 func (ns NullString) Value() (driver.Value, error) {
74 if len(ns) == 0 {
75 return sql.NullString{}.Value()
76 }
77 return sql.NullString{Valid: true, String: string(ns)}.Value()
78 }
79
80
81 func (ns NullString) String() string {
82 return string(ns)
83 }
84
85
86 type NullTime time.Time
87
88
89 func (ns *NullTime) Scan(value interface{}) error {
90 var v sql.NullTime
91 if err := (&v).Scan(value); err != nil {
92 return err
93 }
94 *ns = NullTime(v.Time)
95 return nil
96 }
97
98
99 func (ns NullTime) MarshalJSON() ([]byte, error) {
100 var t *time.Time
101 if !time.Time(ns).IsZero() {
102 tt := time.Time(ns)
103 t = &tt
104 }
105 return json.Marshal(t)
106 }
107
108
109 func (ns *NullTime) UnmarshalJSON(data []byte) error {
110 var t time.Time
111 if err := json.Unmarshal(data, &t); err != nil {
112 return err
113 }
114 *ns = NullTime(t)
115 return nil
116 }
117
118
119 func (ns NullTime) Value() (driver.Value, error) {
120 return sql.NullTime{Valid: !time.Time(ns).IsZero(), Time: time.Time(ns)}.Value()
121 }
122
123
124 type MapStringInterface map[string]interface{}
125
126
127 func (n *MapStringInterface) Scan(value interface{}) error {
128 v := fmt.Sprintf("%s", value)
129 if len(v) == 0 {
130 return nil
131 }
132 return errors.WithStack(json.Unmarshal([]byte(v), n))
133 }
134
135
136 func (n MapStringInterface) Value() (driver.Value, error) {
137 value, err := json.Marshal(n)
138 if err != nil {
139 return nil, errors.WithStack(err)
140 }
141 return string(value), nil
142 }
143
144
145 type JSONRawMessage json.RawMessage
146
147
148 func (m *JSONRawMessage) Scan(value interface{}) error {
149 *m = []byte(fmt.Sprintf("%s", value))
150 return nil
151 }
152
153
154 func (m JSONRawMessage) Value() (driver.Value, error) {
155 if len(m) == 0 {
156 return "null", nil
157 }
158 return string(m), nil
159 }
160
161
162 func (m JSONRawMessage) MarshalJSON() ([]byte, error) {
163 if len(m) == 0 {
164 return []byte("null"), nil
165 }
166 return m, nil
167 }
168
169
170 func (m *JSONRawMessage) UnmarshalJSON(data []byte) error {
171 if m == nil {
172 return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
173 }
174 *m = append((*m)[0:0], data...)
175 return nil
176 }
177
178
179 type NullJSONRawMessage json.RawMessage
180
181
182 func (m *NullJSONRawMessage) Scan(value interface{}) error {
183 if value == nil {
184 value = "null"
185 }
186 *m = []byte(fmt.Sprintf("%s", value))
187 return nil
188 }
189
190
191 func (m NullJSONRawMessage) Value() (driver.Value, error) {
192 if len(m) == 0 {
193 return nil, nil
194 }
195 return string(m), nil
196 }
197
198
199 func (m NullJSONRawMessage) MarshalJSON() ([]byte, error) {
200 if len(m) == 0 {
201 return []byte("null"), nil
202 }
203 return m, nil
204 }
205
206
207 func (m *NullJSONRawMessage) UnmarshalJSON(data []byte) error {
208 if m == nil {
209 return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
210 }
211 *m = append((*m)[0:0], data...)
212 return nil
213 }
214
215
216 func JSONScan(dst interface{}, value interface{}) error {
217 if value == nil {
218 value = "null"
219 }
220 if err := json.Unmarshal([]byte(fmt.Sprintf("%s", value)), &dst); err != nil {
221 return fmt.Errorf("unable to decode payload to: %s", err)
222 }
223 return nil
224 }
225
226
227 func JSONValue(src interface{}) (driver.Value, error) {
228 if src == nil {
229 return nil, nil
230 }
231 var b bytes.Buffer
232 if err := json.NewEncoder(&b).Encode(&src); err != nil {
233 return nil, err
234 }
235 return b.String(), nil
236 }
237
View as plain text