1
2
3
4
5 package proto_test
6
7 import (
8 "fmt"
9 "reflect"
10 "sync"
11 "testing"
12
13 "github.com/google/go-cmp/cmp"
14
15 "google.golang.org/protobuf/internal/test/race"
16 "google.golang.org/protobuf/proto"
17 "google.golang.org/protobuf/reflect/protoreflect"
18 "google.golang.org/protobuf/runtime/protoimpl"
19 "google.golang.org/protobuf/testing/protocmp"
20
21 legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
22 testpb "google.golang.org/protobuf/internal/testprotos/test"
23 test3pb "google.golang.org/protobuf/internal/testprotos/test3"
24 testeditionspb "google.golang.org/protobuf/internal/testprotos/testeditions"
25 descpb "google.golang.org/protobuf/types/descriptorpb"
26 )
27
28 func TestExtensionFuncs(t *testing.T) {
29 for _, test := range []struct {
30 message proto.Message
31 ext protoreflect.ExtensionType
32 wantDefault interface{}
33 value interface{}
34 }{
35 {
36 message: &testpb.TestAllExtensions{},
37 ext: testpb.E_OptionalInt32,
38 wantDefault: int32(0),
39 value: int32(1),
40 },
41 {
42 message: &testpb.TestAllExtensions{},
43 ext: testpb.E_RepeatedString,
44 wantDefault: ([]string)(nil),
45 value: []string{"a", "b", "c"},
46 },
47 {
48 message: &testeditionspb.TestAllExtensions{},
49 ext: testeditionspb.E_OptionalInt32,
50 wantDefault: int32(0),
51 value: int32(1),
52 },
53 {
54 message: &testeditionspb.TestAllExtensions{},
55 ext: testeditionspb.E_RepeatedString,
56 wantDefault: ([]string)(nil),
57 value: []string{"a", "b", "c"},
58 },
59 {
60 message: protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(),
61 ext: legacy1pb.E_Message_ExtensionOptionalBool,
62 wantDefault: false,
63 value: true,
64 },
65 {
66 message: &descpb.MessageOptions{},
67 ext: test3pb.E_OptionalInt32Ext,
68 wantDefault: int32(0),
69 value: int32(1),
70 },
71 {
72 message: &descpb.MessageOptions{},
73 ext: test3pb.E_RepeatedInt32Ext,
74 wantDefault: ([]int32)(nil),
75 value: []int32{1, 2, 3},
76 },
77 } {
78 if test.ext.TypeDescriptor().HasPresence() == test.ext.TypeDescriptor().IsList() {
79 t.Errorf("Extension %v has presence = %v, want %v", test.ext.TypeDescriptor().FullName(), test.ext.TypeDescriptor().HasPresence(), !test.ext.TypeDescriptor().IsList())
80 }
81 desc := fmt.Sprintf("Extension %v, value %v", test.ext.TypeDescriptor().FullName(), test.value)
82 if proto.HasExtension(test.message, test.ext) {
83 t.Errorf("%v:\nbefore setting extension HasExtension(...) = true, want false", desc)
84 }
85 got := proto.GetExtension(test.message, test.ext)
86 if d := cmp.Diff(test.wantDefault, got); d != "" {
87 t.Errorf("%v:\nbefore setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
88 }
89 proto.SetExtension(test.message, test.ext, test.value)
90 if !proto.HasExtension(test.message, test.ext) {
91 t.Errorf("%v:\nafter setting extension HasExtension(...) = false, want true", desc)
92 }
93 got = proto.GetExtension(test.message, test.ext)
94 if d := cmp.Diff(test.value, got); d != "" {
95 t.Errorf("%v:\nafter setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
96 }
97 proto.ClearExtension(test.message, test.ext)
98 if proto.HasExtension(test.message, test.ext) {
99 t.Errorf("%v:\nafter clearing extension HasExtension(...) = true, want false", desc)
100 }
101 }
102 }
103
104 func TestHasExtensionNoAlloc(t *testing.T) {
105
106
107
108
109 if race.Enabled {
110 t.Skip("HasExtension always allocates in -race mode")
111 }
112
113
114
115 want := int32(42)
116 mEager := &testpb.TestAllExtensions{}
117 proto.SetExtension(mEager, testpb.E_OptionalNestedMessage, &testpb.TestAllExtensions_NestedMessage{
118 A: proto.Int32(want),
119 Corecursive: &testpb.TestAllExtensions{},
120 })
121
122 b, err := proto.Marshal(mEager)
123 if err != nil {
124 t.Fatal(err)
125 }
126 mLazy := &testpb.TestAllExtensions{}
127 if err := proto.Unmarshal(b, mLazy); err != nil {
128 t.Fatal(err)
129 }
130
131 for _, tc := range []struct {
132 name string
133 m proto.Message
134 }{
135 {name: "Nil", m: nil},
136 {name: "Eager", m: mEager},
137 {name: "Lazy", m: mLazy},
138 } {
139 t.Run(tc.name, func(t *testing.T) {
140
141
142
143
144
145
146
147 warmup := true
148 avg := testing.AllocsPerRun(1, func() {
149 if warmup {
150 warmup = false
151 return
152 }
153 proto.HasExtension(tc.m, testpb.E_OptionalNestedMessage)
154 })
155 if avg != 0 {
156 t.Errorf("proto.HasExtension should not allocate, but allocated %.2fx per run", avg)
157 }
158 })
159 }
160 }
161
162 func TestIsValid(t *testing.T) {
163 tests := []struct {
164 xt protoreflect.ExtensionType
165 vi interface{}
166 want bool
167 }{
168 {testpb.E_OptionalBool, nil, false},
169 {testpb.E_OptionalBool, bool(true), true},
170 {testpb.E_OptionalBool, new(bool), false},
171 {testpb.E_OptionalInt32, nil, false},
172 {testpb.E_OptionalInt32, int32(0), true},
173 {testpb.E_OptionalInt32, new(int32), false},
174 {testpb.E_OptionalInt64, nil, false},
175 {testpb.E_OptionalInt64, int64(0), true},
176 {testpb.E_OptionalInt64, new(int64), false},
177 {testpb.E_OptionalUint32, nil, false},
178 {testpb.E_OptionalUint32, uint32(0), true},
179 {testpb.E_OptionalUint32, new(uint32), false},
180 {testpb.E_OptionalUint64, nil, false},
181 {testpb.E_OptionalUint64, uint64(0), true},
182 {testpb.E_OptionalUint64, new(uint64), false},
183 {testpb.E_OptionalFloat, nil, false},
184 {testpb.E_OptionalFloat, float32(0), true},
185 {testpb.E_OptionalFloat, new(float32), false},
186 {testpb.E_OptionalDouble, nil, false},
187 {testpb.E_OptionalDouble, float64(0), true},
188 {testpb.E_OptionalDouble, new(float32), false},
189 {testpb.E_OptionalString, nil, false},
190 {testpb.E_OptionalString, string(""), true},
191 {testpb.E_OptionalString, new(string), false},
192 {testpb.E_OptionalNestedEnum, nil, false},
193 {testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ, true},
194 {testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ.Enum(), false},
195 {testpb.E_OptionalNestedMessage, nil, false},
196 {testpb.E_OptionalNestedMessage, (*testpb.TestAllExtensions_NestedMessage)(nil), true},
197 {testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions_NestedMessage), true},
198 {testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions), false},
199 {testpb.E_RepeatedBool, nil, false},
200 {testpb.E_RepeatedBool, []bool(nil), true},
201 {testpb.E_RepeatedBool, []bool{}, true},
202 {testpb.E_RepeatedBool, []bool{false}, true},
203 {testpb.E_RepeatedBool, []*bool{}, false},
204 {testpb.E_RepeatedInt32, nil, false},
205 {testpb.E_RepeatedInt32, []int32(nil), true},
206 {testpb.E_RepeatedInt32, []int32{}, true},
207 {testpb.E_RepeatedInt32, []int32{0}, true},
208 {testpb.E_RepeatedInt32, []*int32{}, false},
209 {testpb.E_RepeatedInt64, nil, false},
210 {testpb.E_RepeatedInt64, []int64(nil), true},
211 {testpb.E_RepeatedInt64, []int64{}, true},
212 {testpb.E_RepeatedInt64, []int64{0}, true},
213 {testpb.E_RepeatedInt64, []*int64{}, false},
214 {testpb.E_RepeatedUint32, nil, false},
215 {testpb.E_RepeatedUint32, []uint32(nil), true},
216 {testpb.E_RepeatedUint32, []uint32{}, true},
217 {testpb.E_RepeatedUint32, []uint32{0}, true},
218 {testpb.E_RepeatedUint32, []*uint32{}, false},
219 {testpb.E_RepeatedUint64, nil, false},
220 {testpb.E_RepeatedUint64, []uint64(nil), true},
221 {testpb.E_RepeatedUint64, []uint64{}, true},
222 {testpb.E_RepeatedUint64, []uint64{0}, true},
223 {testpb.E_RepeatedUint64, []*uint64{}, false},
224 {testpb.E_RepeatedFloat, nil, false},
225 {testpb.E_RepeatedFloat, []float32(nil), true},
226 {testpb.E_RepeatedFloat, []float32{}, true},
227 {testpb.E_RepeatedFloat, []float32{0}, true},
228 {testpb.E_RepeatedFloat, []*float32{}, false},
229 {testpb.E_RepeatedDouble, nil, false},
230 {testpb.E_RepeatedDouble, []float64(nil), true},
231 {testpb.E_RepeatedDouble, []float64{}, true},
232 {testpb.E_RepeatedDouble, []float64{0}, true},
233 {testpb.E_RepeatedDouble, []*float64{}, false},
234 {testpb.E_RepeatedString, nil, false},
235 {testpb.E_RepeatedString, []string(nil), true},
236 {testpb.E_RepeatedString, []string{}, true},
237 {testpb.E_RepeatedString, []string{""}, true},
238 {testpb.E_RepeatedString, []*string{}, false},
239 {testpb.E_RepeatedNestedEnum, nil, false},
240 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum(nil), true},
241 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{}, true},
242 {testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{0}, true},
243 {testpb.E_RepeatedNestedEnum, []*testpb.TestAllTypes_NestedEnum{}, false},
244 {testpb.E_RepeatedNestedMessage, nil, false},
245 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage(nil), true},
246 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{}, true},
247 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{{}}, true},
248 {testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions{}, false},
249 }
250
251 for _, tt := range tests {
252
253 got := tt.xt.IsValidInterface(tt.vi)
254 if got != tt.want {
255 t.Errorf("%v.IsValidInterface() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
256 }
257 if !got {
258 continue
259 }
260
261
262 wantHas := true
263 pv := tt.xt.ValueOf(tt.vi)
264 switch v := pv.Interface().(type) {
265 case protoreflect.List:
266 wantHas = v.Len() > 0
267 case protoreflect.Message:
268 wantHas = v.IsValid()
269 }
270 m := &testpb.TestAllExtensions{}
271 proto.SetExtension(m, tt.xt, tt.vi)
272 gotHas := proto.HasExtension(m, tt.xt)
273 if gotHas != wantHas {
274 t.Errorf("HasExtension(%q) = %v, want %v", tt.xt.TypeDescriptor().FullName(), gotHas, wantHas)
275 }
276
277
278 got = tt.xt.IsValidValue(pv)
279 if got != tt.want {
280 t.Errorf("%v.IsValidValue() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
281 }
282 if !got {
283 continue
284 }
285
286
287
288 vi := tt.xt.InterfaceOf(pv)
289 if !reflect.DeepEqual(vi, tt.vi) {
290 t.Errorf("InterfaceOf(ValueOf(...)) round-trip mismatch: got %v, want %v", vi, tt.vi)
291 }
292 }
293 }
294
295 func TestExtensionRanger(t *testing.T) {
296 tests := []struct {
297 msg proto.Message
298 want map[protoreflect.ExtensionType]interface{}
299 }{{
300 msg: &testpb.TestAllExtensions{},
301 want: map[protoreflect.ExtensionType]interface{}{
302 testpb.E_OptionalInt32: int32(5),
303 testpb.E_OptionalString: string("hello"),
304 testpb.E_OptionalNestedMessage: &testpb.TestAllExtensions_NestedMessage{},
305 testpb.E_OptionalNestedEnum: testpb.TestAllTypes_BAZ,
306 testpb.E_RepeatedFloat: []float32{+32.32, -32.32},
307 testpb.E_RepeatedNestedMessage: []*testpb.TestAllExtensions_NestedMessage{{}},
308 testpb.E_RepeatedNestedEnum: []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_BAZ},
309 },
310 }, {
311 msg: &testeditionspb.TestAllExtensions{},
312 want: map[protoreflect.ExtensionType]interface{}{
313 testeditionspb.E_OptionalInt32: int32(5),
314 testeditionspb.E_OptionalString: string("hello"),
315 testeditionspb.E_OptionalNestedMessage: &testeditionspb.TestAllExtensions_NestedMessage{},
316 testeditionspb.E_OptionalNestedEnum: testeditionspb.TestAllTypes_BAZ,
317 testeditionspb.E_RepeatedFloat: []float32{+32.32, -32.32},
318 testeditionspb.E_RepeatedNestedMessage: []*testeditionspb.TestAllExtensions_NestedMessage{{}},
319 testeditionspb.E_RepeatedNestedEnum: []testeditionspb.TestAllTypes_NestedEnum{testeditionspb.TestAllTypes_BAZ},
320 },
321 }, {
322 msg: &descpb.MessageOptions{},
323 want: map[protoreflect.ExtensionType]interface{}{
324 test3pb.E_OptionalInt32Ext: int32(5),
325 test3pb.E_OptionalStringExt: string("hello"),
326 test3pb.E_OptionalForeignMessageExt: &test3pb.ForeignMessage{},
327 test3pb.E_OptionalForeignEnumExt: test3pb.ForeignEnum_FOREIGN_BAR,
328
329 test3pb.E_OptionalOptionalInt32Ext: int32(5),
330 test3pb.E_OptionalOptionalStringExt: string("hello"),
331 test3pb.E_OptionalOptionalForeignMessageExt: &test3pb.ForeignMessage{},
332 test3pb.E_OptionalOptionalForeignEnumExt: test3pb.ForeignEnum_FOREIGN_BAR,
333 },
334 }}
335
336 for _, tt := range tests {
337 for xt, v := range tt.want {
338 proto.SetExtension(tt.msg, xt, v)
339 }
340
341 got := make(map[protoreflect.ExtensionType]interface{})
342 proto.RangeExtensions(tt.msg, func(xt protoreflect.ExtensionType, v interface{}) bool {
343 got[xt] = v
344 return true
345 })
346
347 if diff := cmp.Diff(tt.want, got, protocmp.Transform()); diff != "" {
348 t.Errorf("proto.RangeExtensions mismatch (-want +got):\n%s", diff)
349 }
350 }
351 }
352
353 func TestExtensionGetRace(t *testing.T) {
354
355
356
357 want := int32(42)
358 m1 := &testpb.TestAllExtensions{}
359 proto.SetExtension(m1, testpb.E_OptionalNestedMessage, &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(want)})
360 b, err := proto.Marshal(m1)
361 if err != nil {
362 t.Fatal(err)
363 }
364 m := &testpb.TestAllExtensions{}
365 if err := proto.Unmarshal(b, m); err != nil {
366 t.Fatal(err)
367 }
368 var wg sync.WaitGroup
369 for i := 0; i < 3; i++ {
370 wg.Add(1)
371 go func() {
372 defer wg.Done()
373 if _, err := proto.Marshal(m); err != nil {
374 t.Error(err)
375 }
376 }()
377 wg.Add(1)
378 go func() {
379 defer wg.Done()
380 got := proto.GetExtension(m, testpb.E_OptionalNestedMessage).(*testpb.TestAllExtensions_NestedMessage).GetA()
381 if got != want {
382 t.Errorf("GetExtension(optional_nested_message).a = %v, want %v", got, want)
383 }
384 }()
385 }
386 wg.Wait()
387 }
388
389 func TestFeatureResolution(t *testing.T) {
390 for _, tc := range []struct {
391 input interface {
392 TypeDescriptor() protoreflect.ExtensionTypeDescriptor
393 }
394 wantPacked bool
395 }{
396 {testeditionspb.E_GlobalExpandedExtension, false},
397 {testeditionspb.E_GlobalPackedExtensionOverriden, true},
398 {testeditionspb.E_RepeatedFieldEncoding_MessageExpandedExtension, false},
399 {testeditionspb.E_RepeatedFieldEncoding_MessagePackedExtensionOverriden, true},
400 {testeditionspb.E_OtherFileGlobalExpandedExtensionOverriden, false},
401 {testeditionspb.E_OtherFileGlobalPackedExtension, true},
402 {testeditionspb.E_OtherRepeatedFieldEncoding_OtherFileMessagePackedExtension, true},
403 {testeditionspb.E_OtherRepeatedFieldEncoding_OtherFileMessageExpandedExtensionOverriden, false},
404 } {
405 if got, want := tc.input.TypeDescriptor().IsPacked(), tc.wantPacked; got != want {
406 t.Errorf("%v.IsPacked() = %v, want %v", tc.input.TypeDescriptor().FullName(), got, want)
407 }
408 }
409 }
410
View as plain text