...

Source file src/github.com/golang/protobuf/proto/extensions_test.go

Documentation: github.com/golang/protobuf/proto

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package proto_test
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"reflect"
    11  	"sort"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  
    16  	"github.com/golang/protobuf/proto"
    17  
    18  	pb2 "github.com/golang/protobuf/internal/testprotos/proto2_proto"
    19  )
    20  
    21  func TestGetExtensionsWithMissingExtensions(t *testing.T) {
    22  	msg := &pb2.MyMessage{}
    23  	ext1 := &pb2.Ext{}
    24  	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil {
    25  		t.Fatalf("Could not set ext1: %s", err)
    26  	}
    27  	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
    28  		pb2.E_Ext_More,
    29  		pb2.E_Ext_Text,
    30  	})
    31  	if err != nil {
    32  		t.Fatalf("GetExtensions() failed: %s", err)
    33  	}
    34  	if exts[0] != ext1 {
    35  		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
    36  	}
    37  	if exts[1] != nil {
    38  		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
    39  	}
    40  }
    41  
    42  func TestGetExtensionForIncompleteDesc(t *testing.T) {
    43  	msg := &pb2.MyMessage{Count: proto.Int32(0)}
    44  	extdesc1 := &proto.ExtensionDesc{
    45  		ExtendedType:  (*pb2.MyMessage)(nil),
    46  		ExtensionType: (*bool)(nil),
    47  		Field:         123456789,
    48  		Name:          "a.b",
    49  		Tag:           "varint,123456789,opt",
    50  	}
    51  	ext1 := proto.Bool(true)
    52  	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
    53  		t.Fatalf("Could not set ext1: %s", err)
    54  	}
    55  	extdesc2 := &proto.ExtensionDesc{
    56  		ExtendedType:  (*pb2.MyMessage)(nil),
    57  		ExtensionType: ([]byte)(nil),
    58  		Field:         123456790,
    59  		Name:          "a.c",
    60  		Tag:           "bytes,123456790,opt",
    61  	}
    62  	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
    63  	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
    64  		t.Fatalf("Could not set ext2: %s", err)
    65  	}
    66  	extdesc3 := &proto.ExtensionDesc{
    67  		ExtendedType:  (*pb2.MyMessage)(nil),
    68  		ExtensionType: (*pb2.Ext)(nil),
    69  		Field:         123456791,
    70  		Name:          "a.d",
    71  		Tag:           "bytes,123456791,opt",
    72  	}
    73  	ext3 := &pb2.Ext{Data: proto.String("foo")}
    74  	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
    75  		t.Fatalf("Could not set ext3: %s", err)
    76  	}
    77  
    78  	b, err := proto.Marshal(msg)
    79  	if err != nil {
    80  		t.Fatalf("Could not marshal msg: %v", err)
    81  	}
    82  	if err := proto.Unmarshal(b, msg); err != nil {
    83  		t.Fatalf("Could not unmarshal into msg: %v", err)
    84  	}
    85  
    86  	var expected proto.Buffer
    87  	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
    88  		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
    89  	}
    90  	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
    91  		t.Fatalf("failed to compute expected value for ext1: %s", err)
    92  	}
    93  
    94  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
    95  		t.Fatalf("Failed to get raw value for ext1: %s", err)
    96  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
    97  		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
    98  	}
    99  
   100  	expected = proto.Buffer{} // reset
   101  	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
   102  		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
   103  	}
   104  	if err := expected.EncodeRawBytes(ext2); err != nil {
   105  		t.Fatalf("failed to compute expected value for ext2: %s", err)
   106  	}
   107  
   108  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
   109  		t.Fatalf("Failed to get raw value for ext2: %s", err)
   110  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
   111  		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
   112  	}
   113  
   114  	expected = proto.Buffer{} // reset
   115  	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
   116  		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
   117  	}
   118  	if b, err := proto.Marshal(ext3); err != nil {
   119  		t.Fatalf("failed to compute expected value for ext3: %s", err)
   120  	} else if err := expected.EncodeRawBytes(b); err != nil {
   121  		t.Fatalf("failed to compute expected value for ext3: %s", err)
   122  	}
   123  
   124  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
   125  		t.Fatalf("Failed to get raw value for ext3: %s", err)
   126  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
   127  		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
   128  	}
   129  }
   130  
   131  func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
   132  	msg := &pb2.MyMessage{Count: proto.Int32(0)}
   133  	extdesc1 := pb2.E_Ext_More
   134  	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
   135  		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
   136  	}
   137  
   138  	ext1 := &pb2.Ext{}
   139  	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
   140  		t.Fatalf("Could not set ext1: %s", err)
   141  	}
   142  	extdesc2 := &proto.ExtensionDesc{
   143  		ExtendedType:  (*pb2.MyMessage)(nil),
   144  		ExtensionType: (*bool)(nil),
   145  		Field:         123456789,
   146  		Name:          "a.b",
   147  		Tag:           "varint,123456789,opt",
   148  	}
   149  	ext2 := proto.Bool(false)
   150  	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
   151  		t.Fatalf("Could not set ext2: %s", err)
   152  	}
   153  
   154  	b, err := proto.Marshal(msg)
   155  	if err != nil {
   156  		t.Fatalf("Could not marshal msg: %v", err)
   157  	}
   158  	if err := proto.Unmarshal(b, msg); err != nil {
   159  		t.Fatalf("Could not unmarshal into msg: %v", err)
   160  	}
   161  
   162  	descs, err := proto.ExtensionDescs(msg)
   163  	if err != nil {
   164  		t.Fatalf("proto.ExtensionDescs: got error %v", err)
   165  	}
   166  	sortExtDescs(descs)
   167  	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
   168  	if !reflect.DeepEqual(descs, wantDescs) {
   169  		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
   170  	}
   171  }
   172  
   173  type ExtensionDescSlice []*proto.ExtensionDesc
   174  
   175  func (s ExtensionDescSlice) Len() int           { return len(s) }
   176  func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
   177  func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
   178  
   179  func sortExtDescs(s []*proto.ExtensionDesc) {
   180  	sort.Sort(ExtensionDescSlice(s))
   181  }
   182  
   183  func TestGetExtensionStability(t *testing.T) {
   184  	check := func(m *pb2.MyMessage) bool {
   185  		ext1, err := proto.GetExtension(m, pb2.E_Ext_More)
   186  		if err != nil {
   187  			t.Fatalf("GetExtension() failed: %s", err)
   188  		}
   189  		ext2, err := proto.GetExtension(m, pb2.E_Ext_More)
   190  		if err != nil {
   191  			t.Fatalf("GetExtension() failed: %s", err)
   192  		}
   193  		return ext1 == ext2
   194  	}
   195  	msg := &pb2.MyMessage{Count: proto.Int32(4)}
   196  	ext0 := &pb2.Ext{}
   197  	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext0); err != nil {
   198  		t.Fatalf("Could not set ext1: %s", ext0)
   199  	}
   200  	if !check(msg) {
   201  		t.Errorf("GetExtension() not stable before marshaling")
   202  	}
   203  	bb, err := proto.Marshal(msg)
   204  	if err != nil {
   205  		t.Fatalf("Marshal() failed: %s", err)
   206  	}
   207  	msg1 := &pb2.MyMessage{}
   208  	err = proto.Unmarshal(bb, msg1)
   209  	if err != nil {
   210  		t.Fatalf("Unmarshal() failed: %s", err)
   211  	}
   212  	if !check(msg1) {
   213  		t.Errorf("GetExtension() not stable after unmarshaling")
   214  	}
   215  }
   216  
   217  func TestGetExtensionDefaults(t *testing.T) {
   218  	var setFloat64 float64 = 1
   219  	var setFloat32 float32 = 2
   220  	var setInt32 int32 = 3
   221  	var setInt64 int64 = 4
   222  	var setUint32 uint32 = 5
   223  	var setUint64 uint64 = 6
   224  	var setBool = true
   225  	var setBool2 = false
   226  	var setString = "Goodnight string"
   227  	var setBytes = []byte("Goodnight bytes")
   228  	var setEnum = pb2.DefaultsMessage_TWO
   229  
   230  	type testcase struct {
   231  		ext  *proto.ExtensionDesc // Extension we are testing.
   232  		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
   233  		def  interface{}          // Expected value of extension after ClearExtension().
   234  	}
   235  	tests := []testcase{
   236  		{pb2.E_NoDefaultDouble, setFloat64, nil},
   237  		{pb2.E_NoDefaultFloat, setFloat32, nil},
   238  		{pb2.E_NoDefaultInt32, setInt32, nil},
   239  		{pb2.E_NoDefaultInt64, setInt64, nil},
   240  		{pb2.E_NoDefaultUint32, setUint32, nil},
   241  		{pb2.E_NoDefaultUint64, setUint64, nil},
   242  		{pb2.E_NoDefaultSint32, setInt32, nil},
   243  		{pb2.E_NoDefaultSint64, setInt64, nil},
   244  		{pb2.E_NoDefaultFixed32, setUint32, nil},
   245  		{pb2.E_NoDefaultFixed64, setUint64, nil},
   246  		{pb2.E_NoDefaultSfixed32, setInt32, nil},
   247  		{pb2.E_NoDefaultSfixed64, setInt64, nil},
   248  		{pb2.E_NoDefaultBool, setBool, nil},
   249  		{pb2.E_NoDefaultBool, setBool2, nil},
   250  		{pb2.E_NoDefaultString, setString, nil},
   251  		{pb2.E_NoDefaultBytes, setBytes, nil},
   252  		{pb2.E_NoDefaultEnum, setEnum, nil},
   253  		{pb2.E_DefaultDouble, setFloat64, float64(3.1415)},
   254  		{pb2.E_DefaultFloat, setFloat32, float32(3.14)},
   255  		{pb2.E_DefaultInt32, setInt32, int32(42)},
   256  		{pb2.E_DefaultInt64, setInt64, int64(43)},
   257  		{pb2.E_DefaultUint32, setUint32, uint32(44)},
   258  		{pb2.E_DefaultUint64, setUint64, uint64(45)},
   259  		{pb2.E_DefaultSint32, setInt32, int32(46)},
   260  		{pb2.E_DefaultSint64, setInt64, int64(47)},
   261  		{pb2.E_DefaultFixed32, setUint32, uint32(48)},
   262  		{pb2.E_DefaultFixed64, setUint64, uint64(49)},
   263  		{pb2.E_DefaultSfixed32, setInt32, int32(50)},
   264  		{pb2.E_DefaultSfixed64, setInt64, int64(51)},
   265  		{pb2.E_DefaultBool, setBool, true},
   266  		{pb2.E_DefaultBool, setBool2, true},
   267  		{pb2.E_DefaultString, setString, "Hello, string,def=foo"},
   268  		{pb2.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
   269  		{pb2.E_DefaultEnum, setEnum, pb2.DefaultsMessage_ONE},
   270  	}
   271  
   272  	checkVal := func(t *testing.T, name string, test testcase, msg *pb2.DefaultsMessage, valWant interface{}) {
   273  		t.Run(name, func(t *testing.T) {
   274  			val, err := proto.GetExtension(msg, test.ext)
   275  			if err != nil {
   276  				if valWant != nil {
   277  					t.Errorf("GetExtension(): %s", err)
   278  					return
   279  				}
   280  				if want := proto.ErrMissingExtension; err != want {
   281  					t.Errorf("Unexpected error: got %v, want %v", err, want)
   282  					return
   283  				}
   284  				return
   285  			}
   286  
   287  			// All proto2 extension values are either a pointer to a value or a slice of values.
   288  			ty := reflect.TypeOf(val)
   289  			tyWant := reflect.TypeOf(test.ext.ExtensionType)
   290  			if got, want := ty, tyWant; got != want {
   291  				t.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
   292  				return
   293  			}
   294  			tye := ty.Elem()
   295  			tyeWant := tyWant.Elem()
   296  			if got, want := tye, tyeWant; got != want {
   297  				t.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
   298  				return
   299  			}
   300  
   301  			// Check the name of the type of the value.
   302  			// If it is an enum it will be type int32 with the name of the enum.
   303  			if got, want := tye.Name(), tye.Name(); got != want {
   304  				t.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
   305  				return
   306  			}
   307  
   308  			// Check that value is what we expect.
   309  			// If we have a pointer in val, get the value it points to.
   310  			valExp := val
   311  			if ty.Kind() == reflect.Ptr {
   312  				valExp = reflect.ValueOf(val).Elem().Interface()
   313  			}
   314  			if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
   315  				t.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
   316  				return
   317  			}
   318  		})
   319  	}
   320  
   321  	setTo := func(test testcase) interface{} {
   322  		setTo := reflect.ValueOf(test.want)
   323  		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
   324  			setTo = reflect.New(typ).Elem()
   325  			setTo.Set(reflect.New(setTo.Type().Elem()))
   326  			setTo.Elem().Set(reflect.ValueOf(test.want))
   327  		}
   328  		return setTo.Interface()
   329  	}
   330  
   331  	for _, test := range tests {
   332  		msg := &pb2.DefaultsMessage{}
   333  		name := test.ext.Name
   334  
   335  		// Check the initial value.
   336  		checkVal(t, name+"/initial", test, msg, test.def)
   337  
   338  		// Set the per-type value and check value.
   339  		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
   340  			t.Errorf("%s: SetExtension(): %v", name, err)
   341  			continue
   342  		}
   343  		checkVal(t, name+"/set", test, msg, test.want)
   344  
   345  		// Set and check the value.
   346  		proto.ClearExtension(msg, test.ext)
   347  		checkVal(t, name+"/cleared", test, msg, test.def)
   348  	}
   349  }
   350  
   351  func TestNilMessage(t *testing.T) {
   352  	name := "nil interface"
   353  	if got, err := proto.GetExtension(nil, pb2.E_Ext_More); err == nil {
   354  		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
   355  	} else if !strings.Contains(err.Error(), "extendable") {
   356  		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
   357  	}
   358  
   359  	// Regression tests: all functions of the Extension API
   360  	// used to panic when passed (*M)(nil), where M is a concrete message
   361  	// type.  Now they handle this gracefully as a no-op or reported error.
   362  	var nilMsg *pb2.MyMessage
   363  	desc := pb2.E_Ext_More
   364  
   365  	isNotExtendable := func(err error) bool {
   366  		return strings.Contains(fmt.Sprint(err), "not an extendable")
   367  	}
   368  
   369  	if proto.HasExtension(nilMsg, desc) {
   370  		t.Error("HasExtension(nil) = true")
   371  	}
   372  
   373  	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
   374  		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
   375  	}
   376  
   377  	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
   378  		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
   379  	}
   380  
   381  	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
   382  		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
   383  	}
   384  
   385  	proto.ClearExtension(nilMsg, desc) // no-op
   386  	proto.ClearAllExtensions(nilMsg)   // no-op
   387  }
   388  
   389  func TestExtensionsRoundTrip(t *testing.T) {
   390  	msg := &pb2.MyMessage{}
   391  	ext1 := &pb2.Ext{
   392  		Data: proto.String("hi"),
   393  	}
   394  	ext2 := &pb2.Ext{
   395  		Data: proto.String("there"),
   396  	}
   397  	exists := proto.HasExtension(msg, pb2.E_Ext_More)
   398  	if exists {
   399  		t.Error("Extension More present unexpectedly")
   400  	}
   401  	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext1); err != nil {
   402  		t.Error(err)
   403  	}
   404  	if err := proto.SetExtension(msg, pb2.E_Ext_More, ext2); err != nil {
   405  		t.Error(err)
   406  	}
   407  	e, err := proto.GetExtension(msg, pb2.E_Ext_More)
   408  	if err != nil {
   409  		t.Error(err)
   410  	}
   411  	x, ok := e.(*pb2.Ext)
   412  	if !ok {
   413  		t.Errorf("e has type %T, expected test_proto.Ext", e)
   414  	} else if *x.Data != "there" {
   415  		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
   416  	}
   417  	proto.ClearExtension(msg, pb2.E_Ext_More)
   418  	if _, err = proto.GetExtension(msg, pb2.E_Ext_More); err != proto.ErrMissingExtension {
   419  		t.Errorf("got %v, expected ErrMissingExtension", e)
   420  	}
   421  	if err := proto.SetExtension(msg, pb2.E_Ext_More, 12); err == nil {
   422  		t.Error("expected some sort of type mismatch error, got nil")
   423  	}
   424  }
   425  
   426  func TestNilExtension(t *testing.T) {
   427  	msg := &pb2.MyMessage{
   428  		Count: proto.Int32(1),
   429  	}
   430  	if err := proto.SetExtension(msg, pb2.E_Ext_Text, proto.String("hello")); err != nil {
   431  		t.Fatal(err)
   432  	}
   433  	if err := proto.SetExtension(msg, pb2.E_Ext_More, (*pb2.Ext)(nil)); err == nil {
   434  		t.Error("expected SetExtension to fail due to a nil extension")
   435  	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb2.Ext)); err.Error() != want {
   436  		t.Errorf("expected error %v, got %v", want, err)
   437  	}
   438  	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
   439  	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
   440  }
   441  
   442  func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
   443  	// Add a repeated extension to the result.
   444  	tests := []struct {
   445  		name string
   446  		ext  []*pb2.ComplexExtension
   447  	}{
   448  		{
   449  			"two fields",
   450  			[]*pb2.ComplexExtension{
   451  				{First: proto.Int32(7)},
   452  				{Second: proto.Int32(11)},
   453  			},
   454  		},
   455  		{
   456  			"repeated field",
   457  			[]*pb2.ComplexExtension{
   458  				{Third: []int32{1000}},
   459  				{Third: []int32{2000}},
   460  			},
   461  		},
   462  		{
   463  			"two fields and repeated field",
   464  			[]*pb2.ComplexExtension{
   465  				{Third: []int32{1000}},
   466  				{First: proto.Int32(9)},
   467  				{Second: proto.Int32(21)},
   468  				{Third: []int32{2000}},
   469  			},
   470  		},
   471  	}
   472  	for _, test := range tests {
   473  		// Marshal message with a repeated extension.
   474  		msg1 := new(pb2.OtherMessage)
   475  		err := proto.SetExtension(msg1, pb2.E_RComplex, test.ext)
   476  		if err != nil {
   477  			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
   478  		}
   479  		b, err := proto.Marshal(msg1)
   480  		if err != nil {
   481  			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
   482  		}
   483  
   484  		// Unmarshal and read the merged proto.
   485  		msg2 := new(pb2.OtherMessage)
   486  		err = proto.Unmarshal(b, msg2)
   487  		if err != nil {
   488  			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
   489  		}
   490  		e, err := proto.GetExtension(msg2, pb2.E_RComplex)
   491  		if err != nil {
   492  			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
   493  		}
   494  		ext := e.([]*pb2.ComplexExtension)
   495  		if ext == nil {
   496  			t.Fatalf("[%s] Invalid extension", test.name)
   497  		}
   498  		if len(ext) != len(test.ext) {
   499  			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
   500  		}
   501  		for i := range test.ext {
   502  			if !proto.Equal(ext[i], test.ext[i]) {
   503  				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
   504  			}
   505  		}
   506  	}
   507  }
   508  
   509  func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
   510  	// We may see multiple instances of the same extension in the wire
   511  	// format. For example, the proto compiler may encode custom options in
   512  	// this way. Here, we verify that we merge the extensions together.
   513  	tests := []struct {
   514  		name string
   515  		ext  []*pb2.ComplexExtension
   516  	}{
   517  		{
   518  			"two fields",
   519  			[]*pb2.ComplexExtension{
   520  				{First: proto.Int32(7)},
   521  				{Second: proto.Int32(11)},
   522  			},
   523  		},
   524  		{
   525  			"repeated field",
   526  			[]*pb2.ComplexExtension{
   527  				{Third: []int32{1000}},
   528  				{Third: []int32{2000}},
   529  			},
   530  		},
   531  		{
   532  			"two fields and repeated field",
   533  			[]*pb2.ComplexExtension{
   534  				{Third: []int32{1000}},
   535  				{First: proto.Int32(9)},
   536  				{Second: proto.Int32(21)},
   537  				{Third: []int32{2000}},
   538  			},
   539  		},
   540  	}
   541  	for _, test := range tests {
   542  		var buf bytes.Buffer
   543  		var want pb2.ComplexExtension
   544  
   545  		// Generate a serialized representation of a repeated extension
   546  		// by catenating bytes together.
   547  		for i, e := range test.ext {
   548  			// Merge to create the wanted proto.
   549  			proto.Merge(&want, e)
   550  
   551  			// serialize the message
   552  			msg := new(pb2.OtherMessage)
   553  			err := proto.SetExtension(msg, pb2.E_Complex, e)
   554  			if err != nil {
   555  				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
   556  			}
   557  			b, err := proto.Marshal(msg)
   558  			if err != nil {
   559  				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
   560  			}
   561  			buf.Write(b)
   562  		}
   563  
   564  		// Unmarshal and read the merged proto.
   565  		msg2 := new(pb2.OtherMessage)
   566  		err := proto.Unmarshal(buf.Bytes(), msg2)
   567  		if err != nil {
   568  			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
   569  		}
   570  		e, err := proto.GetExtension(msg2, pb2.E_Complex)
   571  		if err != nil {
   572  			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
   573  		}
   574  		ext := e.(*pb2.ComplexExtension)
   575  		if ext == nil {
   576  			t.Fatalf("[%s] Invalid extension", test.name)
   577  		}
   578  		if !proto.Equal(ext, &want) {
   579  			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
   580  		}
   581  	}
   582  }
   583  
   584  func TestClearAllExtensions(t *testing.T) {
   585  	// unregistered extension
   586  	desc := &proto.ExtensionDesc{
   587  		ExtendedType:  (*pb2.MyMessage)(nil),
   588  		ExtensionType: (*bool)(nil),
   589  		Field:         101010100,
   590  		Name:          "emptyextension",
   591  		Tag:           "varint,0,opt",
   592  	}
   593  	m := &pb2.MyMessage{}
   594  	if proto.HasExtension(m, desc) {
   595  		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
   596  	}
   597  	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
   598  		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
   599  	}
   600  	if !proto.HasExtension(m, desc) {
   601  		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
   602  	}
   603  	proto.ClearAllExtensions(m)
   604  	if proto.HasExtension(m, desc) {
   605  		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
   606  	}
   607  }
   608  
   609  func TestMarshalRace(t *testing.T) {
   610  	ext := &pb2.Ext{}
   611  	m := &pb2.MyMessage{Count: proto.Int32(4)}
   612  	if err := proto.SetExtension(m, pb2.E_Ext_More, ext); err != nil {
   613  		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
   614  	}
   615  
   616  	b, err := proto.Marshal(m)
   617  	if err != nil {
   618  		t.Fatalf("Could not marshal message: %v", err)
   619  	}
   620  	if err := proto.Unmarshal(b, m); err != nil {
   621  		t.Fatalf("Could not unmarshal message: %v", err)
   622  	}
   623  	// after Unmarshal, the extension is in undecoded form.
   624  	// GetExtension will decode it lazily. Make sure this does
   625  	// not race against Marshal.
   626  
   627  	wg := sync.WaitGroup{}
   628  	errs := make(chan error, 3)
   629  	for n := 3; n > 0; n-- {
   630  		wg.Add(1)
   631  		go func() {
   632  			defer wg.Done()
   633  			_, err := proto.Marshal(m)
   634  			errs <- err
   635  		}()
   636  	}
   637  	wg.Wait()
   638  	close(errs)
   639  
   640  	for err = range errs {
   641  		if err != nil {
   642  			t.Fatal(err)
   643  		}
   644  	}
   645  }
   646  

View as plain text