...

Source file src/github.com/gogo/protobuf/proto/extensions_test.go

Documentation: github.com/gogo/protobuf/proto

     1  // Go support for Protocol Buffers - Google's data interchange format
     2  //
     3  // Copyright 2014 The Go Authors.  All rights reserved.
     4  // https://github.com/golang/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //     * Neither the name of Google Inc. nor the names of its
    17  // contributors may be used to endorse or promote products derived from
    18  // this software without specific prior written permission.
    19  //
    20  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    21  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    22  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    23  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    24  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    25  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    26  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    27  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    28  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    29  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    30  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    31  
    32  package proto_test
    33  
    34  import (
    35  	"bytes"
    36  	"fmt"
    37  	"io"
    38  	"reflect"
    39  	"sort"
    40  	"strings"
    41  	"testing"
    42  
    43  	"github.com/gogo/protobuf/proto"
    44  	pb "github.com/gogo/protobuf/proto/test_proto"
    45  )
    46  
    47  func TestGetExtensionsWithMissingExtensions(t *testing.T) {
    48  	msg := &pb.MyMessage{}
    49  	ext1 := &pb.Ext{}
    50  	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
    51  		t.Fatalf("Could not set ext1: %s", err)
    52  	}
    53  	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
    54  		pb.E_Ext_More,
    55  		pb.E_Ext_Text,
    56  	})
    57  	if err != nil {
    58  		t.Fatalf("GetExtensions() failed: %s", err)
    59  	}
    60  	if exts[0] != ext1 {
    61  		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
    62  	}
    63  	if exts[1] != nil {
    64  		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
    65  	}
    66  }
    67  
    68  func TestGetExtensionWithEmptyBuffer(t *testing.T) {
    69  	// Make sure that GetExtension returns an error if its
    70  	// undecoded buffer is empty.
    71  	msg := &pb.MyMessage{}
    72  	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
    73  	_, err := proto.GetExtension(msg, pb.E_Ext_More)
    74  	if want := io.ErrUnexpectedEOF; err != want {
    75  		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
    76  	}
    77  }
    78  
    79  func TestGetExtensionForIncompleteDesc(t *testing.T) {
    80  	msg := &pb.MyMessage{Count: proto.Int32(0)}
    81  	extdesc1 := &proto.ExtensionDesc{
    82  		ExtendedType:  (*pb.MyMessage)(nil),
    83  		ExtensionType: (*bool)(nil),
    84  		Field:         123456789,
    85  		Name:          "a.b",
    86  		Tag:           "varint,123456789,opt",
    87  	}
    88  	ext1 := proto.Bool(true)
    89  	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
    90  		t.Fatalf("Could not set ext1: %s", err)
    91  	}
    92  	extdesc2 := &proto.ExtensionDesc{
    93  		ExtendedType:  (*pb.MyMessage)(nil),
    94  		ExtensionType: ([]byte)(nil),
    95  		Field:         123456790,
    96  		Name:          "a.c",
    97  		Tag:           "bytes,123456790,opt",
    98  	}
    99  	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
   100  	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
   101  		t.Fatalf("Could not set ext2: %s", err)
   102  	}
   103  	extdesc3 := &proto.ExtensionDesc{
   104  		ExtendedType:  (*pb.MyMessage)(nil),
   105  		ExtensionType: (*pb.Ext)(nil),
   106  		Field:         123456791,
   107  		Name:          "a.d",
   108  		Tag:           "bytes,123456791,opt",
   109  	}
   110  	ext3 := &pb.Ext{Data: proto.String("foo")}
   111  	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
   112  		t.Fatalf("Could not set ext3: %s", err)
   113  	}
   114  
   115  	b, err := proto.Marshal(msg)
   116  	if err != nil {
   117  		t.Fatalf("Could not marshal msg: %v", err)
   118  	}
   119  	if err := proto.Unmarshal(b, msg); err != nil {
   120  		t.Fatalf("Could not unmarshal into msg: %v", err)
   121  	}
   122  
   123  	var expected proto.Buffer
   124  	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
   125  		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
   126  	}
   127  	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
   128  		t.Fatalf("failed to compute expected value for ext1: %s", err)
   129  	}
   130  
   131  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
   132  		t.Fatalf("Failed to get raw value for ext1: %s", err)
   133  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
   134  		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
   135  	}
   136  
   137  	expected = proto.Buffer{} // reset
   138  	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
   139  		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
   140  	}
   141  	if err := expected.EncodeRawBytes(ext2); err != nil {
   142  		t.Fatalf("failed to compute expected value for ext2: %s", err)
   143  	}
   144  
   145  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
   146  		t.Fatalf("Failed to get raw value for ext2: %s", err)
   147  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
   148  		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
   149  	}
   150  
   151  	expected = proto.Buffer{} // reset
   152  	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
   153  		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
   154  	}
   155  	if b, err := proto.Marshal(ext3); err != nil {
   156  		t.Fatalf("failed to compute expected value for ext3: %s", err)
   157  	} else if err := expected.EncodeRawBytes(b); err != nil {
   158  		t.Fatalf("failed to compute expected value for ext3: %s", err)
   159  	}
   160  
   161  	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
   162  		t.Fatalf("Failed to get raw value for ext3: %s", err)
   163  	} else if !reflect.DeepEqual(b, expected.Bytes()) {
   164  		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
   165  	}
   166  }
   167  
   168  func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
   169  	msg := &pb.MyMessage{Count: proto.Int32(0)}
   170  	extdesc1 := pb.E_Ext_More
   171  	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
   172  		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
   173  	}
   174  
   175  	ext1 := &pb.Ext{}
   176  	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
   177  		t.Fatalf("Could not set ext1: %s", err)
   178  	}
   179  	extdesc2 := &proto.ExtensionDesc{
   180  		ExtendedType:  (*pb.MyMessage)(nil),
   181  		ExtensionType: (*bool)(nil),
   182  		Field:         123456789,
   183  		Name:          "a.b",
   184  		Tag:           "varint,123456789,opt",
   185  	}
   186  	ext2 := proto.Bool(false)
   187  	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
   188  		t.Fatalf("Could not set ext2: %s", err)
   189  	}
   190  
   191  	b, err := proto.Marshal(msg)
   192  	if err != nil {
   193  		t.Fatalf("Could not marshal msg: %v", err)
   194  	}
   195  	if err = proto.Unmarshal(b, msg); err != nil {
   196  		t.Fatalf("Could not unmarshal into msg: %v", err)
   197  	}
   198  
   199  	descs, err := proto.ExtensionDescs(msg)
   200  	if err != nil {
   201  		t.Fatalf("proto.ExtensionDescs: got error %v", err)
   202  	}
   203  	sortExtDescs(descs)
   204  	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
   205  	if !reflect.DeepEqual(descs, wantDescs) {
   206  		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
   207  	}
   208  }
   209  
   210  type ExtensionDescSlice []*proto.ExtensionDesc
   211  
   212  func (s ExtensionDescSlice) Len() int           { return len(s) }
   213  func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
   214  func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
   215  
   216  func sortExtDescs(s []*proto.ExtensionDesc) {
   217  	sort.Sort(ExtensionDescSlice(s))
   218  }
   219  
   220  func TestGetExtensionStability(t *testing.T) {
   221  	check := func(m *pb.MyMessage) bool {
   222  		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
   223  		if err != nil {
   224  			t.Fatalf("GetExtension() failed: %s", err)
   225  		}
   226  		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
   227  		if err != nil {
   228  			t.Fatalf("GetExtension() failed: %s", err)
   229  		}
   230  		return ext1 == ext2
   231  	}
   232  	msg := &pb.MyMessage{Count: proto.Int32(4)}
   233  	ext0 := &pb.Ext{}
   234  	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
   235  		t.Fatalf("Could not set ext1: %s", ext0)
   236  	}
   237  	if !check(msg) {
   238  		t.Errorf("GetExtension() not stable before marshaling")
   239  	}
   240  	bb, err := proto.Marshal(msg)
   241  	if err != nil {
   242  		t.Fatalf("Marshal() failed: %s", err)
   243  	}
   244  	msg1 := &pb.MyMessage{}
   245  	err = proto.Unmarshal(bb, msg1)
   246  	if err != nil {
   247  		t.Fatalf("Unmarshal() failed: %s", err)
   248  	}
   249  	if !check(msg1) {
   250  		t.Errorf("GetExtension() not stable after unmarshaling")
   251  	}
   252  }
   253  
   254  func TestGetExtensionDefaults(t *testing.T) {
   255  	var setFloat64 float64 = 1
   256  	var setFloat32 float32 = 2
   257  	var setInt32 int32 = 3
   258  	var setInt64 int64 = 4
   259  	var setUint32 uint32 = 5
   260  	var setUint64 uint64 = 6
   261  	var setBool = true
   262  	var setBool2 = false
   263  	var setString = "Goodnight string"
   264  	var setBytes = []byte("Goodnight bytes")
   265  	var setEnum = pb.DefaultsMessage_TWO
   266  
   267  	type testcase struct {
   268  		ext  *proto.ExtensionDesc // Extension we are testing.
   269  		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
   270  		def  interface{}          // Expected value of extension after ClearExtension().
   271  	}
   272  	tests := []testcase{
   273  		{pb.E_NoDefaultDouble, setFloat64, nil},
   274  		{pb.E_NoDefaultFloat, setFloat32, nil},
   275  		{pb.E_NoDefaultInt32, setInt32, nil},
   276  		{pb.E_NoDefaultInt64, setInt64, nil},
   277  		{pb.E_NoDefaultUint32, setUint32, nil},
   278  		{pb.E_NoDefaultUint64, setUint64, nil},
   279  		{pb.E_NoDefaultSint32, setInt32, nil},
   280  		{pb.E_NoDefaultSint64, setInt64, nil},
   281  		{pb.E_NoDefaultFixed32, setUint32, nil},
   282  		{pb.E_NoDefaultFixed64, setUint64, nil},
   283  		{pb.E_NoDefaultSfixed32, setInt32, nil},
   284  		{pb.E_NoDefaultSfixed64, setInt64, nil},
   285  		{pb.E_NoDefaultBool, setBool, nil},
   286  		{pb.E_NoDefaultBool, setBool2, nil},
   287  		{pb.E_NoDefaultString, setString, nil},
   288  		{pb.E_NoDefaultBytes, setBytes, nil},
   289  		{pb.E_NoDefaultEnum, setEnum, nil},
   290  		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
   291  		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
   292  		{pb.E_DefaultInt32, setInt32, int32(42)},
   293  		{pb.E_DefaultInt64, setInt64, int64(43)},
   294  		{pb.E_DefaultUint32, setUint32, uint32(44)},
   295  		{pb.E_DefaultUint64, setUint64, uint64(45)},
   296  		{pb.E_DefaultSint32, setInt32, int32(46)},
   297  		{pb.E_DefaultSint64, setInt64, int64(47)},
   298  		{pb.E_DefaultFixed32, setUint32, uint32(48)},
   299  		{pb.E_DefaultFixed64, setUint64, uint64(49)},
   300  		{pb.E_DefaultSfixed32, setInt32, int32(50)},
   301  		{pb.E_DefaultSfixed64, setInt64, int64(51)},
   302  		{pb.E_DefaultBool, setBool, true},
   303  		{pb.E_DefaultBool, setBool2, true},
   304  		{pb.E_DefaultString, setString, "Hello, string,def=foo"},
   305  		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
   306  		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
   307  	}
   308  
   309  	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
   310  		val, err := proto.GetExtension(msg, test.ext)
   311  		if err != nil {
   312  			if valWant != nil {
   313  				return fmt.Errorf("GetExtension(): %s", err)
   314  			}
   315  			if want := proto.ErrMissingExtension; err != want {
   316  				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
   317  			}
   318  			return nil
   319  		}
   320  
   321  		// All proto2 extension values are either a pointer to a value or a slice of values.
   322  		ty := reflect.TypeOf(val)
   323  		tyWant := reflect.TypeOf(test.ext.ExtensionType)
   324  		if got, want := ty, tyWant; got != want {
   325  			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
   326  		}
   327  		tye := ty.Elem()
   328  		tyeWant := tyWant.Elem()
   329  		if got, want := tye, tyeWant; got != want {
   330  			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
   331  		}
   332  
   333  		// Check the name of the type of the value.
   334  		// If it is an enum it will be type int32 with the name of the enum.
   335  		if got, want := tye.Name(), tye.Name(); got != want {
   336  			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
   337  		}
   338  
   339  		// Check that value is what we expect.
   340  		// If we have a pointer in val, get the value it points to.
   341  		valExp := val
   342  		if ty.Kind() == reflect.Ptr {
   343  			valExp = reflect.ValueOf(val).Elem().Interface()
   344  		}
   345  		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
   346  			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
   347  		}
   348  
   349  		return nil
   350  	}
   351  
   352  	setTo := func(test testcase) interface{} {
   353  		setTo := reflect.ValueOf(test.want)
   354  		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
   355  			setTo = reflect.New(typ).Elem()
   356  			setTo.Set(reflect.New(setTo.Type().Elem()))
   357  			setTo.Elem().Set(reflect.ValueOf(test.want))
   358  		}
   359  		return setTo.Interface()
   360  	}
   361  
   362  	for _, test := range tests {
   363  		msg := &pb.DefaultsMessage{}
   364  		name := test.ext.Name
   365  
   366  		// Check the initial value.
   367  		if err := checkVal(test, msg, test.def); err != nil {
   368  			t.Errorf("%s: %v", name, err)
   369  		}
   370  
   371  		// Set the per-type value and check value.
   372  		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
   373  		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
   374  			t.Errorf("%s: SetExtension(): %v", name, err)
   375  			continue
   376  		}
   377  		if err := checkVal(test, msg, test.want); err != nil {
   378  			t.Errorf("%s: %v", name, err)
   379  			continue
   380  		}
   381  
   382  		// Set and check the value.
   383  		name += " (cleared)"
   384  		proto.ClearExtension(msg, test.ext)
   385  		if err := checkVal(test, msg, test.def); err != nil {
   386  			t.Errorf("%s: %v", name, err)
   387  		}
   388  	}
   389  }
   390  
   391  func TestNilMessage(t *testing.T) {
   392  	name := "nil interface"
   393  	if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
   394  		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
   395  	} else if !strings.Contains(err.Error(), "extendable") {
   396  		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
   397  	}
   398  
   399  	// Regression tests: all functions of the Extension API
   400  	// used to panic when passed (*M)(nil), where M is a concrete message
   401  	// type.  Now they handle this gracefully as a no-op or reported error.
   402  	var nilMsg *pb.MyMessage
   403  	desc := pb.E_Ext_More
   404  
   405  	isNotExtendable := func(err error) bool {
   406  		return strings.Contains(fmt.Sprint(err), "not extendable")
   407  	}
   408  
   409  	if proto.HasExtension(nilMsg, desc) {
   410  		t.Error("HasExtension(nil) = true")
   411  	}
   412  
   413  	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
   414  		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
   415  	}
   416  
   417  	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
   418  		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
   419  	}
   420  
   421  	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
   422  		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
   423  	}
   424  
   425  	proto.ClearExtension(nilMsg, desc) // no-op
   426  	proto.ClearAllExtensions(nilMsg)   // no-op
   427  }
   428  
   429  func TestExtensionsRoundTrip(t *testing.T) {
   430  	msg := &pb.MyMessage{}
   431  	ext1 := &pb.Ext{
   432  		Data: proto.String("hi"),
   433  	}
   434  	ext2 := &pb.Ext{
   435  		Data: proto.String("there"),
   436  	}
   437  	exists := proto.HasExtension(msg, pb.E_Ext_More)
   438  	if exists {
   439  		t.Error("Extension More present unexpectedly")
   440  	}
   441  	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
   442  		t.Error(err)
   443  	}
   444  	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
   445  		t.Error(err)
   446  	}
   447  	e, err := proto.GetExtension(msg, pb.E_Ext_More)
   448  	if err != nil {
   449  		t.Error(err)
   450  	}
   451  	x, ok := e.(*pb.Ext)
   452  	if !ok {
   453  		t.Errorf("e has type %T, expected test_proto.Ext", e)
   454  	} else if *x.Data != "there" {
   455  		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
   456  	}
   457  	proto.ClearExtension(msg, pb.E_Ext_More)
   458  	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
   459  		t.Errorf("got %v, expected ErrMissingExtension", e)
   460  	}
   461  	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
   462  		t.Error("expected bad extension error, got nil")
   463  	}
   464  	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
   465  		t.Error("expected extension err")
   466  	}
   467  	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
   468  		t.Error("expected some sort of type mismatch error, got nil")
   469  	}
   470  }
   471  
   472  func TestNilExtension(t *testing.T) {
   473  	msg := &pb.MyMessage{
   474  		Count: proto.Int32(1),
   475  	}
   476  	if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
   477  		t.Fatal(err)
   478  	}
   479  	if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
   480  		t.Error("expected SetExtension to fail due to a nil extension")
   481  	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
   482  		t.Errorf("expected error %v, got %v", want, err)
   483  	}
   484  	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
   485  	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
   486  }
   487  
   488  func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
   489  	// Add a repeated extension to the result.
   490  	tests := []struct {
   491  		name string
   492  		ext  []*pb.ComplexExtension
   493  	}{
   494  		{
   495  			"two fields",
   496  			[]*pb.ComplexExtension{
   497  				{First: proto.Int32(7)},
   498  				{Second: proto.Int32(11)},
   499  			},
   500  		},
   501  		{
   502  			"repeated field",
   503  			[]*pb.ComplexExtension{
   504  				{Third: []int32{1000}},
   505  				{Third: []int32{2000}},
   506  			},
   507  		},
   508  		{
   509  			"two fields and repeated field",
   510  			[]*pb.ComplexExtension{
   511  				{Third: []int32{1000}},
   512  				{First: proto.Int32(9)},
   513  				{Second: proto.Int32(21)},
   514  				{Third: []int32{2000}},
   515  			},
   516  		},
   517  	}
   518  	for _, test := range tests {
   519  		// Marshal message with a repeated extension.
   520  		msg1 := new(pb.OtherMessage)
   521  		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
   522  		if err != nil {
   523  			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
   524  		}
   525  		b, err := proto.Marshal(msg1)
   526  		if err != nil {
   527  			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
   528  		}
   529  
   530  		// Unmarshal and read the merged proto.
   531  		msg2 := new(pb.OtherMessage)
   532  		err = proto.Unmarshal(b, msg2)
   533  		if err != nil {
   534  			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
   535  		}
   536  		e, err := proto.GetExtension(msg2, pb.E_RComplex)
   537  		if err != nil {
   538  			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
   539  		}
   540  		ext := e.([]*pb.ComplexExtension)
   541  		if ext == nil {
   542  			t.Fatalf("[%s] Invalid extension", test.name)
   543  		}
   544  		if len(ext) != len(test.ext) {
   545  			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
   546  		}
   547  		for i := range test.ext {
   548  			if !proto.Equal(ext[i], test.ext[i]) {
   549  				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
   550  			}
   551  		}
   552  	}
   553  }
   554  
   555  func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
   556  	// We may see multiple instances of the same extension in the wire
   557  	// format. For example, the proto compiler may encode custom options in
   558  	// this way. Here, we verify that we merge the extensions together.
   559  	tests := []struct {
   560  		name string
   561  		ext  []*pb.ComplexExtension
   562  	}{
   563  		{
   564  			"two fields",
   565  			[]*pb.ComplexExtension{
   566  				{First: proto.Int32(7)},
   567  				{Second: proto.Int32(11)},
   568  			},
   569  		},
   570  		{
   571  			"repeated field",
   572  			[]*pb.ComplexExtension{
   573  				{Third: []int32{1000}},
   574  				{Third: []int32{2000}},
   575  			},
   576  		},
   577  		{
   578  			"two fields and repeated field",
   579  			[]*pb.ComplexExtension{
   580  				{Third: []int32{1000}},
   581  				{First: proto.Int32(9)},
   582  				{Second: proto.Int32(21)},
   583  				{Third: []int32{2000}},
   584  			},
   585  		},
   586  	}
   587  	for _, test := range tests {
   588  		var buf bytes.Buffer
   589  		var want pb.ComplexExtension
   590  
   591  		// Generate a serialized representation of a repeated extension
   592  		// by catenating bytes together.
   593  		for i, e := range test.ext {
   594  			// Merge to create the wanted proto.
   595  			proto.Merge(&want, e)
   596  
   597  			// serialize the message
   598  			msg := new(pb.OtherMessage)
   599  			err := proto.SetExtension(msg, pb.E_Complex, e)
   600  			if err != nil {
   601  				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
   602  			}
   603  			b, err := proto.Marshal(msg)
   604  			if err != nil {
   605  				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
   606  			}
   607  			buf.Write(b)
   608  		}
   609  
   610  		// Unmarshal and read the merged proto.
   611  		msg2 := new(pb.OtherMessage)
   612  		err := proto.Unmarshal(buf.Bytes(), msg2)
   613  		if err != nil {
   614  			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
   615  		}
   616  		e, err := proto.GetExtension(msg2, pb.E_Complex)
   617  		if err != nil {
   618  			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
   619  		}
   620  		ext := e.(*pb.ComplexExtension)
   621  		if ext == nil {
   622  			t.Fatalf("[%s] Invalid extension", test.name)
   623  		}
   624  		if !proto.Equal(ext, &want) {
   625  			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, &want)
   626  
   627  		}
   628  	}
   629  }
   630  
   631  func TestClearAllExtensions(t *testing.T) {
   632  	// unregistered extension
   633  	desc := &proto.ExtensionDesc{
   634  		ExtendedType:  (*pb.MyMessage)(nil),
   635  		ExtensionType: (*bool)(nil),
   636  		Field:         101010100,
   637  		Name:          "emptyextension",
   638  		Tag:           "varint,0,opt",
   639  	}
   640  	m := &pb.MyMessage{}
   641  	if proto.HasExtension(m, desc) {
   642  		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
   643  	}
   644  	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
   645  		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
   646  	}
   647  	if !proto.HasExtension(m, desc) {
   648  		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
   649  	}
   650  	proto.ClearAllExtensions(m)
   651  	if proto.HasExtension(m, desc) {
   652  		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
   653  	}
   654  }
   655  
   656  func TestMarshalRace(t *testing.T) {
   657  	ext := &pb.Ext{}
   658  	m := &pb.MyMessage{Count: proto.Int32(4)}
   659  	if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
   660  		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
   661  	}
   662  
   663  	b, err := proto.Marshal(m)
   664  	if err != nil {
   665  		t.Fatalf("Could not marshal message: %v", err)
   666  	}
   667  	if err := proto.Unmarshal(b, m); err != nil {
   668  		t.Fatalf("Could not unmarshal message: %v", err)
   669  	}
   670  	// after Unmarshal, the extension is in undecoded form.
   671  	// GetExtension will decode it lazily. Make sure this does
   672  	// not race against Marshal.
   673  
   674  	errChan := make(chan error, 6)
   675  	for n := 3; n > 0; n-- {
   676  		go func() {
   677  			_, err := proto.Marshal(m)
   678  			errChan <- err
   679  		}()
   680  		go func() {
   681  			_, err := proto.GetExtension(m, pb.E_Ext_More)
   682  			errChan <- err
   683  		}()
   684  	}
   685  	for i := 0; i < 6; i++ {
   686  		err := <-errChan
   687  		if err != nil {
   688  			t.Fatal(err)
   689  		}
   690  	}
   691  }
   692  

View as plain text