1
2
3
4
5 package proto_test
6
7 import (
8 "bytes"
9 "fmt"
10 "reflect"
11 "testing"
12
13 "google.golang.org/protobuf/encoding/prototext"
14 "google.golang.org/protobuf/proto"
15 "google.golang.org/protobuf/reflect/protoreflect"
16 "google.golang.org/protobuf/testing/protopack"
17 "google.golang.org/protobuf/types/known/durationpb"
18
19 "google.golang.org/protobuf/internal/errors"
20 testpb "google.golang.org/protobuf/internal/testprotos/test"
21 test3pb "google.golang.org/protobuf/internal/testprotos/test3"
22 )
23
24 func TestDecode(t *testing.T) {
25 for _, test := range testValidMessages {
26 if len(test.decodeTo) == 0 {
27 t.Errorf("%v: no test message types", test.desc)
28 }
29 for _, want := range test.decodeTo {
30 t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
31 opts := test.unmarshalOptions
32 opts.AllowPartial = test.partial
33 wire := append(([]byte)(nil), test.wire...)
34 got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
35 if err := opts.Unmarshal(wire, got); err != nil {
36 t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, prototext.Format(want))
37 return
38 }
39
40
41
42
43 if !bytes.Equal(test.wire, wire) {
44 t.Errorf("Unmarshal unexpectedly modified its input")
45 }
46 for i := range wire {
47 wire[i] = 0
48 }
49 if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
50 t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", prototext.Format(got), prototext.Format(want))
51 }
52 })
53 }
54 }
55 }
56
57 func TestDecodeRequiredFieldChecks(t *testing.T) {
58 for _, test := range testValidMessages {
59 if !test.partial {
60 continue
61 }
62 for _, m := range test.decodeTo {
63 t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
64 opts := test.unmarshalOptions
65 opts.AllowPartial = false
66 got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
67 if err := proto.Unmarshal(test.wire, got); err == nil {
68 t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", prototext.Format(got))
69 }
70 })
71 }
72 }
73 }
74
75 func TestDecodeInvalidMessages(t *testing.T) {
76 for _, test := range testInvalidMessages {
77 if len(test.decodeTo) == 0 {
78 t.Errorf("%v: no test message types", test.desc)
79 }
80 for _, want := range test.decodeTo {
81 t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
82 opts := test.unmarshalOptions
83 opts.AllowPartial = test.partial
84 got := want.ProtoReflect().New().Interface()
85 if err := opts.Unmarshal(test.wire, got); err == nil {
86 t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, prototext.Format(got))
87 } else if !errors.Is(err, proto.Error) {
88 t.Errorf("Unmarshal error is not a proto.Error: %v", err)
89 }
90 })
91 }
92 }
93 }
94
95 func TestDecodeZeroLengthBytes(t *testing.T) {
96
97
98 wire := protopack.Message{
99 protopack.Tag{94, protopack.BytesType}, protopack.Bytes(nil),
100 }.Marshal()
101 m := &test3pb.TestAllTypes{}
102 if err := proto.Unmarshal(wire, m); err != nil {
103 t.Fatal(err)
104 }
105 if m.OptionalBytes != nil {
106 t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
107 }
108 }
109
110 func TestDecodeOneofNilWrapper(t *testing.T) {
111 wire := protopack.Message{
112 protopack.Tag{111, protopack.VarintType}, protopack.Varint(1111),
113 }.Marshal()
114 m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
115 if err := proto.Unmarshal(wire, m); err != nil {
116 t.Fatal(err)
117 }
118 if got := m.GetOneofUint32(); got != 1111 {
119 t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
120 }
121 }
122
123 func TestDecodeEmptyBytes(t *testing.T) {
124
125
126
127 m := &testpb.TestAllTypes{}
128 b := protopack.Message{
129 protopack.Tag{45, protopack.BytesType}, protopack.Bytes(nil),
130 }.Marshal()
131 if err := proto.Unmarshal(b, m); err != nil {
132 t.Fatal(err)
133 }
134 if m.RepeatedBytes[0] == nil {
135 t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
136 }
137 }
138
139 func build(m proto.Message, opts ...buildOpt) proto.Message {
140 for _, opt := range opts {
141 opt(m)
142 }
143 return m
144 }
145
146 type buildOpt func(proto.Message)
147
148 func unknown(raw protoreflect.RawFields) buildOpt {
149 return func(m proto.Message) {
150 m.ProtoReflect().SetUnknown(raw)
151 }
152 }
153
154 func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
155 return func(m proto.Message) {
156 proto.SetExtension(m, desc, value)
157 }
158 }
159
160
161
162 func ExampleUnmarshal() {
163
164
165 b := []byte{0x10, 0x7d}
166
167 var dur durationpb.Duration
168 if err := proto.Unmarshal(b, &dur); err != nil {
169 panic(err)
170 }
171
172 fmt.Printf("Protobuf wire format decoded to duration %v\n", dur.AsDuration())
173
174
175 }
176
View as plain text