...

Source file src/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_codec.go

Documentation: go.mongodb.org/mongo-driver/bson/bsoncodec

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package bsoncodec
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/bson/bsonoptions"
    19  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    20  	"go.mongodb.org/mongo-driver/bson/bsontype"
    21  )
    22  
    23  // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
    24  type DecodeError struct {
    25  	keys    []string
    26  	wrapped error
    27  }
    28  
    29  // Unwrap returns the underlying error
    30  func (de *DecodeError) Unwrap() error {
    31  	return de.wrapped
    32  }
    33  
    34  // Error implements the error interface.
    35  func (de *DecodeError) Error() string {
    36  	// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
    37  	// stack of BSON keys, so we call de.Keys(), which reverses them.
    38  	keyPath := strings.Join(de.Keys(), ".")
    39  	return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
    40  }
    41  
    42  // Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
    43  // order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
    44  // a string, the keys slice will be ["a", "b", "c"].
    45  func (de *DecodeError) Keys() []string {
    46  	reversedKeys := make([]string, 0, len(de.keys))
    47  	for idx := len(de.keys) - 1; idx >= 0; idx-- {
    48  		reversedKeys = append(reversedKeys, de.keys[idx])
    49  	}
    50  
    51  	return reversedKeys
    52  }
    53  
    54  // Zeroer allows custom struct types to implement a report of zero
    55  // state. All struct types that don't implement Zeroer or where IsZero
    56  // returns false are considered to be not zero.
    57  type Zeroer interface {
    58  	IsZero() bool
    59  }
    60  
    61  // StructCodec is the Codec used for struct values.
    62  //
    63  // Deprecated: StructCodec will not be directly configurable in Go Driver 2.0.
    64  // To configure the struct encode and decode behavior, use the configuration
    65  // methods on a [go.mongodb.org/mongo-driver/bson.Encoder] or
    66  // [go.mongodb.org/mongo-driver/bson.Decoder]. To configure the struct encode
    67  // and decode behavior for a mongo.Client, use
    68  // [go.mongodb.org/mongo-driver/mongo/options.ClientOptions.SetBSONOptions].
    69  //
    70  // For example, to configure a mongo.Client to omit zero-value structs when
    71  // using the "omitempty" struct tag, use:
    72  //
    73  //	opt := options.Client().SetBSONOptions(&options.BSONOptions{
    74  //	    OmitZeroStruct: true,
    75  //	})
    76  //
    77  // See the deprecation notice for each field in StructCodec for the corresponding
    78  // settings.
    79  type StructCodec struct {
    80  	cache  sync.Map // map[reflect.Type]*structDescription
    81  	parser StructTagParser
    82  
    83  	// DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the
    84  	// destination value passed to Decode before unmarshaling BSON documents into them.
    85  	//
    86  	// Deprecated: Use bson.Decoder.ZeroStructs or options.BSONOptions.ZeroStructs instead.
    87  	DecodeZeroStruct bool
    88  
    89  	// DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the
    90  	// destination value passed to Decode before unmarshaling BSON documents into them.
    91  	//
    92  	// Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0.
    93  	DecodeDeepZeroInline bool
    94  
    95  	// EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g.
    96  	// MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag
    97  	// option is set.
    98  	//
    99  	// Deprecated: Use bson.Encoder.OmitZeroStruct or options.BSONOptions.OmitZeroStruct instead.
   100  	EncodeOmitDefaultStruct bool
   101  
   102  	// AllowUnexportedFields allows encoding and decoding values from un-exported struct fields.
   103  	//
   104  	// Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be
   105  	// supported in Go Driver 2.0.
   106  	AllowUnexportedFields bool
   107  
   108  	// OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is
   109  	// a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The
   110  	// default value is true.
   111  	//
   112  	// Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates or
   113  	// options.BSONOptions.ErrorOnInlineDuplicates instead.
   114  	OverwriteDuplicatedInlinedFields bool
   115  }
   116  
   117  var _ ValueEncoder = &StructCodec{}
   118  var _ ValueDecoder = &StructCodec{}
   119  
   120  // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
   121  //
   122  // Deprecated: NewStructCodec will not be available in Go Driver 2.0. See
   123  // [StructCodec] for more details.
   124  func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
   125  	if p == nil {
   126  		return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
   127  	}
   128  
   129  	structOpt := bsonoptions.MergeStructCodecOptions(opts...)
   130  
   131  	codec := &StructCodec{
   132  		parser: p,
   133  	}
   134  
   135  	if structOpt.DecodeZeroStruct != nil {
   136  		codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
   137  	}
   138  	if structOpt.DecodeDeepZeroInline != nil {
   139  		codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
   140  	}
   141  	if structOpt.EncodeOmitDefaultStruct != nil {
   142  		codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
   143  	}
   144  	if structOpt.OverwriteDuplicatedInlinedFields != nil {
   145  		codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
   146  	}
   147  	if structOpt.AllowUnexportedFields != nil {
   148  		codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
   149  	}
   150  
   151  	return codec, nil
   152  }
   153  
   154  // EncodeValue handles encoding generic struct types.
   155  func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
   156  	if !val.IsValid() || val.Kind() != reflect.Struct {
   157  		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
   158  	}
   159  
   160  	sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates)
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	dw, err := vw.WriteDocument()
   166  	if err != nil {
   167  		return err
   168  	}
   169  	var rv reflect.Value
   170  	for _, desc := range sd.fl {
   171  		if desc.inline == nil {
   172  			rv = val.Field(desc.idx)
   173  		} else {
   174  			rv, err = fieldByIndexErr(val, desc.inline)
   175  			if err != nil {
   176  				continue
   177  			}
   178  		}
   179  
   180  		desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv)
   181  
   182  		if err != nil && !errors.Is(err, errInvalidValue) {
   183  			return err
   184  		}
   185  
   186  		if errors.Is(err, errInvalidValue) {
   187  			if desc.omitEmpty {
   188  				continue
   189  			}
   190  			vw2, err := dw.WriteDocumentElement(desc.name)
   191  			if err != nil {
   192  				return err
   193  			}
   194  			err = vw2.WriteNull()
   195  			if err != nil {
   196  				return err
   197  			}
   198  			continue
   199  		}
   200  
   201  		if desc.encoder == nil {
   202  			return ErrNoEncoder{Type: rv.Type()}
   203  		}
   204  
   205  		encoder := desc.encoder
   206  
   207  		var empty bool
   208  		if cz, ok := encoder.(CodecZeroer); ok {
   209  			empty = cz.IsTypeZero(rv.Interface())
   210  		} else if rv.Kind() == reflect.Interface {
   211  			// isEmpty will not treat an interface rv as an interface, so we need to check for the
   212  			// nil interface separately.
   213  			empty = rv.IsNil()
   214  		} else {
   215  			empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct)
   216  		}
   217  		if desc.omitEmpty && empty {
   218  			continue
   219  		}
   220  
   221  		vw2, err := dw.WriteDocumentElement(desc.name)
   222  		if err != nil {
   223  			return err
   224  		}
   225  
   226  		ectx := EncodeContext{
   227  			Registry:                ec.Registry,
   228  			MinSize:                 desc.minSize || ec.MinSize,
   229  			errorOnInlineDuplicates: ec.errorOnInlineDuplicates,
   230  			stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt,
   231  			nilMapAsEmpty:           ec.nilMapAsEmpty,
   232  			nilSliceAsEmpty:         ec.nilSliceAsEmpty,
   233  			nilByteSliceAsEmpty:     ec.nilByteSliceAsEmpty,
   234  			omitZeroStruct:          ec.omitZeroStruct,
   235  			useJSONStructTags:       ec.useJSONStructTags,
   236  		}
   237  		err = encoder.EncodeValue(ectx, vw2, rv)
   238  		if err != nil {
   239  			return err
   240  		}
   241  	}
   242  
   243  	if sd.inlineMap >= 0 {
   244  		rv := val.Field(sd.inlineMap)
   245  		collisionFn := func(key string) bool {
   246  			_, exists := sd.fm[key]
   247  			return exists
   248  		}
   249  
   250  		return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn)
   251  	}
   252  
   253  	return dw.WriteDocumentEnd()
   254  }
   255  
   256  func newDecodeError(key string, original error) error {
   257  	var de *DecodeError
   258  	if !errors.As(original, &de) {
   259  		return &DecodeError{
   260  			keys:    []string{key},
   261  			wrapped: original,
   262  		}
   263  	}
   264  
   265  	de.keys = append(de.keys, key)
   266  	return de
   267  }
   268  
   269  // DecodeValue implements the Codec interface.
   270  // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
   271  // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
   272  func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
   273  	if !val.CanSet() || val.Kind() != reflect.Struct {
   274  		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
   275  	}
   276  
   277  	switch vrType := vr.Type(); vrType {
   278  	case bsontype.Type(0), bsontype.EmbeddedDocument:
   279  	case bsontype.Null:
   280  		if err := vr.ReadNull(); err != nil {
   281  			return err
   282  		}
   283  
   284  		val.Set(reflect.Zero(val.Type()))
   285  		return nil
   286  	case bsontype.Undefined:
   287  		if err := vr.ReadUndefined(); err != nil {
   288  			return err
   289  		}
   290  
   291  		val.Set(reflect.Zero(val.Type()))
   292  		return nil
   293  	default:
   294  		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
   295  	}
   296  
   297  	sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false)
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	if sc.DecodeZeroStruct || dc.zeroStructs {
   303  		val.Set(reflect.Zero(val.Type()))
   304  	}
   305  	if sc.DecodeDeepZeroInline && sd.inline {
   306  		val.Set(deepZero(val.Type()))
   307  	}
   308  
   309  	var decoder ValueDecoder
   310  	var inlineMap reflect.Value
   311  	if sd.inlineMap >= 0 {
   312  		inlineMap = val.Field(sd.inlineMap)
   313  		decoder, err = dc.LookupDecoder(inlineMap.Type().Elem())
   314  		if err != nil {
   315  			return err
   316  		}
   317  	}
   318  
   319  	dr, err := vr.ReadDocument()
   320  	if err != nil {
   321  		return err
   322  	}
   323  
   324  	for {
   325  		name, vr, err := dr.ReadElement()
   326  		if errors.Is(err, bsonrw.ErrEOD) {
   327  			break
   328  		}
   329  		if err != nil {
   330  			return err
   331  		}
   332  
   333  		fd, exists := sd.fm[name]
   334  		if !exists {
   335  			// if the original name isn't found in the struct description, try again with the name in lowercase
   336  			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
   337  			// names
   338  			fd, exists = sd.fm[strings.ToLower(name)]
   339  		}
   340  
   341  		if !exists {
   342  			if sd.inlineMap < 0 {
   343  				// The encoding/json package requires a flag to return on error for non-existent fields.
   344  				// This functionality seems appropriate for the struct codec.
   345  				err = vr.Skip()
   346  				if err != nil {
   347  					return err
   348  				}
   349  				continue
   350  			}
   351  
   352  			if inlineMap.IsNil() {
   353  				inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
   354  			}
   355  
   356  			elem := reflect.New(inlineMap.Type().Elem()).Elem()
   357  			dc.Ancestor = inlineMap.Type()
   358  			err = decoder.DecodeValue(dc, vr, elem)
   359  			if err != nil {
   360  				return err
   361  			}
   362  			inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
   363  			continue
   364  		}
   365  
   366  		var field reflect.Value
   367  		if fd.inline == nil {
   368  			field = val.Field(fd.idx)
   369  		} else {
   370  			field, err = getInlineField(val, fd.inline)
   371  			if err != nil {
   372  				return err
   373  			}
   374  		}
   375  
   376  		if !field.CanSet() { // Being settable is a super set of being addressable.
   377  			innerErr := fmt.Errorf("field %v is not settable", field)
   378  			return newDecodeError(fd.name, innerErr)
   379  		}
   380  		if field.Kind() == reflect.Ptr && field.IsNil() {
   381  			field.Set(reflect.New(field.Type().Elem()))
   382  		}
   383  		field = field.Addr()
   384  
   385  		dctx := DecodeContext{
   386  			Registry:            dc.Registry,
   387  			Truncate:            fd.truncate || dc.Truncate,
   388  			defaultDocumentType: dc.defaultDocumentType,
   389  			binaryAsSlice:       dc.binaryAsSlice,
   390  			useJSONStructTags:   dc.useJSONStructTags,
   391  			useLocalTimeZone:    dc.useLocalTimeZone,
   392  			zeroMaps:            dc.zeroMaps,
   393  			zeroStructs:         dc.zeroStructs,
   394  		}
   395  
   396  		if fd.decoder == nil {
   397  			return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
   398  		}
   399  
   400  		err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
   401  		if err != nil {
   402  			return newDecodeError(fd.name, err)
   403  		}
   404  	}
   405  
   406  	return nil
   407  }
   408  
   409  func isEmpty(v reflect.Value, omitZeroStruct bool) bool {
   410  	kind := v.Kind()
   411  	if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
   412  		return v.Interface().(Zeroer).IsZero()
   413  	}
   414  	switch kind {
   415  	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
   416  		return v.Len() == 0
   417  	case reflect.Struct:
   418  		if !omitZeroStruct {
   419  			return false
   420  		}
   421  		vt := v.Type()
   422  		if vt == tTime {
   423  			return v.Interface().(time.Time).IsZero()
   424  		}
   425  		numField := vt.NumField()
   426  		for i := 0; i < numField; i++ {
   427  			ff := vt.Field(i)
   428  			if ff.PkgPath != "" && !ff.Anonymous {
   429  				continue // Private field
   430  			}
   431  			if !isEmpty(v.Field(i), omitZeroStruct) {
   432  				return false
   433  			}
   434  		}
   435  		return true
   436  	}
   437  	return !v.IsValid() || v.IsZero()
   438  }
   439  
   440  type structDescription struct {
   441  	fm        map[string]fieldDescription
   442  	fl        []fieldDescription
   443  	inlineMap int
   444  	inline    bool
   445  }
   446  
   447  type fieldDescription struct {
   448  	name      string // BSON key name
   449  	fieldName string // struct field name
   450  	idx       int
   451  	omitEmpty bool
   452  	minSize   bool
   453  	truncate  bool
   454  	inline    []int
   455  	encoder   ValueEncoder
   456  	decoder   ValueDecoder
   457  }
   458  
   459  type byIndex []fieldDescription
   460  
   461  func (bi byIndex) Len() int { return len(bi) }
   462  
   463  func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
   464  
   465  func (bi byIndex) Less(i, j int) bool {
   466  	// If a field is inlined, its index in the top level struct is stored at inline[0]
   467  	iIdx, jIdx := bi[i].idx, bi[j].idx
   468  	if len(bi[i].inline) > 0 {
   469  		iIdx = bi[i].inline[0]
   470  	}
   471  	if len(bi[j].inline) > 0 {
   472  		jIdx = bi[j].inline[0]
   473  	}
   474  	if iIdx != jIdx {
   475  		return iIdx < jIdx
   476  	}
   477  	for k, biik := range bi[i].inline {
   478  		if k >= len(bi[j].inline) {
   479  			return false
   480  		}
   481  		if biik != bi[j].inline[k] {
   482  			return biik < bi[j].inline[k]
   483  		}
   484  	}
   485  	return len(bi[i].inline) < len(bi[j].inline)
   486  }
   487  
   488  func (sc *StructCodec) describeStruct(
   489  	r *Registry,
   490  	t reflect.Type,
   491  	useJSONStructTags bool,
   492  	errorOnDuplicates bool,
   493  ) (*structDescription, error) {
   494  	// We need to analyze the struct, including getting the tags, collecting
   495  	// information about inlining, and create a map of the field name to the field.
   496  	if v, ok := sc.cache.Load(t); ok {
   497  		return v.(*structDescription), nil
   498  	}
   499  	// TODO(charlie): Only describe the struct once when called
   500  	// concurrently with the same type.
   501  	ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates)
   502  	if err != nil {
   503  		return nil, err
   504  	}
   505  	if v, loaded := sc.cache.LoadOrStore(t, ds); loaded {
   506  		ds = v.(*structDescription)
   507  	}
   508  	return ds, nil
   509  }
   510  
   511  func (sc *StructCodec) describeStructSlow(
   512  	r *Registry,
   513  	t reflect.Type,
   514  	useJSONStructTags bool,
   515  	errorOnDuplicates bool,
   516  ) (*structDescription, error) {
   517  	numFields := t.NumField()
   518  	sd := &structDescription{
   519  		fm:        make(map[string]fieldDescription, numFields),
   520  		fl:        make([]fieldDescription, 0, numFields),
   521  		inlineMap: -1,
   522  	}
   523  
   524  	var fields []fieldDescription
   525  	for i := 0; i < numFields; i++ {
   526  		sf := t.Field(i)
   527  		if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
   528  			// field is private or unexported fields aren't allowed, ignore
   529  			continue
   530  		}
   531  
   532  		sfType := sf.Type
   533  		encoder, err := r.LookupEncoder(sfType)
   534  		if err != nil {
   535  			encoder = nil
   536  		}
   537  		decoder, err := r.LookupDecoder(sfType)
   538  		if err != nil {
   539  			decoder = nil
   540  		}
   541  
   542  		description := fieldDescription{
   543  			fieldName: sf.Name,
   544  			idx:       i,
   545  			encoder:   encoder,
   546  			decoder:   decoder,
   547  		}
   548  
   549  		var stags StructTags
   550  		// If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser
   551  		// instead of the parser defined on the codec.
   552  		if useJSONStructTags {
   553  			stags, err = JSONFallbackStructTagParser.ParseStructTags(sf)
   554  		} else {
   555  			stags, err = sc.parser.ParseStructTags(sf)
   556  		}
   557  		if err != nil {
   558  			return nil, err
   559  		}
   560  		if stags.Skip {
   561  			continue
   562  		}
   563  		description.name = stags.Name
   564  		description.omitEmpty = stags.OmitEmpty
   565  		description.minSize = stags.MinSize
   566  		description.truncate = stags.Truncate
   567  
   568  		if stags.Inline {
   569  			sd.inline = true
   570  			switch sfType.Kind() {
   571  			case reflect.Map:
   572  				if sd.inlineMap >= 0 {
   573  					return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
   574  				}
   575  				if sfType.Key() != tString {
   576  					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
   577  				}
   578  				sd.inlineMap = description.idx
   579  			case reflect.Ptr:
   580  				sfType = sfType.Elem()
   581  				if sfType.Kind() != reflect.Struct {
   582  					return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
   583  				}
   584  				fallthrough
   585  			case reflect.Struct:
   586  				inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates)
   587  				if err != nil {
   588  					return nil, err
   589  				}
   590  				for _, fd := range inlinesf.fl {
   591  					if fd.inline == nil {
   592  						fd.inline = []int{i, fd.idx}
   593  					} else {
   594  						fd.inline = append([]int{i}, fd.inline...)
   595  					}
   596  					fields = append(fields, fd)
   597  
   598  				}
   599  			default:
   600  				return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
   601  			}
   602  			continue
   603  		}
   604  		fields = append(fields, description)
   605  	}
   606  
   607  	// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
   608  	sort.Slice(fields, func(i, j int) bool {
   609  		x := fields
   610  		// sort field by name, breaking ties with depth, then
   611  		// breaking ties with index sequence.
   612  		if x[i].name != x[j].name {
   613  			return x[i].name < x[j].name
   614  		}
   615  		if len(x[i].inline) != len(x[j].inline) {
   616  			return len(x[i].inline) < len(x[j].inline)
   617  		}
   618  		return byIndex(x).Less(i, j)
   619  	})
   620  
   621  	for advance, i := 0, 0; i < len(fields); i += advance {
   622  		// One iteration per name.
   623  		// Find the sequence of fields with the name of this first field.
   624  		fi := fields[i]
   625  		name := fi.name
   626  		for advance = 1; i+advance < len(fields); advance++ {
   627  			fj := fields[i+advance]
   628  			if fj.name != name {
   629  				break
   630  			}
   631  		}
   632  		if advance == 1 { // Only one field with this name
   633  			sd.fl = append(sd.fl, fi)
   634  			sd.fm[name] = fi
   635  			continue
   636  		}
   637  		dominant, ok := dominantField(fields[i : i+advance])
   638  		if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates {
   639  			return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
   640  		}
   641  		sd.fl = append(sd.fl, dominant)
   642  		sd.fm[name] = dominant
   643  	}
   644  
   645  	sort.Sort(byIndex(sd.fl))
   646  
   647  	return sd, nil
   648  }
   649  
   650  // dominantField looks through the fields, all of which are known to
   651  // have the same name, to find the single field that dominates the
   652  // others using Go's inlining rules. If there are multiple top-level
   653  // fields, the boolean will be false: This condition is an error in Go
   654  // and we skip all the fields.
   655  func dominantField(fields []fieldDescription) (fieldDescription, bool) {
   656  	// The fields are sorted in increasing index-length order, then by presence of tag.
   657  	// That means that the first field is the dominant one. We need only check
   658  	// for error cases: two fields at top level.
   659  	if len(fields) > 1 &&
   660  		len(fields[0].inline) == len(fields[1].inline) {
   661  		return fieldDescription{}, false
   662  	}
   663  	return fields[0], true
   664  }
   665  
   666  func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
   667  	defer func() {
   668  		if recovered := recover(); recovered != nil {
   669  			switch r := recovered.(type) {
   670  			case string:
   671  				err = fmt.Errorf("%s", r)
   672  			case error:
   673  				err = r
   674  			}
   675  		}
   676  	}()
   677  
   678  	result = v.FieldByIndex(index)
   679  	return
   680  }
   681  
   682  func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
   683  	field, err := fieldByIndexErr(val, index)
   684  	if err == nil {
   685  		return field, nil
   686  	}
   687  
   688  	// if parent of this element doesn't exist, fix its parent
   689  	inlineParent := index[:len(index)-1]
   690  	var fParent reflect.Value
   691  	if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
   692  		fParent, err = getInlineField(val, inlineParent)
   693  		if err != nil {
   694  			return fParent, err
   695  		}
   696  	}
   697  	fParent.Set(reflect.New(fParent.Type().Elem()))
   698  
   699  	return fieldByIndexErr(val, index)
   700  }
   701  
   702  // DeepZero returns recursive zero object
   703  func deepZero(st reflect.Type) (result reflect.Value) {
   704  	if st.Kind() == reflect.Struct {
   705  		numField := st.NumField()
   706  		for i := 0; i < numField; i++ {
   707  			if result == emptyValue {
   708  				result = reflect.Indirect(reflect.New(st))
   709  			}
   710  			f := result.Field(i)
   711  			if f.CanInterface() {
   712  				if f.Type().Kind() == reflect.Struct {
   713  					result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
   714  				}
   715  			}
   716  		}
   717  	}
   718  	return result
   719  }
   720  
   721  // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
   722  func recursivePointerTo(v reflect.Value) reflect.Value {
   723  	v = reflect.Indirect(v)
   724  	result := reflect.New(v.Type())
   725  	if v.Kind() == reflect.Struct {
   726  		for i := 0; i < v.NumField(); i++ {
   727  			if f := v.Field(i); f.Kind() == reflect.Ptr {
   728  				if f.Elem().Kind() == reflect.Struct {
   729  					result.Elem().Field(i).Set(recursivePointerTo(f))
   730  				}
   731  			}
   732  		}
   733  	}
   734  
   735  	return result
   736  }
   737  

View as plain text