1
2
3
4
5 package json
6
7 import (
8 "errors"
9 "fmt"
10 "reflect"
11 "sync"
12 )
13
14
15
16
17
18
19
20 const SkipFunc = jsonError("skip function")
21
22
23
24
25 type Marshalers = typedMarshalers
26
27
28
29
30
31
32
33
34
35
36
37
38 func NewMarshalers(ms ...*Marshalers) *Marshalers {
39 return newMarshalers(ms...)
40 }
41
42
43
44
45 type Unmarshalers = typedUnmarshalers
46
47
48
49
50
51
52
53
54
55
56
57
58 func NewUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
59 return newUnmarshalers(us...)
60 }
61
62 type typedMarshalers = typedArshalers[MarshalOptions, Encoder]
63 type typedUnmarshalers = typedArshalers[UnmarshalOptions, Decoder]
64 type typedArshalers[Options, Coder any] struct {
65 nonComparable
66
67 fncVals []typedArshaler[Options, Coder]
68 fncCache sync.Map
69
70
71
72
73
74
75
76
77
78
79 fromAny bool
80 }
81 type typedMarshaler = typedArshaler[MarshalOptions, Encoder]
82 type typedUnmarshaler = typedArshaler[UnmarshalOptions, Decoder]
83 type typedArshaler[Options, Coder any] struct {
84 typ reflect.Type
85 fnc func(Options, *Coder, addressableValue) error
86 maySkip bool
87 }
88
89 func newMarshalers(ms ...*Marshalers) *Marshalers { return newTypedArshalers(ms...) }
90 func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
91 func newTypedArshalers[Options, Coder any](as ...*typedArshalers[Options, Coder]) *typedArshalers[Options, Coder] {
92 var a typedArshalers[Options, Coder]
93 for _, a2 := range as {
94 if a2 != nil {
95 a.fncVals = append(a.fncVals, a2.fncVals...)
96 a.fromAny = a.fromAny || a2.fromAny
97 }
98 }
99 if len(a.fncVals) == 0 {
100 return nil
101 }
102 return &a
103 }
104
105 func (a *typedArshalers[Options, Coder]) lookup(fnc func(Options, *Coder, addressableValue) error, t reflect.Type) (func(Options, *Coder, addressableValue) error, bool) {
106 if a == nil {
107 return fnc, false
108 }
109 if v, ok := a.fncCache.Load(t); ok {
110 if v == nil {
111 return fnc, false
112 }
113 return v.(func(Options, *Coder, addressableValue) error), true
114 }
115
116
117
118 var fncs []func(Options, *Coder, addressableValue) error
119 for _, fncVal := range a.fncVals {
120 if !castableTo(t, fncVal.typ) {
121 continue
122 }
123 fncs = append(fncs, fncVal.fnc)
124 if !fncVal.maySkip {
125 break
126 }
127 }
128
129 if len(fncs) == 0 {
130 a.fncCache.Store(t, nil)
131 return fnc, false
132 }
133
134
135 fncDefault := fnc
136 fnc = func(o Options, c *Coder, v addressableValue) error {
137 for _, fnc := range fncs {
138 if err := fnc(o, c, v); err != SkipFunc {
139 return err
140 }
141 }
142 return fncDefault(o, c, v)
143 }
144
145
146 v, _ := a.fncCache.LoadOrStore(t, fnc)
147 return v.(func(Options, *Coder, addressableValue) error), true
148 }
149
150
151
152
153
154
155
156
157
158
159 func MarshalFuncV1[T any](fn func(T) ([]byte, error)) *Marshalers {
160 t := reflect.TypeOf((*T)(nil)).Elem()
161 assertCastableTo(t, true)
162 typFnc := typedMarshaler{
163 typ: t,
164 fnc: func(mo MarshalOptions, enc *Encoder, va addressableValue) error {
165 val, err := fn(va.castTo(t).Interface().(T))
166 if err != nil {
167 err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
168
169 return &SemanticError{action: "marshal", GoType: t, Err: err}
170 }
171 if err := enc.WriteValue(val); err != nil {
172
173 return &SemanticError{action: "marshal", JSONKind: RawValue(val).Kind(), GoType: t, Err: err}
174 }
175 return nil
176 },
177 }
178 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
179 }
180
181
182
183
184
185
186
187
188
189
190
191
192
193 func MarshalFuncV2[T any](fn func(MarshalOptions, *Encoder, T) error) *Marshalers {
194 t := reflect.TypeOf((*T)(nil)).Elem()
195 assertCastableTo(t, true)
196 typFnc := typedMarshaler{
197 typ: t,
198 fnc: func(mo MarshalOptions, enc *Encoder, va addressableValue) error {
199 prevDepth, prevLength := enc.tokens.depthLength()
200 err := fn(mo, enc, va.castTo(t).Interface().(T))
201 currDepth, currLength := enc.tokens.depthLength()
202 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
203 err = errors.New("must write exactly one JSON value")
204 }
205 if err != nil {
206 if err == SkipFunc {
207 if prevDepth == currDepth && prevLength == currLength {
208 return SkipFunc
209 }
210 err = errors.New("must not write any JSON tokens when skipping")
211 }
212
213 return &SemanticError{action: "marshal", GoType: t, Err: err}
214 }
215 return nil
216 },
217 maySkip: true,
218 }
219 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
220 }
221
222
223
224
225
226
227
228
229
230
231 func UnmarshalFuncV1[T any](fn func([]byte, T) error) *Unmarshalers {
232 t := reflect.TypeOf((*T)(nil)).Elem()
233 assertCastableTo(t, false)
234 typFnc := typedUnmarshaler{
235 typ: t,
236 fnc: func(uo UnmarshalOptions, dec *Decoder, va addressableValue) error {
237 val, err := dec.ReadValue()
238 if err != nil {
239 return err
240 }
241 err = fn(val, va.castTo(t).Interface().(T))
242 if err != nil {
243 err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
244
245 return &SemanticError{action: "unmarshal", JSONKind: val.Kind(), GoType: t, Err: err}
246 }
247 return nil
248 },
249 }
250 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
251 }
252
253
254
255
256
257
258
259
260
261
262
263
264 func UnmarshalFuncV2[T any](fn func(UnmarshalOptions, *Decoder, T) error) *Unmarshalers {
265 t := reflect.TypeOf((*T)(nil)).Elem()
266 assertCastableTo(t, false)
267 typFnc := typedUnmarshaler{
268 typ: t,
269 fnc: func(uo UnmarshalOptions, dec *Decoder, va addressableValue) error {
270 prevDepth, prevLength := dec.tokens.depthLength()
271 err := fn(uo, dec, va.castTo(t).Interface().(T))
272 currDepth, currLength := dec.tokens.depthLength()
273 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
274 err = errors.New("must read exactly one JSON value")
275 }
276 if err != nil {
277 if err == SkipFunc {
278 if prevDepth == currDepth && prevLength == currLength {
279 return SkipFunc
280 }
281 err = errors.New("must not read any JSON tokens when skipping")
282 }
283
284 return &SemanticError{action: "unmarshal", GoType: t, Err: err}
285 }
286 return nil
287 },
288 maySkip: true,
289 }
290 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
291 }
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307 func assertCastableTo(to reflect.Type, marshal bool) {
308 switch to.Kind() {
309 case reflect.Interface:
310 return
311 case reflect.Pointer:
312
313
314 if to.Name() == "" {
315 return
316 }
317 default:
318
319
320
321 if marshal {
322 return
323 }
324 }
325 if marshal {
326 panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
327 } else {
328 panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
329 }
330 }
331
332
333
334
335
336 func castableTo(from, to reflect.Type) bool {
337 switch to.Kind() {
338 case reflect.Interface:
339
340
341
342
343 return reflect.PointerTo(from).Implements(to)
344 case reflect.Pointer:
345
346
347 return reflect.PointerTo(from) == to
348 default:
349
350
351 return from == to
352 }
353 }
354
355
356
357
358
359
360 func (va addressableValue) castTo(to reflect.Type) reflect.Value {
361 switch to.Kind() {
362 case reflect.Interface:
363 return va.Addr().Convert(to)
364 case reflect.Pointer:
365 return va.Addr()
366 default:
367 return va.Value
368 }
369 }
370
371
372
373 func castableToFromAny(to reflect.Type) bool {
374 for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
375 if castableTo(from, to) {
376 return true
377 }
378 }
379 return false
380 }
381
382 func wrapSkipFunc(err error, what string) error {
383 if err == SkipFunc {
384 return errors.New(what + " cannot be skipped")
385 }
386 return err
387 }
388
View as plain text