1
2
3
4
5 package impl
6
7 import (
8 "reflect"
9 "sort"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/errors"
13 "google.golang.org/protobuf/internal/genid"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 )
16
17 type mapInfo struct {
18 goType reflect.Type
19 keyWiretag uint64
20 valWiretag uint64
21 keyFuncs valueCoderFuncs
22 valFuncs valueCoderFuncs
23 keyZero protoreflect.Value
24 keyKind protoreflect.Kind
25 conv *mapConverter
26 }
27
28 func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
29
30 keyField := fd.MapKey()
31 valField := fd.MapValue()
32 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
33 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
34 keyFuncs := encoderFuncsForValue(keyField)
35 valFuncs := encoderFuncsForValue(valField)
36 conv := newMapConverter(ft, fd)
37
38 mapi := &mapInfo{
39 goType: ft,
40 keyWiretag: keyWiretag,
41 valWiretag: valWiretag,
42 keyFuncs: keyFuncs,
43 valFuncs: valFuncs,
44 keyZero: keyField.Default(),
45 keyKind: keyField.Kind(),
46 conv: conv,
47 }
48 if valField.Kind() == protoreflect.MessageKind {
49 valueMessage = getMessageInfo(ft.Elem())
50 }
51
52 funcs = pointerCoderFuncs{
53 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
54 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
55 },
56 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
57 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
58 },
59 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
60 mp := p.AsValueOf(ft)
61 if mp.Elem().IsNil() {
62 mp.Elem().Set(reflect.MakeMap(mapi.goType))
63 }
64 if f.mi == nil {
65 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
66 } else {
67 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
68 }
69 },
70 }
71 switch valField.Kind() {
72 case protoreflect.MessageKind:
73 funcs.merge = mergeMapOfMessage
74 case protoreflect.BytesKind:
75 funcs.merge = mergeMapOfBytes
76 default:
77 funcs.merge = mergeMap
78 }
79 if valFuncs.isInit != nil {
80 funcs.isInit = func(p pointer, f *coderFieldInfo) error {
81 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
82 }
83 }
84 return valueMessage, funcs
85 }
86
87 const (
88 mapKeyTagSize = 1
89 mapValTagSize = 1
90 )
91
92 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
93 if mapv.Len() == 0 {
94 return 0
95 }
96 n := 0
97 iter := mapRange(mapv)
98 for iter.Next() {
99 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
100 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
101 var valSize int
102 value := mapi.conv.valConv.PBValueOf(iter.Value())
103 if f.mi == nil {
104 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
105 } else {
106 p := pointerOfValue(iter.Value())
107 valSize += mapValTagSize
108 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
109 }
110 n += f.tagsize + protowire.SizeBytes(keySize+valSize)
111 }
112 return n
113 }
114
115 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
116 if wtyp != protowire.BytesType {
117 return out, errUnknown
118 }
119 b, n := protowire.ConsumeBytes(b)
120 if n < 0 {
121 return out, errDecode
122 }
123 var (
124 key = mapi.keyZero
125 val = mapi.conv.valConv.New()
126 )
127 for len(b) > 0 {
128 num, wtyp, n := protowire.ConsumeTag(b)
129 if n < 0 {
130 return out, errDecode
131 }
132 if num > protowire.MaxValidNumber {
133 return out, errDecode
134 }
135 b = b[n:]
136 err := errUnknown
137 switch num {
138 case genid.MapEntry_Key_field_number:
139 var v protoreflect.Value
140 var o unmarshalOutput
141 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
142 if err != nil {
143 break
144 }
145 key = v
146 n = o.n
147 case genid.MapEntry_Value_field_number:
148 var v protoreflect.Value
149 var o unmarshalOutput
150 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
151 if err != nil {
152 break
153 }
154 val = v
155 n = o.n
156 }
157 if err == errUnknown {
158 n = protowire.ConsumeFieldValue(num, wtyp, b)
159 if n < 0 {
160 return out, errDecode
161 }
162 } else if err != nil {
163 return out, err
164 }
165 b = b[n:]
166 }
167 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
168 out.n = n
169 return out, nil
170 }
171
172 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
173 if wtyp != protowire.BytesType {
174 return out, errUnknown
175 }
176 b, n := protowire.ConsumeBytes(b)
177 if n < 0 {
178 return out, errDecode
179 }
180 var (
181 key = mapi.keyZero
182 val = reflect.New(f.mi.GoReflectType.Elem())
183 )
184 for len(b) > 0 {
185 num, wtyp, n := protowire.ConsumeTag(b)
186 if n < 0 {
187 return out, errDecode
188 }
189 if num > protowire.MaxValidNumber {
190 return out, errDecode
191 }
192 b = b[n:]
193 err := errUnknown
194 switch num {
195 case 1:
196 var v protoreflect.Value
197 var o unmarshalOutput
198 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
199 if err != nil {
200 break
201 }
202 key = v
203 n = o.n
204 case 2:
205 if wtyp != protowire.BytesType {
206 break
207 }
208 var v []byte
209 v, n = protowire.ConsumeBytes(b)
210 if n < 0 {
211 return out, errDecode
212 }
213 var o unmarshalOutput
214 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
215 if o.initialized {
216
217
218 out.initialized = true
219 }
220 }
221 if err == errUnknown {
222 n = protowire.ConsumeFieldValue(num, wtyp, b)
223 if n < 0 {
224 return out, errDecode
225 }
226 } else if err != nil {
227 return out, err
228 }
229 b = b[n:]
230 }
231 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
232 out.n = n
233 return out, nil
234 }
235
236 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
237 if f.mi == nil {
238 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
239 val := mapi.conv.valConv.PBValueOf(valrv)
240 size := 0
241 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
242 size += mapi.valFuncs.size(val, mapValTagSize, opts)
243 b = protowire.AppendVarint(b, uint64(size))
244 before := len(b)
245 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
246 if err != nil {
247 return nil, err
248 }
249 b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
250 if measuredSize := len(b) - before; size != measuredSize && err == nil {
251 return nil, errors.MismatchedSizeCalculation(size, measuredSize)
252 }
253 return b, err
254 } else {
255 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
256 val := pointerOfValue(valrv)
257 valSize := f.mi.sizePointer(val, opts)
258 size := 0
259 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
260 size += mapValTagSize + protowire.SizeBytes(valSize)
261 b = protowire.AppendVarint(b, uint64(size))
262 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
263 if err != nil {
264 return nil, err
265 }
266 b = protowire.AppendVarint(b, mapi.valWiretag)
267 b = protowire.AppendVarint(b, uint64(valSize))
268 before := len(b)
269 b, err = f.mi.marshalAppendPointer(b, val, opts)
270 if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
271 return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
272 }
273 return b, err
274 }
275 }
276
277 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
278 if mapv.Len() == 0 {
279 return b, nil
280 }
281 if opts.Deterministic() {
282 return appendMapDeterministic(b, mapv, mapi, f, opts)
283 }
284 iter := mapRange(mapv)
285 for iter.Next() {
286 var err error
287 b = protowire.AppendVarint(b, f.wiretag)
288 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
289 if err != nil {
290 return b, err
291 }
292 }
293 return b, nil
294 }
295
296 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
297 keys := mapv.MapKeys()
298 sort.Slice(keys, func(i, j int) bool {
299 switch keys[i].Kind() {
300 case reflect.Bool:
301 return !keys[i].Bool() && keys[j].Bool()
302 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
303 return keys[i].Int() < keys[j].Int()
304 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
305 return keys[i].Uint() < keys[j].Uint()
306 case reflect.Float32, reflect.Float64:
307 return keys[i].Float() < keys[j].Float()
308 case reflect.String:
309 return keys[i].String() < keys[j].String()
310 default:
311 panic("invalid kind: " + keys[i].Kind().String())
312 }
313 })
314 for _, key := range keys {
315 var err error
316 b = protowire.AppendVarint(b, f.wiretag)
317 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
318 if err != nil {
319 return b, err
320 }
321 }
322 return b, nil
323 }
324
325 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
326 if mi := f.mi; mi != nil {
327 mi.init()
328 if !mi.needsInitCheck {
329 return nil
330 }
331 iter := mapRange(mapv)
332 for iter.Next() {
333 val := pointerOfValue(iter.Value())
334 if err := mi.checkInitializedPointer(val); err != nil {
335 return err
336 }
337 }
338 } else {
339 iter := mapRange(mapv)
340 for iter.Next() {
341 val := mapi.conv.valConv.PBValueOf(iter.Value())
342 if err := mapi.valFuncs.isInit(val); err != nil {
343 return err
344 }
345 }
346 }
347 return nil
348 }
349
350 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
351 dstm := dst.AsValueOf(f.ft).Elem()
352 srcm := src.AsValueOf(f.ft).Elem()
353 if srcm.Len() == 0 {
354 return
355 }
356 if dstm.IsNil() {
357 dstm.Set(reflect.MakeMap(f.ft))
358 }
359 iter := mapRange(srcm)
360 for iter.Next() {
361 dstm.SetMapIndex(iter.Key(), iter.Value())
362 }
363 }
364
365 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
366 dstm := dst.AsValueOf(f.ft).Elem()
367 srcm := src.AsValueOf(f.ft).Elem()
368 if srcm.Len() == 0 {
369 return
370 }
371 if dstm.IsNil() {
372 dstm.Set(reflect.MakeMap(f.ft))
373 }
374 iter := mapRange(srcm)
375 for iter.Next() {
376 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
377 }
378 }
379
380 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
381 dstm := dst.AsValueOf(f.ft).Elem()
382 srcm := src.AsValueOf(f.ft).Elem()
383 if srcm.Len() == 0 {
384 return
385 }
386 if dstm.IsNil() {
387 dstm.Set(reflect.MakeMap(f.ft))
388 }
389 iter := mapRange(srcm)
390 for iter.Next() {
391 val := reflect.New(f.ft.Elem().Elem())
392 if f.mi != nil {
393 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
394 } else {
395 opts.Merge(asMessage(val), asMessage(iter.Value()))
396 }
397 dstm.SetMapIndex(iter.Key(), val)
398 }
399 }
400
View as plain text