...

Source file src/google.golang.org/protobuf/internal/impl/codec_map.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"reflect"
     9  	"sort"
    10  
    11  	"google.golang.org/protobuf/encoding/protowire"
    12  	"google.golang.org/protobuf/internal/errors"
    13  	"google.golang.org/protobuf/internal/genid"
    14  	"google.golang.org/protobuf/reflect/protoreflect"
    15  )
    16  
    17  type mapInfo struct {
    18  	goType     reflect.Type
    19  	keyWiretag uint64
    20  	valWiretag uint64
    21  	keyFuncs   valueCoderFuncs
    22  	valFuncs   valueCoderFuncs
    23  	keyZero    protoreflect.Value
    24  	keyKind    protoreflect.Kind
    25  	conv       *mapConverter
    26  }
    27  
    28  func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
    29  	// TODO: Consider generating specialized map coders.
    30  	keyField := fd.MapKey()
    31  	valField := fd.MapValue()
    32  	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
    33  	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
    34  	keyFuncs := encoderFuncsForValue(keyField)
    35  	valFuncs := encoderFuncsForValue(valField)
    36  	conv := newMapConverter(ft, fd)
    37  
    38  	mapi := &mapInfo{
    39  		goType:     ft,
    40  		keyWiretag: keyWiretag,
    41  		valWiretag: valWiretag,
    42  		keyFuncs:   keyFuncs,
    43  		valFuncs:   valFuncs,
    44  		keyZero:    keyField.Default(),
    45  		keyKind:    keyField.Kind(),
    46  		conv:       conv,
    47  	}
    48  	if valField.Kind() == protoreflect.MessageKind {
    49  		valueMessage = getMessageInfo(ft.Elem())
    50  	}
    51  
    52  	funcs = pointerCoderFuncs{
    53  		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
    54  			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
    55  		},
    56  		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    57  			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
    58  		},
    59  		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
    60  			mp := p.AsValueOf(ft)
    61  			if mp.Elem().IsNil() {
    62  				mp.Elem().Set(reflect.MakeMap(mapi.goType))
    63  			}
    64  			if f.mi == nil {
    65  				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
    66  			} else {
    67  				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
    68  			}
    69  		},
    70  	}
    71  	switch valField.Kind() {
    72  	case protoreflect.MessageKind:
    73  		funcs.merge = mergeMapOfMessage
    74  	case protoreflect.BytesKind:
    75  		funcs.merge = mergeMapOfBytes
    76  	default:
    77  		funcs.merge = mergeMap
    78  	}
    79  	if valFuncs.isInit != nil {
    80  		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
    81  			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
    82  		}
    83  	}
    84  	return valueMessage, funcs
    85  }
    86  
    87  const (
    88  	mapKeyTagSize = 1 // field 1, tag size 1.
    89  	mapValTagSize = 1 // field 2, tag size 2.
    90  )
    91  
    92  func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
    93  	if mapv.Len() == 0 {
    94  		return 0
    95  	}
    96  	n := 0
    97  	iter := mapRange(mapv)
    98  	for iter.Next() {
    99  		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
   100  		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   101  		var valSize int
   102  		value := mapi.conv.valConv.PBValueOf(iter.Value())
   103  		if f.mi == nil {
   104  			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
   105  		} else {
   106  			p := pointerOfValue(iter.Value())
   107  			valSize += mapValTagSize
   108  			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
   109  		}
   110  		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
   111  	}
   112  	return n
   113  }
   114  
   115  func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   116  	if wtyp != protowire.BytesType {
   117  		return out, errUnknown
   118  	}
   119  	b, n := protowire.ConsumeBytes(b)
   120  	if n < 0 {
   121  		return out, errDecode
   122  	}
   123  	var (
   124  		key = mapi.keyZero
   125  		val = mapi.conv.valConv.New()
   126  	)
   127  	for len(b) > 0 {
   128  		num, wtyp, n := protowire.ConsumeTag(b)
   129  		if n < 0 {
   130  			return out, errDecode
   131  		}
   132  		if num > protowire.MaxValidNumber {
   133  			return out, errDecode
   134  		}
   135  		b = b[n:]
   136  		err := errUnknown
   137  		switch num {
   138  		case genid.MapEntry_Key_field_number:
   139  			var v protoreflect.Value
   140  			var o unmarshalOutput
   141  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
   142  			if err != nil {
   143  				break
   144  			}
   145  			key = v
   146  			n = o.n
   147  		case genid.MapEntry_Value_field_number:
   148  			var v protoreflect.Value
   149  			var o unmarshalOutput
   150  			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
   151  			if err != nil {
   152  				break
   153  			}
   154  			val = v
   155  			n = o.n
   156  		}
   157  		if err == errUnknown {
   158  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   159  			if n < 0 {
   160  				return out, errDecode
   161  			}
   162  		} else if err != nil {
   163  			return out, err
   164  		}
   165  		b = b[n:]
   166  	}
   167  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
   168  	out.n = n
   169  	return out, nil
   170  }
   171  
   172  func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
   173  	if wtyp != protowire.BytesType {
   174  		return out, errUnknown
   175  	}
   176  	b, n := protowire.ConsumeBytes(b)
   177  	if n < 0 {
   178  		return out, errDecode
   179  	}
   180  	var (
   181  		key = mapi.keyZero
   182  		val = reflect.New(f.mi.GoReflectType.Elem())
   183  	)
   184  	for len(b) > 0 {
   185  		num, wtyp, n := protowire.ConsumeTag(b)
   186  		if n < 0 {
   187  			return out, errDecode
   188  		}
   189  		if num > protowire.MaxValidNumber {
   190  			return out, errDecode
   191  		}
   192  		b = b[n:]
   193  		err := errUnknown
   194  		switch num {
   195  		case 1:
   196  			var v protoreflect.Value
   197  			var o unmarshalOutput
   198  			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
   199  			if err != nil {
   200  				break
   201  			}
   202  			key = v
   203  			n = o.n
   204  		case 2:
   205  			if wtyp != protowire.BytesType {
   206  				break
   207  			}
   208  			var v []byte
   209  			v, n = protowire.ConsumeBytes(b)
   210  			if n < 0 {
   211  				return out, errDecode
   212  			}
   213  			var o unmarshalOutput
   214  			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
   215  			if o.initialized {
   216  				// Consider this map item initialized so long as we see
   217  				// an initialized value.
   218  				out.initialized = true
   219  			}
   220  		}
   221  		if err == errUnknown {
   222  			n = protowire.ConsumeFieldValue(num, wtyp, b)
   223  			if n < 0 {
   224  				return out, errDecode
   225  			}
   226  		} else if err != nil {
   227  			return out, err
   228  		}
   229  		b = b[n:]
   230  	}
   231  	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
   232  	out.n = n
   233  	return out, nil
   234  }
   235  
   236  func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   237  	if f.mi == nil {
   238  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
   239  		val := mapi.conv.valConv.PBValueOf(valrv)
   240  		size := 0
   241  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   242  		size += mapi.valFuncs.size(val, mapValTagSize, opts)
   243  		b = protowire.AppendVarint(b, uint64(size))
   244  		before := len(b)
   245  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
   246  		if err != nil {
   247  			return nil, err
   248  		}
   249  		b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
   250  		if measuredSize := len(b) - before; size != measuredSize && err == nil {
   251  			return nil, errors.MismatchedSizeCalculation(size, measuredSize)
   252  		}
   253  		return b, err
   254  	} else {
   255  		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
   256  		val := pointerOfValue(valrv)
   257  		valSize := f.mi.sizePointer(val, opts)
   258  		size := 0
   259  		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
   260  		size += mapValTagSize + protowire.SizeBytes(valSize)
   261  		b = protowire.AppendVarint(b, uint64(size))
   262  		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  		b = protowire.AppendVarint(b, mapi.valWiretag)
   267  		b = protowire.AppendVarint(b, uint64(valSize))
   268  		before := len(b)
   269  		b, err = f.mi.marshalAppendPointer(b, val, opts)
   270  		if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
   271  			return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
   272  		}
   273  		return b, err
   274  	}
   275  }
   276  
   277  func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   278  	if mapv.Len() == 0 {
   279  		return b, nil
   280  	}
   281  	if opts.Deterministic() {
   282  		return appendMapDeterministic(b, mapv, mapi, f, opts)
   283  	}
   284  	iter := mapRange(mapv)
   285  	for iter.Next() {
   286  		var err error
   287  		b = protowire.AppendVarint(b, f.wiretag)
   288  		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
   289  		if err != nil {
   290  			return b, err
   291  		}
   292  	}
   293  	return b, nil
   294  }
   295  
   296  func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
   297  	keys := mapv.MapKeys()
   298  	sort.Slice(keys, func(i, j int) bool {
   299  		switch keys[i].Kind() {
   300  		case reflect.Bool:
   301  			return !keys[i].Bool() && keys[j].Bool()
   302  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   303  			return keys[i].Int() < keys[j].Int()
   304  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   305  			return keys[i].Uint() < keys[j].Uint()
   306  		case reflect.Float32, reflect.Float64:
   307  			return keys[i].Float() < keys[j].Float()
   308  		case reflect.String:
   309  			return keys[i].String() < keys[j].String()
   310  		default:
   311  			panic("invalid kind: " + keys[i].Kind().String())
   312  		}
   313  	})
   314  	for _, key := range keys {
   315  		var err error
   316  		b = protowire.AppendVarint(b, f.wiretag)
   317  		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
   318  		if err != nil {
   319  			return b, err
   320  		}
   321  	}
   322  	return b, nil
   323  }
   324  
   325  func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
   326  	if mi := f.mi; mi != nil {
   327  		mi.init()
   328  		if !mi.needsInitCheck {
   329  			return nil
   330  		}
   331  		iter := mapRange(mapv)
   332  		for iter.Next() {
   333  			val := pointerOfValue(iter.Value())
   334  			if err := mi.checkInitializedPointer(val); err != nil {
   335  				return err
   336  			}
   337  		}
   338  	} else {
   339  		iter := mapRange(mapv)
   340  		for iter.Next() {
   341  			val := mapi.conv.valConv.PBValueOf(iter.Value())
   342  			if err := mapi.valFuncs.isInit(val); err != nil {
   343  				return err
   344  			}
   345  		}
   346  	}
   347  	return nil
   348  }
   349  
   350  func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   351  	dstm := dst.AsValueOf(f.ft).Elem()
   352  	srcm := src.AsValueOf(f.ft).Elem()
   353  	if srcm.Len() == 0 {
   354  		return
   355  	}
   356  	if dstm.IsNil() {
   357  		dstm.Set(reflect.MakeMap(f.ft))
   358  	}
   359  	iter := mapRange(srcm)
   360  	for iter.Next() {
   361  		dstm.SetMapIndex(iter.Key(), iter.Value())
   362  	}
   363  }
   364  
   365  func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   366  	dstm := dst.AsValueOf(f.ft).Elem()
   367  	srcm := src.AsValueOf(f.ft).Elem()
   368  	if srcm.Len() == 0 {
   369  		return
   370  	}
   371  	if dstm.IsNil() {
   372  		dstm.Set(reflect.MakeMap(f.ft))
   373  	}
   374  	iter := mapRange(srcm)
   375  	for iter.Next() {
   376  		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
   377  	}
   378  }
   379  
   380  func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
   381  	dstm := dst.AsValueOf(f.ft).Elem()
   382  	srcm := src.AsValueOf(f.ft).Elem()
   383  	if srcm.Len() == 0 {
   384  		return
   385  	}
   386  	if dstm.IsNil() {
   387  		dstm.Set(reflect.MakeMap(f.ft))
   388  	}
   389  	iter := mapRange(srcm)
   390  	for iter.Next() {
   391  		val := reflect.New(f.ft.Elem().Elem())
   392  		if f.mi != nil {
   393  			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
   394  		} else {
   395  			opts.Merge(asMessage(val), asMessage(iter.Value()))
   396  		}
   397  		dstm.SetMapIndex(iter.Key(), val)
   398  	}
   399  }
   400  

View as plain text