...
1
2
3
4
5
6
7 package mgocompat
8
9 import (
10 "errors"
11 "reflect"
12
13 "go.mongodb.org/mongo-driver/bson"
14 "go.mongodb.org/mongo-driver/bson/bsoncodec"
15 "go.mongodb.org/mongo-driver/bson/bsonrw"
16 )
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 type Setter interface {
39 SetBSON(raw bson.RawValue) error
40 }
41
42
43
44
45
46
47
48 type Getter interface {
49 GetBSON() (interface{}, error)
50 }
51
52
53 func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
54 if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) {
55 return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
56 }
57
58 if val.Kind() == reflect.Ptr && val.IsNil() {
59 if !val.CanSet() {
60 return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
61 }
62 val.Set(reflect.New(val.Type().Elem()))
63 }
64
65 if !val.Type().Implements(tSetter) {
66 if !val.CanAddr() {
67 return bsoncodec.ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
68 }
69 val = val.Addr()
70 }
71
72 t, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
73 if err != nil {
74 return err
75 }
76
77 m, ok := val.Interface().(Setter)
78 if !ok {
79 return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
80 }
81 if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil {
82 if !errors.Is(err, ErrSetZero) {
83 return err
84 }
85 val.Set(reflect.Zero(val.Type()))
86 }
87 return nil
88 }
89
90
91 func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
92
93 switch {
94 case !val.IsValid():
95 return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
96 case val.Type().Implements(tGetter):
97
98 if isImplementationNil(val, tGetter) {
99 return vw.WriteNull()
100 }
101 case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr():
102 val = val.Addr()
103 default:
104 return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
105 }
106
107 m, ok := val.Interface().(Getter)
108 if !ok {
109 return vw.WriteNull()
110 }
111 x, err := m.GetBSON()
112 if err != nil {
113 return err
114 }
115 if x == nil {
116 return vw.WriteNull()
117 }
118 vv := reflect.ValueOf(x)
119 encoder, err := ec.Registry.LookupEncoder(vv.Type())
120 if err != nil {
121 return err
122 }
123 return encoder.EncodeValue(ec, vw, vv)
124 }
125
126
127 func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
128 vt := val.Type()
129 for vt.Kind() == reflect.Ptr {
130 vt = vt.Elem()
131 }
132 return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
133 }
134
View as plain text