1
2
3
4
5
6
7 package bsoncodec
8
9 import (
10 "encoding"
11 "errors"
12 "fmt"
13 "reflect"
14 "strconv"
15
16 "go.mongodb.org/mongo-driver/bson/bsonoptions"
17 "go.mongodb.org/mongo-driver/bson/bsonrw"
18 "go.mongodb.org/mongo-driver/bson/bsontype"
19 )
20
21 var defaultMapCodec = NewMapCodec()
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 type MapCodec struct {
42
43
44
45
46 DecodeZerosMap bool
47
48
49
50
51
52 EncodeNilAsEmpty bool
53
54
55
56
57
58
59 EncodeKeysWithStringer bool
60 }
61
62
63
64 type KeyMarshaler interface {
65 MarshalKey() (key string, err error)
66 }
67
68
69
70
71
72
73
74 type KeyUnmarshaler interface {
75 UnmarshalKey(key string) error
76 }
77
78
79
80
81
82 func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
83 mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
84
85 codec := MapCodec{}
86 if mapOpt.DecodeZerosMap != nil {
87 codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
88 }
89 if mapOpt.EncodeNilAsEmpty != nil {
90 codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
91 }
92 if mapOpt.EncodeKeysWithStringer != nil {
93 codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
94 }
95 return &codec
96 }
97
98
99 func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
100 if !val.IsValid() || val.Kind() != reflect.Map {
101 return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
102 }
103
104 if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
105
106
107
108
109
110 err := vw.WriteNull()
111 if err == nil {
112 return nil
113 }
114 }
115
116 dw, err := vw.WriteDocument()
117 if err != nil {
118 return err
119 }
120
121 return mc.mapEncodeValue(ec, dw, val, nil)
122 }
123
124
125
126
127 func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
128
129 elemType := val.Type().Elem()
130 encoder, err := ec.LookupEncoder(elemType)
131 if err != nil && elemType.Kind() != reflect.Interface {
132 return err
133 }
134
135 keys := val.MapKeys()
136 for _, key := range keys {
137 keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
138 if err != nil {
139 return err
140 }
141
142 if collisionFn != nil && collisionFn(keyStr) {
143 return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
144 }
145
146 currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
147 if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
148 return lookupErr
149 }
150
151 vw, err := dw.WriteDocumentElement(keyStr)
152 if err != nil {
153 return err
154 }
155
156 if errors.Is(lookupErr, errInvalidValue) {
157 err = vw.WriteNull()
158 if err != nil {
159 return err
160 }
161 continue
162 }
163
164 err = currEncoder.EncodeValue(ec, vw, currVal)
165 if err != nil {
166 return err
167 }
168 }
169
170 return dw.WriteDocumentEnd()
171 }
172
173
174 func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
175 if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
176 return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
177 }
178
179 switch vrType := vr.Type(); vrType {
180 case bsontype.Type(0), bsontype.EmbeddedDocument:
181 case bsontype.Null:
182 val.Set(reflect.Zero(val.Type()))
183 return vr.ReadNull()
184 case bsontype.Undefined:
185 val.Set(reflect.Zero(val.Type()))
186 return vr.ReadUndefined()
187 default:
188 return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
189 }
190
191 dr, err := vr.ReadDocument()
192 if err != nil {
193 return err
194 }
195
196 if val.IsNil() {
197 val.Set(reflect.MakeMap(val.Type()))
198 }
199
200 if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
201 clearMap(val)
202 }
203
204 eType := val.Type().Elem()
205 decoder, err := dc.LookupDecoder(eType)
206 if err != nil {
207 return err
208 }
209 eTypeDecoder, _ := decoder.(typeDecoder)
210
211 if eType == tEmpty {
212 dc.Ancestor = val.Type()
213 }
214
215 keyType := val.Type().Key()
216
217 for {
218 key, vr, err := dr.ReadElement()
219 if errors.Is(err, bsonrw.ErrEOD) {
220 break
221 }
222 if err != nil {
223 return err
224 }
225
226 k, err := mc.decodeKey(key, keyType)
227 if err != nil {
228 return err
229 }
230
231 elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
232 if err != nil {
233 return newDecodeError(key, err)
234 }
235
236 val.SetMapIndex(k, elem)
237 }
238 return nil
239 }
240
241 func clearMap(m reflect.Value) {
242 var none reflect.Value
243 for _, k := range m.MapKeys() {
244 m.SetMapIndex(k, none)
245 }
246 }
247
248 func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
249 if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
250 return fmt.Sprint(val), nil
251 }
252
253
254 if val.Kind() == reflect.String {
255 return val.String(), nil
256 }
257
258 if km, ok := val.Interface().(KeyMarshaler); ok {
259 if val.Kind() == reflect.Ptr && val.IsNil() {
260 return "", nil
261 }
262 buf, err := km.MarshalKey()
263 if err == nil {
264 return buf, nil
265 }
266 return "", err
267 }
268
269 if km, ok := val.Interface().(encoding.TextMarshaler); ok {
270 if val.Kind() == reflect.Ptr && val.IsNil() {
271 return "", nil
272 }
273
274 buf, err := km.MarshalText()
275 if err != nil {
276 return "", err
277 }
278
279 return string(buf), nil
280 }
281
282 switch val.Kind() {
283 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
284 return strconv.FormatInt(val.Int(), 10), nil
285 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
286 return strconv.FormatUint(val.Uint(), 10), nil
287 }
288 return "", fmt.Errorf("unsupported key type: %v", val.Type())
289 }
290
291 var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
292 var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
293
294 func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
295 keyVal := reflect.ValueOf(key)
296 var err error
297 switch {
298
299 case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
300 keyVal = reflect.New(keyType)
301 v := keyVal.Interface().(KeyUnmarshaler)
302 err = v.UnmarshalKey(key)
303 keyVal = keyVal.Elem()
304
305 case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
306 keyVal = reflect.New(keyType)
307 v := keyVal.Interface().(encoding.TextUnmarshaler)
308 err = v.UnmarshalText([]byte(key))
309 keyVal = keyVal.Elem()
310
311 default:
312 switch keyType.Kind() {
313 case reflect.String:
314 keyVal = reflect.ValueOf(key).Convert(keyType)
315 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
316 n, parseErr := strconv.ParseInt(key, 10, 64)
317 if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
318 err = fmt.Errorf("failed to unmarshal number key %v", key)
319 }
320 keyVal = reflect.ValueOf(n).Convert(keyType)
321 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
322 n, parseErr := strconv.ParseUint(key, 10, 64)
323 if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
324 err = fmt.Errorf("failed to unmarshal number key %v", key)
325 break
326 }
327 keyVal = reflect.ValueOf(n).Convert(keyType)
328 case reflect.Float32, reflect.Float64:
329 if mc.EncodeKeysWithStringer {
330 parsed, err := strconv.ParseFloat(key, 64)
331 if err != nil {
332 return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
333 }
334 keyVal = reflect.ValueOf(parsed)
335 break
336 }
337 fallthrough
338 default:
339 return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
340 }
341 }
342 return keyVal, err
343 }
344
View as plain text