...

Source file src/google.golang.org/protobuf/proto/extension_test.go

Documentation: google.golang.org/protobuf/proto

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package proto_test
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"sync"
    11  	"testing"
    12  
    13  	"github.com/google/go-cmp/cmp"
    14  
    15  	"google.golang.org/protobuf/internal/test/race"
    16  	"google.golang.org/protobuf/proto"
    17  	"google.golang.org/protobuf/reflect/protoreflect"
    18  	"google.golang.org/protobuf/runtime/protoimpl"
    19  	"google.golang.org/protobuf/testing/protocmp"
    20  
    21  	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
    22  	testpb "google.golang.org/protobuf/internal/testprotos/test"
    23  	test3pb "google.golang.org/protobuf/internal/testprotos/test3"
    24  	testeditionspb "google.golang.org/protobuf/internal/testprotos/testeditions"
    25  	descpb "google.golang.org/protobuf/types/descriptorpb"
    26  )
    27  
    28  func TestExtensionFuncs(t *testing.T) {
    29  	for _, test := range []struct {
    30  		message     proto.Message
    31  		ext         protoreflect.ExtensionType
    32  		wantDefault interface{}
    33  		value       interface{}
    34  	}{
    35  		{
    36  			message:     &testpb.TestAllExtensions{},
    37  			ext:         testpb.E_OptionalInt32,
    38  			wantDefault: int32(0),
    39  			value:       int32(1),
    40  		},
    41  		{
    42  			message:     &testpb.TestAllExtensions{},
    43  			ext:         testpb.E_RepeatedString,
    44  			wantDefault: ([]string)(nil),
    45  			value:       []string{"a", "b", "c"},
    46  		},
    47  		{
    48  			message:     &testeditionspb.TestAllExtensions{},
    49  			ext:         testeditionspb.E_OptionalInt32,
    50  			wantDefault: int32(0),
    51  			value:       int32(1),
    52  		},
    53  		{
    54  			message:     &testeditionspb.TestAllExtensions{},
    55  			ext:         testeditionspb.E_RepeatedString,
    56  			wantDefault: ([]string)(nil),
    57  			value:       []string{"a", "b", "c"},
    58  		},
    59  		{
    60  			message:     protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(),
    61  			ext:         legacy1pb.E_Message_ExtensionOptionalBool,
    62  			wantDefault: false,
    63  			value:       true,
    64  		},
    65  		{
    66  			message:     &descpb.MessageOptions{},
    67  			ext:         test3pb.E_OptionalInt32Ext,
    68  			wantDefault: int32(0),
    69  			value:       int32(1),
    70  		},
    71  		{
    72  			message:     &descpb.MessageOptions{},
    73  			ext:         test3pb.E_RepeatedInt32Ext,
    74  			wantDefault: ([]int32)(nil),
    75  			value:       []int32{1, 2, 3},
    76  		},
    77  	} {
    78  		if test.ext.TypeDescriptor().HasPresence() == test.ext.TypeDescriptor().IsList() {
    79  			t.Errorf("Extension %v has presence = %v, want %v", test.ext.TypeDescriptor().FullName(), test.ext.TypeDescriptor().HasPresence(), !test.ext.TypeDescriptor().IsList())
    80  		}
    81  		desc := fmt.Sprintf("Extension %v, value %v", test.ext.TypeDescriptor().FullName(), test.value)
    82  		if proto.HasExtension(test.message, test.ext) {
    83  			t.Errorf("%v:\nbefore setting extension HasExtension(...) = true, want false", desc)
    84  		}
    85  		got := proto.GetExtension(test.message, test.ext)
    86  		if d := cmp.Diff(test.wantDefault, got); d != "" {
    87  			t.Errorf("%v:\nbefore setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
    88  		}
    89  		proto.SetExtension(test.message, test.ext, test.value)
    90  		if !proto.HasExtension(test.message, test.ext) {
    91  			t.Errorf("%v:\nafter setting extension HasExtension(...) = false, want true", desc)
    92  		}
    93  		got = proto.GetExtension(test.message, test.ext)
    94  		if d := cmp.Diff(test.value, got); d != "" {
    95  			t.Errorf("%v:\nafter setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d)
    96  		}
    97  		proto.ClearExtension(test.message, test.ext)
    98  		if proto.HasExtension(test.message, test.ext) {
    99  			t.Errorf("%v:\nafter clearing extension HasExtension(...) = true, want false", desc)
   100  		}
   101  	}
   102  }
   103  
   104  func TestHasExtensionNoAlloc(t *testing.T) {
   105  	// If extensions are lazy, they are unmarshaled on first use. Verify that
   106  	// HasExtension does not do this by testing that it does not allocation. This
   107  	// test always passes if extension are eager (the default if protolegacy =
   108  	// false).
   109  	if race.Enabled {
   110  		t.Skip("HasExtension always allocates in -race mode")
   111  	}
   112  	// Create a message with a message extension. Doing it this way produces a
   113  	// non-lazy (eager) variant. Then do a marshal/unmarshal roundtrip to produce
   114  	// a lazy version (if protolegacy = true).
   115  	want := int32(42)
   116  	mEager := &testpb.TestAllExtensions{}
   117  	proto.SetExtension(mEager, testpb.E_OptionalNestedMessage, &testpb.TestAllExtensions_NestedMessage{
   118  		A:           proto.Int32(want),
   119  		Corecursive: &testpb.TestAllExtensions{},
   120  	})
   121  
   122  	b, err := proto.Marshal(mEager)
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  	mLazy := &testpb.TestAllExtensions{}
   127  	if err := proto.Unmarshal(b, mLazy); err != nil {
   128  		t.Fatal(err)
   129  	}
   130  
   131  	for _, tc := range []struct {
   132  		name string
   133  		m    proto.Message
   134  	}{
   135  		{name: "Nil", m: nil},
   136  		{name: "Eager", m: mEager},
   137  		{name: "Lazy", m: mLazy},
   138  	} {
   139  		t.Run(tc.name, func(t *testing.T) {
   140  			// Testing for allocations can be done with `testing.AllocsPerRun`, but it
   141  			// has some snags that complicate its use for us:
   142  			//  - It performs a warmup invocation before starting the measurement. We
   143  			//    want to skip this because lazy initialization only happens once.
   144  			//  - Despite returning a float64, the returned value is an integer, so <1
   145  			//    allocations per operation are returned as 0. Therefore, pass runs =
   146  			//    1.
   147  			warmup := true
   148  			avg := testing.AllocsPerRun(1, func() {
   149  				if warmup {
   150  					warmup = false
   151  					return
   152  				}
   153  				proto.HasExtension(tc.m, testpb.E_OptionalNestedMessage)
   154  			})
   155  			if avg != 0 {
   156  				t.Errorf("proto.HasExtension should not allocate, but allocated %.2fx per run", avg)
   157  			}
   158  		})
   159  	}
   160  }
   161  
   162  func TestIsValid(t *testing.T) {
   163  	tests := []struct {
   164  		xt   protoreflect.ExtensionType
   165  		vi   interface{}
   166  		want bool
   167  	}{
   168  		{testpb.E_OptionalBool, nil, false},
   169  		{testpb.E_OptionalBool, bool(true), true},
   170  		{testpb.E_OptionalBool, new(bool), false},
   171  		{testpb.E_OptionalInt32, nil, false},
   172  		{testpb.E_OptionalInt32, int32(0), true},
   173  		{testpb.E_OptionalInt32, new(int32), false},
   174  		{testpb.E_OptionalInt64, nil, false},
   175  		{testpb.E_OptionalInt64, int64(0), true},
   176  		{testpb.E_OptionalInt64, new(int64), false},
   177  		{testpb.E_OptionalUint32, nil, false},
   178  		{testpb.E_OptionalUint32, uint32(0), true},
   179  		{testpb.E_OptionalUint32, new(uint32), false},
   180  		{testpb.E_OptionalUint64, nil, false},
   181  		{testpb.E_OptionalUint64, uint64(0), true},
   182  		{testpb.E_OptionalUint64, new(uint64), false},
   183  		{testpb.E_OptionalFloat, nil, false},
   184  		{testpb.E_OptionalFloat, float32(0), true},
   185  		{testpb.E_OptionalFloat, new(float32), false},
   186  		{testpb.E_OptionalDouble, nil, false},
   187  		{testpb.E_OptionalDouble, float64(0), true},
   188  		{testpb.E_OptionalDouble, new(float32), false},
   189  		{testpb.E_OptionalString, nil, false},
   190  		{testpb.E_OptionalString, string(""), true},
   191  		{testpb.E_OptionalString, new(string), false},
   192  		{testpb.E_OptionalNestedEnum, nil, false},
   193  		{testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ, true},
   194  		{testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ.Enum(), false},
   195  		{testpb.E_OptionalNestedMessage, nil, false},
   196  		{testpb.E_OptionalNestedMessage, (*testpb.TestAllExtensions_NestedMessage)(nil), true},
   197  		{testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions_NestedMessage), true},
   198  		{testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions), false},
   199  		{testpb.E_RepeatedBool, nil, false},
   200  		{testpb.E_RepeatedBool, []bool(nil), true},
   201  		{testpb.E_RepeatedBool, []bool{}, true},
   202  		{testpb.E_RepeatedBool, []bool{false}, true},
   203  		{testpb.E_RepeatedBool, []*bool{}, false},
   204  		{testpb.E_RepeatedInt32, nil, false},
   205  		{testpb.E_RepeatedInt32, []int32(nil), true},
   206  		{testpb.E_RepeatedInt32, []int32{}, true},
   207  		{testpb.E_RepeatedInt32, []int32{0}, true},
   208  		{testpb.E_RepeatedInt32, []*int32{}, false},
   209  		{testpb.E_RepeatedInt64, nil, false},
   210  		{testpb.E_RepeatedInt64, []int64(nil), true},
   211  		{testpb.E_RepeatedInt64, []int64{}, true},
   212  		{testpb.E_RepeatedInt64, []int64{0}, true},
   213  		{testpb.E_RepeatedInt64, []*int64{}, false},
   214  		{testpb.E_RepeatedUint32, nil, false},
   215  		{testpb.E_RepeatedUint32, []uint32(nil), true},
   216  		{testpb.E_RepeatedUint32, []uint32{}, true},
   217  		{testpb.E_RepeatedUint32, []uint32{0}, true},
   218  		{testpb.E_RepeatedUint32, []*uint32{}, false},
   219  		{testpb.E_RepeatedUint64, nil, false},
   220  		{testpb.E_RepeatedUint64, []uint64(nil), true},
   221  		{testpb.E_RepeatedUint64, []uint64{}, true},
   222  		{testpb.E_RepeatedUint64, []uint64{0}, true},
   223  		{testpb.E_RepeatedUint64, []*uint64{}, false},
   224  		{testpb.E_RepeatedFloat, nil, false},
   225  		{testpb.E_RepeatedFloat, []float32(nil), true},
   226  		{testpb.E_RepeatedFloat, []float32{}, true},
   227  		{testpb.E_RepeatedFloat, []float32{0}, true},
   228  		{testpb.E_RepeatedFloat, []*float32{}, false},
   229  		{testpb.E_RepeatedDouble, nil, false},
   230  		{testpb.E_RepeatedDouble, []float64(nil), true},
   231  		{testpb.E_RepeatedDouble, []float64{}, true},
   232  		{testpb.E_RepeatedDouble, []float64{0}, true},
   233  		{testpb.E_RepeatedDouble, []*float64{}, false},
   234  		{testpb.E_RepeatedString, nil, false},
   235  		{testpb.E_RepeatedString, []string(nil), true},
   236  		{testpb.E_RepeatedString, []string{}, true},
   237  		{testpb.E_RepeatedString, []string{""}, true},
   238  		{testpb.E_RepeatedString, []*string{}, false},
   239  		{testpb.E_RepeatedNestedEnum, nil, false},
   240  		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum(nil), true},
   241  		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{}, true},
   242  		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{0}, true},
   243  		{testpb.E_RepeatedNestedEnum, []*testpb.TestAllTypes_NestedEnum{}, false},
   244  		{testpb.E_RepeatedNestedMessage, nil, false},
   245  		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage(nil), true},
   246  		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{}, true},
   247  		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{{}}, true},
   248  		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions{}, false},
   249  	}
   250  
   251  	for _, tt := range tests {
   252  		// Check the results of IsValidInterface.
   253  		got := tt.xt.IsValidInterface(tt.vi)
   254  		if got != tt.want {
   255  			t.Errorf("%v.IsValidInterface() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
   256  		}
   257  		if !got {
   258  			continue
   259  		}
   260  
   261  		// Set the extension value and verify the results of Has.
   262  		wantHas := true
   263  		pv := tt.xt.ValueOf(tt.vi)
   264  		switch v := pv.Interface().(type) {
   265  		case protoreflect.List:
   266  			wantHas = v.Len() > 0
   267  		case protoreflect.Message:
   268  			wantHas = v.IsValid()
   269  		}
   270  		m := &testpb.TestAllExtensions{}
   271  		proto.SetExtension(m, tt.xt, tt.vi)
   272  		gotHas := proto.HasExtension(m, tt.xt)
   273  		if gotHas != wantHas {
   274  			t.Errorf("HasExtension(%q) = %v, want %v", tt.xt.TypeDescriptor().FullName(), gotHas, wantHas)
   275  		}
   276  
   277  		// Check consistency of IsValidInterface and IsValidValue.
   278  		got = tt.xt.IsValidValue(pv)
   279  		if got != tt.want {
   280  			t.Errorf("%v.IsValidValue() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
   281  		}
   282  		if !got {
   283  			continue
   284  		}
   285  
   286  		// Use of reflect.DeepEqual is intentional.
   287  		// We really do want to ensure that the memory layout is identical.
   288  		vi := tt.xt.InterfaceOf(pv)
   289  		if !reflect.DeepEqual(vi, tt.vi) {
   290  			t.Errorf("InterfaceOf(ValueOf(...)) round-trip mismatch: got %v, want %v", vi, tt.vi)
   291  		}
   292  	}
   293  }
   294  
   295  func TestExtensionRanger(t *testing.T) {
   296  	tests := []struct {
   297  		msg  proto.Message
   298  		want map[protoreflect.ExtensionType]interface{}
   299  	}{{
   300  		msg: &testpb.TestAllExtensions{},
   301  		want: map[protoreflect.ExtensionType]interface{}{
   302  			testpb.E_OptionalInt32:         int32(5),
   303  			testpb.E_OptionalString:        string("hello"),
   304  			testpb.E_OptionalNestedMessage: &testpb.TestAllExtensions_NestedMessage{},
   305  			testpb.E_OptionalNestedEnum:    testpb.TestAllTypes_BAZ,
   306  			testpb.E_RepeatedFloat:         []float32{+32.32, -32.32},
   307  			testpb.E_RepeatedNestedMessage: []*testpb.TestAllExtensions_NestedMessage{{}},
   308  			testpb.E_RepeatedNestedEnum:    []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_BAZ},
   309  		},
   310  	}, {
   311  		msg: &testeditionspb.TestAllExtensions{},
   312  		want: map[protoreflect.ExtensionType]interface{}{
   313  			testeditionspb.E_OptionalInt32:         int32(5),
   314  			testeditionspb.E_OptionalString:        string("hello"),
   315  			testeditionspb.E_OptionalNestedMessage: &testeditionspb.TestAllExtensions_NestedMessage{},
   316  			testeditionspb.E_OptionalNestedEnum:    testeditionspb.TestAllTypes_BAZ,
   317  			testeditionspb.E_RepeatedFloat:         []float32{+32.32, -32.32},
   318  			testeditionspb.E_RepeatedNestedMessage: []*testeditionspb.TestAllExtensions_NestedMessage{{}},
   319  			testeditionspb.E_RepeatedNestedEnum:    []testeditionspb.TestAllTypes_NestedEnum{testeditionspb.TestAllTypes_BAZ},
   320  		},
   321  	}, {
   322  		msg: &descpb.MessageOptions{},
   323  		want: map[protoreflect.ExtensionType]interface{}{
   324  			test3pb.E_OptionalInt32Ext:          int32(5),
   325  			test3pb.E_OptionalStringExt:         string("hello"),
   326  			test3pb.E_OptionalForeignMessageExt: &test3pb.ForeignMessage{},
   327  			test3pb.E_OptionalForeignEnumExt:    test3pb.ForeignEnum_FOREIGN_BAR,
   328  
   329  			test3pb.E_OptionalOptionalInt32Ext:          int32(5),
   330  			test3pb.E_OptionalOptionalStringExt:         string("hello"),
   331  			test3pb.E_OptionalOptionalForeignMessageExt: &test3pb.ForeignMessage{},
   332  			test3pb.E_OptionalOptionalForeignEnumExt:    test3pb.ForeignEnum_FOREIGN_BAR,
   333  		},
   334  	}}
   335  
   336  	for _, tt := range tests {
   337  		for xt, v := range tt.want {
   338  			proto.SetExtension(tt.msg, xt, v)
   339  		}
   340  
   341  		got := make(map[protoreflect.ExtensionType]interface{})
   342  		proto.RangeExtensions(tt.msg, func(xt protoreflect.ExtensionType, v interface{}) bool {
   343  			got[xt] = v
   344  			return true
   345  		})
   346  
   347  		if diff := cmp.Diff(tt.want, got, protocmp.Transform()); diff != "" {
   348  			t.Errorf("proto.RangeExtensions mismatch (-want +got):\n%s", diff)
   349  		}
   350  	}
   351  }
   352  
   353  func TestExtensionGetRace(t *testing.T) {
   354  	// Concurrently fetch an extension value while marshaling the message containing it.
   355  	// Create the message with proto.Unmarshal to give lazy extension decoding (if present)
   356  	// a chance to occur.
   357  	want := int32(42)
   358  	m1 := &testpb.TestAllExtensions{}
   359  	proto.SetExtension(m1, testpb.E_OptionalNestedMessage, &testpb.TestAllExtensions_NestedMessage{A: proto.Int32(want)})
   360  	b, err := proto.Marshal(m1)
   361  	if err != nil {
   362  		t.Fatal(err)
   363  	}
   364  	m := &testpb.TestAllExtensions{}
   365  	if err := proto.Unmarshal(b, m); err != nil {
   366  		t.Fatal(err)
   367  	}
   368  	var wg sync.WaitGroup
   369  	for i := 0; i < 3; i++ {
   370  		wg.Add(1)
   371  		go func() {
   372  			defer wg.Done()
   373  			if _, err := proto.Marshal(m); err != nil {
   374  				t.Error(err)
   375  			}
   376  		}()
   377  		wg.Add(1)
   378  		go func() {
   379  			defer wg.Done()
   380  			got := proto.GetExtension(m, testpb.E_OptionalNestedMessage).(*testpb.TestAllExtensions_NestedMessage).GetA()
   381  			if got != want {
   382  				t.Errorf("GetExtension(optional_nested_message).a = %v, want %v", got, want)
   383  			}
   384  		}()
   385  	}
   386  	wg.Wait()
   387  }
   388  
   389  func TestFeatureResolution(t *testing.T) {
   390  	for _, tc := range []struct {
   391  		input interface {
   392  			TypeDescriptor() protoreflect.ExtensionTypeDescriptor
   393  		}
   394  		wantPacked bool
   395  	}{
   396  		{testeditionspb.E_GlobalExpandedExtension, false},
   397  		{testeditionspb.E_GlobalPackedExtensionOverriden, true},
   398  		{testeditionspb.E_RepeatedFieldEncoding_MessageExpandedExtension, false},
   399  		{testeditionspb.E_RepeatedFieldEncoding_MessagePackedExtensionOverriden, true},
   400  		{testeditionspb.E_OtherFileGlobalExpandedExtensionOverriden, false},
   401  		{testeditionspb.E_OtherFileGlobalPackedExtension, true},
   402  		{testeditionspb.E_OtherRepeatedFieldEncoding_OtherFileMessagePackedExtension, true},
   403  		{testeditionspb.E_OtherRepeatedFieldEncoding_OtherFileMessageExpandedExtensionOverriden, false},
   404  	} {
   405  		if got, want := tc.input.TypeDescriptor().IsPacked(), tc.wantPacked; got != want {
   406  			t.Errorf("%v.IsPacked() = %v, want %v", tc.input.TypeDescriptor().FullName(), got, want)
   407  		}
   408  	}
   409  }
   410  

View as plain text