1
2
3
4
5
6
7 package bsoncodec
8
9 import (
10 "errors"
11 "fmt"
12 "reflect"
13 "sort"
14 "strings"
15 "sync"
16 "time"
17
18 "go.mongodb.org/mongo-driver/bson/bsonoptions"
19 "go.mongodb.org/mongo-driver/bson/bsonrw"
20 "go.mongodb.org/mongo-driver/bson/bsontype"
21 )
22
23
24 type DecodeError struct {
25 keys []string
26 wrapped error
27 }
28
29
30 func (de *DecodeError) Unwrap() error {
31 return de.wrapped
32 }
33
34
35 func (de *DecodeError) Error() string {
36
37
38 keyPath := strings.Join(de.Keys(), ".")
39 return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
40 }
41
42
43
44
45 func (de *DecodeError) Keys() []string {
46 reversedKeys := make([]string, 0, len(de.keys))
47 for idx := len(de.keys) - 1; idx >= 0; idx-- {
48 reversedKeys = append(reversedKeys, de.keys[idx])
49 }
50
51 return reversedKeys
52 }
53
54
55
56
57 type Zeroer interface {
58 IsZero() bool
59 }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 type StructCodec struct {
80 cache sync.Map
81 parser StructTagParser
82
83
84
85
86
87 DecodeZeroStruct bool
88
89
90
91
92
93 DecodeDeepZeroInline bool
94
95
96
97
98
99
100 EncodeOmitDefaultStruct bool
101
102
103
104
105
106 AllowUnexportedFields bool
107
108
109
110
111
112
113
114 OverwriteDuplicatedInlinedFields bool
115 }
116
117 var _ ValueEncoder = &StructCodec{}
118 var _ ValueDecoder = &StructCodec{}
119
120
121
122
123
124 func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
125 if p == nil {
126 return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
127 }
128
129 structOpt := bsonoptions.MergeStructCodecOptions(opts...)
130
131 codec := &StructCodec{
132 parser: p,
133 }
134
135 if structOpt.DecodeZeroStruct != nil {
136 codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
137 }
138 if structOpt.DecodeDeepZeroInline != nil {
139 codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
140 }
141 if structOpt.EncodeOmitDefaultStruct != nil {
142 codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
143 }
144 if structOpt.OverwriteDuplicatedInlinedFields != nil {
145 codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
146 }
147 if structOpt.AllowUnexportedFields != nil {
148 codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
149 }
150
151 return codec, nil
152 }
153
154
155 func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
156 if !val.IsValid() || val.Kind() != reflect.Struct {
157 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
158 }
159
160 sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates)
161 if err != nil {
162 return err
163 }
164
165 dw, err := vw.WriteDocument()
166 if err != nil {
167 return err
168 }
169 var rv reflect.Value
170 for _, desc := range sd.fl {
171 if desc.inline == nil {
172 rv = val.Field(desc.idx)
173 } else {
174 rv, err = fieldByIndexErr(val, desc.inline)
175 if err != nil {
176 continue
177 }
178 }
179
180 desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv)
181
182 if err != nil && !errors.Is(err, errInvalidValue) {
183 return err
184 }
185
186 if errors.Is(err, errInvalidValue) {
187 if desc.omitEmpty {
188 continue
189 }
190 vw2, err := dw.WriteDocumentElement(desc.name)
191 if err != nil {
192 return err
193 }
194 err = vw2.WriteNull()
195 if err != nil {
196 return err
197 }
198 continue
199 }
200
201 if desc.encoder == nil {
202 return ErrNoEncoder{Type: rv.Type()}
203 }
204
205 encoder := desc.encoder
206
207 var empty bool
208 if cz, ok := encoder.(CodecZeroer); ok {
209 empty = cz.IsTypeZero(rv.Interface())
210 } else if rv.Kind() == reflect.Interface {
211
212
213 empty = rv.IsNil()
214 } else {
215 empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
216 }
217 if desc.omitEmpty && empty {
218 continue
219 }
220
221 vw2, err := dw.WriteDocumentElement(desc.name)
222 if err != nil {
223 return err
224 }
225
226 ectx := EncodeContext{
227 Registry: ec.Registry,
228 MinSize: desc.minSize || ec.MinSize,
229 errorOnInlineDuplicates: ec.errorOnInlineDuplicates,
230 stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt,
231 nilMapAsEmpty: ec.nilMapAsEmpty,
232 nilSliceAsEmpty: ec.nilSliceAsEmpty,
233 nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty,
234 omitZeroStruct: ec.omitZeroStruct,
235 useJSONStructTags: ec.useJSONStructTags,
236 }
237 err = encoder.EncodeValue(ectx, vw2, rv)
238 if err != nil {
239 return err
240 }
241 }
242
243 if sd.inlineMap >= 0 {
244 rv := val.Field(sd.inlineMap)
245 collisionFn := func(key string) bool {
246 _, exists := sd.fm[key]
247 return exists
248 }
249
250 return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn)
251 }
252
253 return dw.WriteDocumentEnd()
254 }
255
256 func newDecodeError(key string, original error) error {
257 var de *DecodeError
258 if !errors.As(original, &de) {
259 return &DecodeError{
260 keys: []string{key},
261 wrapped: original,
262 }
263 }
264
265 de.keys = append(de.keys, key)
266 return de
267 }
268
269
270
271
272 func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
273 if !val.CanSet() || val.Kind() != reflect.Struct {
274 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
275 }
276
277 switch vrType := vr.Type(); vrType {
278 case bsontype.Type(0), bsontype.EmbeddedDocument:
279 case bsontype.Null:
280 if err := vr.ReadNull(); err != nil {
281 return err
282 }
283
284 val.Set(reflect.Zero(val.Type()))
285 return nil
286 case bsontype.Undefined:
287 if err := vr.ReadUndefined(); err != nil {
288 return err
289 }
290
291 val.Set(reflect.Zero(val.Type()))
292 return nil
293 default:
294 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
295 }
296
297 sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false)
298 if err != nil {
299 return err
300 }
301
302 if sc.DecodeZeroStruct || dc.zeroStructs {
303 val.Set(reflect.Zero(val.Type()))
304 }
305 if sc.DecodeDeepZeroInline && sd.inline {
306 val.Set(deepZero(val.Type()))
307 }
308
309 var decoder ValueDecoder
310 var inlineMap reflect.Value
311 if sd.inlineMap >= 0 {
312 inlineMap = val.Field(sd.inlineMap)
313 decoder, err = dc.LookupDecoder(inlineMap.Type().Elem())
314 if err != nil {
315 return err
316 }
317 }
318
319 dr, err := vr.ReadDocument()
320 if err != nil {
321 return err
322 }
323
324 for {
325 name, vr, err := dr.ReadElement()
326 if errors.Is(err, bsonrw.ErrEOD) {
327 break
328 }
329 if err != nil {
330 return err
331 }
332
333 fd, exists := sd.fm[name]
334 if !exists {
335
336
337
338 fd, exists = sd.fm[strings.ToLower(name)]
339 }
340
341 if !exists {
342 if sd.inlineMap < 0 {
343
344
345 err = vr.Skip()
346 if err != nil {
347 return err
348 }
349 continue
350 }
351
352 if inlineMap.IsNil() {
353 inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
354 }
355
356 elem := reflect.New(inlineMap.Type().Elem()).Elem()
357 dc.Ancestor = inlineMap.Type()
358 err = decoder.DecodeValue(dc, vr, elem)
359 if err != nil {
360 return err
361 }
362 inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
363 continue
364 }
365
366 var field reflect.Value
367 if fd.inline == nil {
368 field = val.Field(fd.idx)
369 } else {
370 field, err = getInlineField(val, fd.inline)
371 if err != nil {
372 return err
373 }
374 }
375
376 if !field.CanSet() {
377 innerErr := fmt.Errorf("field %v is not settable", field)
378 return newDecodeError(fd.name, innerErr)
379 }
380 if field.Kind() == reflect.Ptr && field.IsNil() {
381 field.Set(reflect.New(field.Type().Elem()))
382 }
383 field = field.Addr()
384
385 dctx := DecodeContext{
386 Registry: dc.Registry,
387 Truncate: fd.truncate || dc.Truncate,
388 defaultDocumentType: dc.defaultDocumentType,
389 binaryAsSlice: dc.binaryAsSlice,
390 useJSONStructTags: dc.useJSONStructTags,
391 useLocalTimeZone: dc.useLocalTimeZone,
392 zeroMaps: dc.zeroMaps,
393 zeroStructs: dc.zeroStructs,
394 }
395
396 if fd.decoder == nil {
397 return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
398 }
399
400 err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
401 if err != nil {
402 return newDecodeError(fd.name, err)
403 }
404 }
405
406 return nil
407 }
408
409 func isEmpty(v reflect.Value, omitZeroStruct bool) bool {
410 kind := v.Kind()
411 if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
412 return v.Interface().(Zeroer).IsZero()
413 }
414 switch kind {
415 case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
416 return v.Len() == 0
417 case reflect.Struct:
418 if !omitZeroStruct {
419 return false
420 }
421 vt := v.Type()
422 if vt == tTime {
423 return v.Interface().(time.Time).IsZero()
424 }
425 numField := vt.NumField()
426 for i := 0; i < numField; i++ {
427 ff := vt.Field(i)
428 if ff.PkgPath != "" && !ff.Anonymous {
429 continue
430 }
431 if !isEmpty(v.Field(i), omitZeroStruct) {
432 return false
433 }
434 }
435 return true
436 }
437 return !v.IsValid() || v.IsZero()
438 }
439
440 type structDescription struct {
441 fm map[string]fieldDescription
442 fl []fieldDescription
443 inlineMap int
444 inline bool
445 }
446
447 type fieldDescription struct {
448 name string
449 fieldName string
450 idx int
451 omitEmpty bool
452 minSize bool
453 truncate bool
454 inline []int
455 encoder ValueEncoder
456 decoder ValueDecoder
457 }
458
459 type byIndex []fieldDescription
460
461 func (bi byIndex) Len() int { return len(bi) }
462
463 func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
464
465 func (bi byIndex) Less(i, j int) bool {
466
467 iIdx, jIdx := bi[i].idx, bi[j].idx
468 if len(bi[i].inline) > 0 {
469 iIdx = bi[i].inline[0]
470 }
471 if len(bi[j].inline) > 0 {
472 jIdx = bi[j].inline[0]
473 }
474 if iIdx != jIdx {
475 return iIdx < jIdx
476 }
477 for k, biik := range bi[i].inline {
478 if k >= len(bi[j].inline) {
479 return false
480 }
481 if biik != bi[j].inline[k] {
482 return biik < bi[j].inline[k]
483 }
484 }
485 return len(bi[i].inline) < len(bi[j].inline)
486 }
487
488 func (sc *StructCodec) describeStruct(
489 r *Registry,
490 t reflect.Type,
491 useJSONStructTags bool,
492 errorOnDuplicates bool,
493 ) (*structDescription, error) {
494
495
496 if v, ok := sc.cache.Load(t); ok {
497 return v.(*structDescription), nil
498 }
499
500
501 ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates)
502 if err != nil {
503 return nil, err
504 }
505 if v, loaded := sc.cache.LoadOrStore(t, ds); loaded {
506 ds = v.(*structDescription)
507 }
508 return ds, nil
509 }
510
511 func (sc *StructCodec) describeStructSlow(
512 r *Registry,
513 t reflect.Type,
514 useJSONStructTags bool,
515 errorOnDuplicates bool,
516 ) (*structDescription, error) {
517 numFields := t.NumField()
518 sd := &structDescription{
519 fm: make(map[string]fieldDescription, numFields),
520 fl: make([]fieldDescription, 0, numFields),
521 inlineMap: -1,
522 }
523
524 var fields []fieldDescription
525 for i := 0; i < numFields; i++ {
526 sf := t.Field(i)
527 if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
528
529 continue
530 }
531
532 sfType := sf.Type
533 encoder, err := r.LookupEncoder(sfType)
534 if err != nil {
535 encoder = nil
536 }
537 decoder, err := r.LookupDecoder(sfType)
538 if err != nil {
539 decoder = nil
540 }
541
542 description := fieldDescription{
543 fieldName: sf.Name,
544 idx: i,
545 encoder: encoder,
546 decoder: decoder,
547 }
548
549 var stags StructTags
550
551
552 if useJSONStructTags {
553 stags, err = JSONFallbackStructTagParser.ParseStructTags(sf)
554 } else {
555 stags, err = sc.parser.ParseStructTags(sf)
556 }
557 if err != nil {
558 return nil, err
559 }
560 if stags.Skip {
561 continue
562 }
563 description.name = stags.Name
564 description.omitEmpty = stags.OmitEmpty
565 description.minSize = stags.MinSize
566 description.truncate = stags.Truncate
567
568 if stags.Inline {
569 sd.inline = true
570 switch sfType.Kind() {
571 case reflect.Map:
572 if sd.inlineMap >= 0 {
573 return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
574 }
575 if sfType.Key() != tString {
576 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
577 }
578 sd.inlineMap = description.idx
579 case reflect.Ptr:
580 sfType = sfType.Elem()
581 if sfType.Kind() != reflect.Struct {
582 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
583 }
584 fallthrough
585 case reflect.Struct:
586 inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates)
587 if err != nil {
588 return nil, err
589 }
590 for _, fd := range inlinesf.fl {
591 if fd.inline == nil {
592 fd.inline = []int{i, fd.idx}
593 } else {
594 fd.inline = append([]int{i}, fd.inline...)
595 }
596 fields = append(fields, fd)
597
598 }
599 default:
600 return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
601 }
602 continue
603 }
604 fields = append(fields, description)
605 }
606
607
608 sort.Slice(fields, func(i, j int) bool {
609 x := fields
610
611
612 if x[i].name != x[j].name {
613 return x[i].name < x[j].name
614 }
615 if len(x[i].inline) != len(x[j].inline) {
616 return len(x[i].inline) < len(x[j].inline)
617 }
618 return byIndex(x).Less(i, j)
619 })
620
621 for advance, i := 0, 0; i < len(fields); i += advance {
622
623
624 fi := fields[i]
625 name := fi.name
626 for advance = 1; i+advance < len(fields); advance++ {
627 fj := fields[i+advance]
628 if fj.name != name {
629 break
630 }
631 }
632 if advance == 1 {
633 sd.fl = append(sd.fl, fi)
634 sd.fm[name] = fi
635 continue
636 }
637 dominant, ok := dominantField(fields[i : i+advance])
638 if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates {
639 return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
640 }
641 sd.fl = append(sd.fl, dominant)
642 sd.fm[name] = dominant
643 }
644
645 sort.Sort(byIndex(sd.fl))
646
647 return sd, nil
648 }
649
650
651
652
653
654
655 func dominantField(fields []fieldDescription) (fieldDescription, bool) {
656
657
658
659 if len(fields) > 1 &&
660 len(fields[0].inline) == len(fields[1].inline) {
661 return fieldDescription{}, false
662 }
663 return fields[0], true
664 }
665
666 func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
667 defer func() {
668 if recovered := recover(); recovered != nil {
669 switch r := recovered.(type) {
670 case string:
671 err = fmt.Errorf("%s", r)
672 case error:
673 err = r
674 }
675 }
676 }()
677
678 result = v.FieldByIndex(index)
679 return
680 }
681
682 func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
683 field, err := fieldByIndexErr(val, index)
684 if err == nil {
685 return field, nil
686 }
687
688
689 inlineParent := index[:len(index)-1]
690 var fParent reflect.Value
691 if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
692 fParent, err = getInlineField(val, inlineParent)
693 if err != nil {
694 return fParent, err
695 }
696 }
697 fParent.Set(reflect.New(fParent.Type().Elem()))
698
699 return fieldByIndexErr(val, index)
700 }
701
702
703 func deepZero(st reflect.Type) (result reflect.Value) {
704 if st.Kind() == reflect.Struct {
705 numField := st.NumField()
706 for i := 0; i < numField; i++ {
707 if result == emptyValue {
708 result = reflect.Indirect(reflect.New(st))
709 }
710 f := result.Field(i)
711 if f.CanInterface() {
712 if f.Type().Kind() == reflect.Struct {
713 result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
714 }
715 }
716 }
717 }
718 return result
719 }
720
721
722 func recursivePointerTo(v reflect.Value) reflect.Value {
723 v = reflect.Indirect(v)
724 result := reflect.New(v.Type())
725 if v.Kind() == reflect.Struct {
726 for i := 0; i < v.NumField(); i++ {
727 if f := v.Field(i); f.Kind() == reflect.Ptr {
728 if f.Elem().Kind() == reflect.Struct {
729 result.Elem().Field(i).Set(recursivePointerTo(f))
730 }
731 }
732 }
733 }
734
735 return result
736 }
737
View as plain text