1
2
3
4
5 package jsonpb
6
7 import (
8 "encoding/json"
9 "errors"
10 "fmt"
11 "io"
12 "math"
13 "reflect"
14 "strconv"
15 "strings"
16 "time"
17
18 "github.com/golang/protobuf/proto"
19 "google.golang.org/protobuf/encoding/protojson"
20 protoV2 "google.golang.org/protobuf/proto"
21 "google.golang.org/protobuf/reflect/protoreflect"
22 "google.golang.org/protobuf/reflect/protoregistry"
23 )
24
25 const wrapJSONUnmarshalV2 = false
26
27
28 func UnmarshalNext(d *json.Decoder, m proto.Message) error {
29 return new(Unmarshaler).UnmarshalNext(d, m)
30 }
31
32
33 func Unmarshal(r io.Reader, m proto.Message) error {
34 return new(Unmarshaler).Unmarshal(r, m)
35 }
36
37
38 func UnmarshalString(s string, m proto.Message) error {
39 return new(Unmarshaler).Unmarshal(strings.NewReader(s), m)
40 }
41
42
43
44 type Unmarshaler struct {
45
46
47 AllowUnknownFields bool
48
49
50
51 AnyResolver AnyResolver
52 }
53
54
55
56
57
58
59
60
61
62
63 type JSONPBUnmarshaler interface {
64 UnmarshalJSONPB(*Unmarshaler, []byte) error
65 }
66
67
68 func (u *Unmarshaler) Unmarshal(r io.Reader, m proto.Message) error {
69 return u.UnmarshalNext(json.NewDecoder(r), m)
70 }
71
72
73 func (u *Unmarshaler) UnmarshalNext(d *json.Decoder, m proto.Message) error {
74 if m == nil {
75 return errors.New("invalid nil message")
76 }
77
78
79 raw := json.RawMessage{}
80 if err := d.Decode(&raw); err != nil {
81 return err
82 }
83
84
85
86 if jsu, ok := m.(JSONPBUnmarshaler); ok {
87 return jsu.UnmarshalJSONPB(u, raw)
88 }
89
90 mr := proto.MessageReflect(m)
91
92
93
94 if string(raw) == "null" && mr.Descriptor().FullName() != "google.protobuf.Value" {
95 return nil
96 }
97
98 if wrapJSONUnmarshalV2 {
99
100
101
102 isEmpty := true
103 mr.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
104 isEmpty = false
105 return false
106 })
107 if !isEmpty {
108
109 mr = mr.New()
110
111
112 dst := proto.MessageReflect(m)
113 defer mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
114 dst.Set(fd, v)
115 return true
116 })
117 }
118
119
120 opts := protojson.UnmarshalOptions{
121 DiscardUnknown: u.AllowUnknownFields,
122 }
123 if u.AnyResolver != nil {
124 opts.Resolver = anyResolver{u.AnyResolver}
125 }
126 return opts.Unmarshal(raw, mr.Interface())
127 } else {
128 if err := u.unmarshalMessage(mr, raw); err != nil {
129 return err
130 }
131 return protoV2.CheckInitialized(mr.Interface())
132 }
133 }
134
135 func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error {
136 md := m.Descriptor()
137 fds := md.Fields()
138
139 if jsu, ok := proto.MessageV1(m.Interface()).(JSONPBUnmarshaler); ok {
140 return jsu.UnmarshalJSONPB(u, in)
141 }
142
143 if string(in) == "null" && md.FullName() != "google.protobuf.Value" {
144 return nil
145 }
146
147 switch wellKnownType(md.FullName()) {
148 case "Any":
149 var jsonObject map[string]json.RawMessage
150 if err := json.Unmarshal(in, &jsonObject); err != nil {
151 return err
152 }
153
154 rawTypeURL, ok := jsonObject["@type"]
155 if !ok {
156 return errors.New("Any JSON doesn't have '@type'")
157 }
158 typeURL, err := unquoteString(string(rawTypeURL))
159 if err != nil {
160 return fmt.Errorf("can't unmarshal Any's '@type': %q", rawTypeURL)
161 }
162 m.Set(fds.ByNumber(1), protoreflect.ValueOfString(typeURL))
163
164 var m2 protoreflect.Message
165 if u.AnyResolver != nil {
166 mi, err := u.AnyResolver.Resolve(typeURL)
167 if err != nil {
168 return err
169 }
170 m2 = proto.MessageReflect(mi)
171 } else {
172 mt, err := protoregistry.GlobalTypes.FindMessageByURL(typeURL)
173 if err != nil {
174 if err == protoregistry.NotFound {
175 return fmt.Errorf("could not resolve Any message type: %v", typeURL)
176 }
177 return err
178 }
179 m2 = mt.New()
180 }
181
182 if wellKnownType(m2.Descriptor().FullName()) != "" {
183 rawValue, ok := jsonObject["value"]
184 if !ok {
185 return errors.New("Any JSON doesn't have 'value'")
186 }
187 if err := u.unmarshalMessage(m2, rawValue); err != nil {
188 return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
189 }
190 } else {
191 delete(jsonObject, "@type")
192 rawJSON, err := json.Marshal(jsonObject)
193 if err != nil {
194 return fmt.Errorf("can't generate JSON for Any's nested proto to be unmarshaled: %v", err)
195 }
196 if err = u.unmarshalMessage(m2, rawJSON); err != nil {
197 return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
198 }
199 }
200
201 rawWire, err := protoV2.Marshal(m2.Interface())
202 if err != nil {
203 return fmt.Errorf("can't marshal proto %v into Any.Value: %v", typeURL, err)
204 }
205 m.Set(fds.ByNumber(2), protoreflect.ValueOfBytes(rawWire))
206 return nil
207 case "BoolValue", "BytesValue", "StringValue",
208 "Int32Value", "UInt32Value", "FloatValue",
209 "Int64Value", "UInt64Value", "DoubleValue":
210 fd := fds.ByNumber(1)
211 v, err := u.unmarshalValue(m.NewField(fd), in, fd)
212 if err != nil {
213 return err
214 }
215 m.Set(fd, v)
216 return nil
217 case "Duration":
218 v, err := unquoteString(string(in))
219 if err != nil {
220 return err
221 }
222 d, err := time.ParseDuration(v)
223 if err != nil {
224 return fmt.Errorf("bad Duration: %v", err)
225 }
226
227 sec := d.Nanoseconds() / 1e9
228 nsec := d.Nanoseconds() % 1e9
229 m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
230 m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
231 return nil
232 case "Timestamp":
233 v, err := unquoteString(string(in))
234 if err != nil {
235 return err
236 }
237 t, err := time.Parse(time.RFC3339Nano, v)
238 if err != nil {
239 return fmt.Errorf("bad Timestamp: %v", err)
240 }
241
242 sec := t.Unix()
243 nsec := t.Nanosecond()
244 m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
245 m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
246 return nil
247 case "Value":
248 switch {
249 case string(in) == "null":
250 m.Set(fds.ByNumber(1), protoreflect.ValueOfEnum(0))
251 case string(in) == "true":
252 m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(true))
253 case string(in) == "false":
254 m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(false))
255 case hasPrefixAndSuffix('"', in, '"'):
256 s, err := unquoteString(string(in))
257 if err != nil {
258 return fmt.Errorf("unrecognized type for Value %q", in)
259 }
260 m.Set(fds.ByNumber(3), protoreflect.ValueOfString(s))
261 case hasPrefixAndSuffix('[', in, ']'):
262 v := m.Mutable(fds.ByNumber(6))
263 return u.unmarshalMessage(v.Message(), in)
264 case hasPrefixAndSuffix('{', in, '}'):
265 v := m.Mutable(fds.ByNumber(5))
266 return u.unmarshalMessage(v.Message(), in)
267 default:
268 f, err := strconv.ParseFloat(string(in), 0)
269 if err != nil {
270 return fmt.Errorf("unrecognized type for Value %q", in)
271 }
272 m.Set(fds.ByNumber(2), protoreflect.ValueOfFloat64(f))
273 }
274 return nil
275 case "ListValue":
276 var jsonArray []json.RawMessage
277 if err := json.Unmarshal(in, &jsonArray); err != nil {
278 return fmt.Errorf("bad ListValue: %v", err)
279 }
280
281 lv := m.Mutable(fds.ByNumber(1)).List()
282 for _, raw := range jsonArray {
283 ve := lv.NewElement()
284 if err := u.unmarshalMessage(ve.Message(), raw); err != nil {
285 return err
286 }
287 lv.Append(ve)
288 }
289 return nil
290 case "Struct":
291 var jsonObject map[string]json.RawMessage
292 if err := json.Unmarshal(in, &jsonObject); err != nil {
293 return fmt.Errorf("bad StructValue: %v", err)
294 }
295
296 mv := m.Mutable(fds.ByNumber(1)).Map()
297 for key, raw := range jsonObject {
298 kv := protoreflect.ValueOf(key).MapKey()
299 vv := mv.NewValue()
300 if err := u.unmarshalMessage(vv.Message(), raw); err != nil {
301 return fmt.Errorf("bad value in StructValue for key %q: %v", key, err)
302 }
303 mv.Set(kv, vv)
304 }
305 return nil
306 }
307
308 var jsonObject map[string]json.RawMessage
309 if err := json.Unmarshal(in, &jsonObject); err != nil {
310 return err
311 }
312
313
314 for i := 0; i < fds.Len(); i++ {
315 fd := fds.Get(i)
316 if fd.IsWeak() && fd.Message().IsPlaceholder() {
317 continue
318 }
319
320
321 var raw json.RawMessage
322 name := string(fd.Name())
323 if fd.Kind() == protoreflect.GroupKind {
324 name = string(fd.Message().Name())
325 }
326 if v, ok := jsonObject[name]; ok {
327 delete(jsonObject, name)
328 raw = v
329 }
330 name = string(fd.JSONName())
331 if v, ok := jsonObject[name]; ok {
332 delete(jsonObject, name)
333 raw = v
334 }
335
336 field := m.NewField(fd)
337
338 if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) {
339 continue
340 }
341 v, err := u.unmarshalValue(field, raw, fd)
342 if err != nil {
343 return err
344 }
345 m.Set(fd, v)
346 }
347
348
349 for name, raw := range jsonObject {
350 if !strings.HasPrefix(name, "[") || !strings.HasSuffix(name, "]") {
351 continue
352 }
353
354
355 xname := protoreflect.FullName(name[len("[") : len(name)-len("]")])
356 xt, _ := protoregistry.GlobalTypes.FindExtensionByName(xname)
357 if xt == nil && isMessageSet(md) {
358 xt, _ = protoregistry.GlobalTypes.FindExtensionByName(xname.Append("message_set_extension"))
359 }
360 if xt == nil {
361 continue
362 }
363 delete(jsonObject, name)
364 fd := xt.TypeDescriptor()
365 if fd.ContainingMessage().FullName() != m.Descriptor().FullName() {
366 return fmt.Errorf("extension field %q does not extend message %q", xname, m.Descriptor().FullName())
367 }
368
369 field := m.NewField(fd)
370
371 if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) {
372 continue
373 }
374 v, err := u.unmarshalValue(field, raw, fd)
375 if err != nil {
376 return err
377 }
378 m.Set(fd, v)
379 }
380
381 if !u.AllowUnknownFields && len(jsonObject) > 0 {
382 for name := range jsonObject {
383 return fmt.Errorf("unknown field %q in %v", name, md.FullName())
384 }
385 }
386 return nil
387 }
388
389 func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool {
390 if fd.Cardinality() == protoreflect.Repeated {
391 return false
392 }
393 if md := fd.Message(); md != nil {
394 return md.FullName() == "google.protobuf.Value"
395 }
396 if ed := fd.Enum(); ed != nil {
397 return ed.FullName() == "google.protobuf.NullValue"
398 }
399 return false
400 }
401
402 func isSingularJSONPBUnmarshaler(v protoreflect.Value, fd protoreflect.FieldDescriptor) bool {
403 if fd.Message() != nil && fd.Cardinality() != protoreflect.Repeated {
404 _, ok := proto.MessageV1(v.Interface()).(JSONPBUnmarshaler)
405 return ok
406 }
407 return false
408 }
409
410 func (u *Unmarshaler) unmarshalValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
411 switch {
412 case fd.IsList():
413 var jsonArray []json.RawMessage
414 if err := json.Unmarshal(in, &jsonArray); err != nil {
415 return v, err
416 }
417 lv := v.List()
418 for _, raw := range jsonArray {
419 ve, err := u.unmarshalSingularValue(lv.NewElement(), raw, fd)
420 if err != nil {
421 return v, err
422 }
423 lv.Append(ve)
424 }
425 return v, nil
426 case fd.IsMap():
427 var jsonObject map[string]json.RawMessage
428 if err := json.Unmarshal(in, &jsonObject); err != nil {
429 return v, err
430 }
431 kfd := fd.MapKey()
432 vfd := fd.MapValue()
433 mv := v.Map()
434 for key, raw := range jsonObject {
435 var kv protoreflect.MapKey
436 if kfd.Kind() == protoreflect.StringKind {
437 kv = protoreflect.ValueOf(key).MapKey()
438 } else {
439 v, err := u.unmarshalSingularValue(kfd.Default(), []byte(key), kfd)
440 if err != nil {
441 return v, err
442 }
443 kv = v.MapKey()
444 }
445
446 vv, err := u.unmarshalSingularValue(mv.NewValue(), raw, vfd)
447 if err != nil {
448 return v, err
449 }
450 mv.Set(kv, vv)
451 }
452 return v, nil
453 default:
454 return u.unmarshalSingularValue(v, in, fd)
455 }
456 }
457
458 var nonFinite = map[string]float64{
459 `"NaN"`: math.NaN(),
460 `"Infinity"`: math.Inf(+1),
461 `"-Infinity"`: math.Inf(-1),
462 }
463
464 func (u *Unmarshaler) unmarshalSingularValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
465 switch fd.Kind() {
466 case protoreflect.BoolKind:
467 return unmarshalValue(in, new(bool))
468 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
469 return unmarshalValue(trimQuote(in), new(int32))
470 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
471 return unmarshalValue(trimQuote(in), new(int64))
472 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
473 return unmarshalValue(trimQuote(in), new(uint32))
474 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
475 return unmarshalValue(trimQuote(in), new(uint64))
476 case protoreflect.FloatKind:
477 if f, ok := nonFinite[string(in)]; ok {
478 return protoreflect.ValueOfFloat32(float32(f)), nil
479 }
480 return unmarshalValue(trimQuote(in), new(float32))
481 case protoreflect.DoubleKind:
482 if f, ok := nonFinite[string(in)]; ok {
483 return protoreflect.ValueOfFloat64(float64(f)), nil
484 }
485 return unmarshalValue(trimQuote(in), new(float64))
486 case protoreflect.StringKind:
487 return unmarshalValue(in, new(string))
488 case protoreflect.BytesKind:
489 return unmarshalValue(in, new([]byte))
490 case protoreflect.EnumKind:
491 if hasPrefixAndSuffix('"', in, '"') {
492 vd := fd.Enum().Values().ByName(protoreflect.Name(trimQuote(in)))
493 if vd == nil {
494 return v, fmt.Errorf("unknown value %q for enum %s", in, fd.Enum().FullName())
495 }
496 return protoreflect.ValueOfEnum(vd.Number()), nil
497 }
498 return unmarshalValue(in, new(protoreflect.EnumNumber))
499 case protoreflect.MessageKind, protoreflect.GroupKind:
500 err := u.unmarshalMessage(v.Message(), in)
501 return v, err
502 default:
503 panic(fmt.Sprintf("invalid kind %v", fd.Kind()))
504 }
505 }
506
507 func unmarshalValue(in []byte, v interface{}) (protoreflect.Value, error) {
508 err := json.Unmarshal(in, v)
509 return protoreflect.ValueOf(reflect.ValueOf(v).Elem().Interface()), err
510 }
511
512 func unquoteString(in string) (out string, err error) {
513 err = json.Unmarshal([]byte(in), &out)
514 return out, err
515 }
516
517 func hasPrefixAndSuffix(prefix byte, in []byte, suffix byte) bool {
518 if len(in) >= 2 && in[0] == prefix && in[len(in)-1] == suffix {
519 return true
520 }
521 return false
522 }
523
524
525
526 func trimQuote(in []byte) []byte {
527 if len(in) >= 2 && in[0] == '"' && in[len(in)-1] == '"' {
528 in = in[1 : len(in)-1]
529 }
530 return in
531 }
532
View as plain text