...

Source file src/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go

Documentation: go.mongodb.org/mongo-driver/bson/bsoncodec

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package bsoncodec
     8  
     9  import (
    10  	"encoding"
    11  	"errors"
    12  	"fmt"
    13  	"reflect"
    14  	"strconv"
    15  
    16  	"go.mongodb.org/mongo-driver/bson/bsonoptions"
    17  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    18  	"go.mongodb.org/mongo-driver/bson/bsontype"
    19  )
    20  
    21  var defaultMapCodec = NewMapCodec()
    22  
    23  // MapCodec is the Codec used for map values.
    24  //
    25  // Deprecated: MapCodec will not be directly configurable in Go Driver 2.0. To
    26  // configure the map encode and decode behavior, use the configuration methods
    27  // on a [go.mongodb.org/mongo-driver/bson.Encoder] or
    28  // [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the map encode and
    29  // decode behavior for a mongo.Client, use
    30  // [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions].
    31  //
    32  // For example, to configure a mongo.Client to marshal nil Go maps as empty BSON
    33  // documents, use:
    34  //
    35  //	opt := options.Client().SetBSONOptions(&options.BSONOptions{
    36  //	    NilMapAsEmpty: true,
    37  //	})
    38  //
    39  // See the deprecation notice for each field in MapCodec for the corresponding
    40  // settings.
    41  type MapCodec struct {
    42  	// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
    43  	// value passed to Decode before unmarshaling BSON documents into them.
    44  	//
    45  	// Deprecated: Use bson.Decoder.ZeroMaps or options.BSONOptions.ZeroMaps instead.
    46  	DecodeZerosMap bool
    47  
    48  	// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
    49  	// BSON null.
    50  	//
    51  	// Deprecated: Use bson.Encoder.NilMapAsEmpty or options.BSONOptions.NilMapAsEmpty instead.
    52  	EncodeNilAsEmpty bool
    53  
    54  	// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
    55  	// strings using fmt.Sprintf() instead of the default string conversion logic.
    56  	//
    57  	// Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt or
    58  	// options.BSONOptions.StringifyMapKeysWithFmt instead.
    59  	EncodeKeysWithStringer bool
    60  }
    61  
    62  // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
    63  // This applies to types used as map keys and is similar to encoding.TextMarshaler.
    64  type KeyMarshaler interface {
    65  	MarshalKey() (key string, err error)
    66  }
    67  
    68  // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
    69  // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
    70  //
    71  // UnmarshalKey must be able to decode the form generated by MarshalKey.
    72  // UnmarshalKey must copy the text if it wishes to retain the text
    73  // after returning.
    74  type KeyUnmarshaler interface {
    75  	UnmarshalKey(key string) error
    76  }
    77  
    78  // NewMapCodec returns a MapCodec with options opts.
    79  //
    80  // Deprecated: NewMapCodec will not be available in Go Driver 2.0. See
    81  // [MapCodec] for more details.
    82  func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
    83  	mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
    84  
    85  	codec := MapCodec{}
    86  	if mapOpt.DecodeZerosMap != nil {
    87  		codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
    88  	}
    89  	if mapOpt.EncodeNilAsEmpty != nil {
    90  		codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
    91  	}
    92  	if mapOpt.EncodeKeysWithStringer != nil {
    93  		codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
    94  	}
    95  	return &codec
    96  }
    97  
    98  // EncodeValue is the ValueEncoder for map[*]* types.
    99  func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
   100  	if !val.IsValid() || val.Kind() != reflect.Map {
   101  		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
   102  	}
   103  
   104  	if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
   105  		// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
   106  		// to a TopLevel document. We can't currently tell if this is what actually happened, but if
   107  		// there's a deeper underlying problem, the error will also be returned from WriteDocument,
   108  		// so just continue. The operations on a map reflection value are valid, so we can call
   109  		// MapKeys within mapEncodeValue without a problem.
   110  		err := vw.WriteNull()
   111  		if err == nil {
   112  			return nil
   113  		}
   114  	}
   115  
   116  	dw, err := vw.WriteDocument()
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	return mc.mapEncodeValue(ec, dw, val, nil)
   122  }
   123  
   124  // mapEncodeValue handles encoding of the values of a map. The collisionFn returns
   125  // true if the provided key exists, this is mainly used for inline maps in the
   126  // struct codec.
   127  func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
   128  
   129  	elemType := val.Type().Elem()
   130  	encoder, err := ec.LookupEncoder(elemType)
   131  	if err != nil && elemType.Kind() != reflect.Interface {
   132  		return err
   133  	}
   134  
   135  	keys := val.MapKeys()
   136  	for _, key := range keys {
   137  		keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
   138  		if err != nil {
   139  			return err
   140  		}
   141  
   142  		if collisionFn != nil && collisionFn(keyStr) {
   143  			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
   144  		}
   145  
   146  		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
   147  		if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
   148  			return lookupErr
   149  		}
   150  
   151  		vw, err := dw.WriteDocumentElement(keyStr)
   152  		if err != nil {
   153  			return err
   154  		}
   155  
   156  		if errors.Is(lookupErr, errInvalidValue) {
   157  			err = vw.WriteNull()
   158  			if err != nil {
   159  				return err
   160  			}
   161  			continue
   162  		}
   163  
   164  		err = currEncoder.EncodeValue(ec, vw, currVal)
   165  		if err != nil {
   166  			return err
   167  		}
   168  	}
   169  
   170  	return dw.WriteDocumentEnd()
   171  }
   172  
   173  // DecodeValue is the ValueDecoder for map[string/decimal]* types.
   174  func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
   175  	if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
   176  		return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
   177  	}
   178  
   179  	switch vrType := vr.Type(); vrType {
   180  	case bsontype.Type(0), bsontype.EmbeddedDocument:
   181  	case bsontype.Null:
   182  		val.Set(reflect.Zero(val.Type()))
   183  		return vr.ReadNull()
   184  	case bsontype.Undefined:
   185  		val.Set(reflect.Zero(val.Type()))
   186  		return vr.ReadUndefined()
   187  	default:
   188  		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
   189  	}
   190  
   191  	dr, err := vr.ReadDocument()
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	if val.IsNil() {
   197  		val.Set(reflect.MakeMap(val.Type()))
   198  	}
   199  
   200  	if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
   201  		clearMap(val)
   202  	}
   203  
   204  	eType := val.Type().Elem()
   205  	decoder, err := dc.LookupDecoder(eType)
   206  	if err != nil {
   207  		return err
   208  	}
   209  	eTypeDecoder, _ := decoder.(typeDecoder)
   210  
   211  	if eType == tEmpty {
   212  		dc.Ancestor = val.Type()
   213  	}
   214  
   215  	keyType := val.Type().Key()
   216  
   217  	for {
   218  		key, vr, err := dr.ReadElement()
   219  		if errors.Is(err, bsonrw.ErrEOD) {
   220  			break
   221  		}
   222  		if err != nil {
   223  			return err
   224  		}
   225  
   226  		k, err := mc.decodeKey(key, keyType)
   227  		if err != nil {
   228  			return err
   229  		}
   230  
   231  		elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
   232  		if err != nil {
   233  			return newDecodeError(key, err)
   234  		}
   235  
   236  		val.SetMapIndex(k, elem)
   237  	}
   238  	return nil
   239  }
   240  
   241  func clearMap(m reflect.Value) {
   242  	var none reflect.Value
   243  	for _, k := range m.MapKeys() {
   244  		m.SetMapIndex(k, none)
   245  	}
   246  }
   247  
   248  func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
   249  	if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
   250  		return fmt.Sprint(val), nil
   251  	}
   252  
   253  	// keys of any string type are used directly
   254  	if val.Kind() == reflect.String {
   255  		return val.String(), nil
   256  	}
   257  	// KeyMarshalers are marshaled
   258  	if km, ok := val.Interface().(KeyMarshaler); ok {
   259  		if val.Kind() == reflect.Ptr && val.IsNil() {
   260  			return "", nil
   261  		}
   262  		buf, err := km.MarshalKey()
   263  		if err == nil {
   264  			return buf, nil
   265  		}
   266  		return "", err
   267  	}
   268  	// keys implement encoding.TextMarshaler are marshaled.
   269  	if km, ok := val.Interface().(encoding.TextMarshaler); ok {
   270  		if val.Kind() == reflect.Ptr && val.IsNil() {
   271  			return "", nil
   272  		}
   273  
   274  		buf, err := km.MarshalText()
   275  		if err != nil {
   276  			return "", err
   277  		}
   278  
   279  		return string(buf), nil
   280  	}
   281  
   282  	switch val.Kind() {
   283  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   284  		return strconv.FormatInt(val.Int(), 10), nil
   285  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   286  		return strconv.FormatUint(val.Uint(), 10), nil
   287  	}
   288  	return "", fmt.Errorf("unsupported key type: %v", val.Type())
   289  }
   290  
   291  var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
   292  var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
   293  
   294  func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
   295  	keyVal := reflect.ValueOf(key)
   296  	var err error
   297  	switch {
   298  	// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
   299  	case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
   300  		keyVal = reflect.New(keyType)
   301  		v := keyVal.Interface().(KeyUnmarshaler)
   302  		err = v.UnmarshalKey(key)
   303  		keyVal = keyVal.Elem()
   304  	// Try to decode encoding.TextUnmarshalers.
   305  	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
   306  		keyVal = reflect.New(keyType)
   307  		v := keyVal.Interface().(encoding.TextUnmarshaler)
   308  		err = v.UnmarshalText([]byte(key))
   309  		keyVal = keyVal.Elem()
   310  	// Otherwise, go to type specific behavior
   311  	default:
   312  		switch keyType.Kind() {
   313  		case reflect.String:
   314  			keyVal = reflect.ValueOf(key).Convert(keyType)
   315  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   316  			n, parseErr := strconv.ParseInt(key, 10, 64)
   317  			if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
   318  				err = fmt.Errorf("failed to unmarshal number key %v", key)
   319  			}
   320  			keyVal = reflect.ValueOf(n).Convert(keyType)
   321  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   322  			n, parseErr := strconv.ParseUint(key, 10, 64)
   323  			if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
   324  				err = fmt.Errorf("failed to unmarshal number key %v", key)
   325  				break
   326  			}
   327  			keyVal = reflect.ValueOf(n).Convert(keyType)
   328  		case reflect.Float32, reflect.Float64:
   329  			if mc.EncodeKeysWithStringer {
   330  				parsed, err := strconv.ParseFloat(key, 64)
   331  				if err != nil {
   332  					return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
   333  				}
   334  				keyVal = reflect.ValueOf(parsed)
   335  				break
   336  			}
   337  			fallthrough
   338  		default:
   339  			return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
   340  		}
   341  	}
   342  	return keyVal, err
   343  }
   344  

View as plain text