...

Source file src/github.com/gogo/protobuf/plugin/compare/compare.go

Documentation: github.com/gogo/protobuf/plugin/compare

     1  // Protocol Buffers for Go with Gadgets
     2  //
     3  // Copyright (c) 2013, The GoGo Authors. All rights reserved.
     4  // http://github.com/gogo/protobuf
     5  //
     6  // Redistribution and use in source and binary forms, with or without
     7  // modification, are permitted provided that the following conditions are
     8  // met:
     9  //
    10  //     * Redistributions of source code must retain the above copyright
    11  // notice, this list of conditions and the following disclaimer.
    12  //     * Redistributions in binary form must reproduce the above
    13  // copyright notice, this list of conditions and the following disclaimer
    14  // in the documentation and/or other materials provided with the
    15  // distribution.
    16  //
    17  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    18  // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    19  // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    20  // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    21  // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    22  // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    23  // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    24  // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    25  // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    26  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    27  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    28  
    29  package compare
    30  
    31  import (
    32  	"github.com/gogo/protobuf/gogoproto"
    33  	"github.com/gogo/protobuf/proto"
    34  	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    35  	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
    36  	"github.com/gogo/protobuf/vanity"
    37  )
    38  
    39  type plugin struct {
    40  	*generator.Generator
    41  	generator.PluginImports
    42  	fmtPkg      generator.Single
    43  	bytesPkg    generator.Single
    44  	sortkeysPkg generator.Single
    45  	protoPkg    generator.Single
    46  }
    47  
    48  func NewPlugin() *plugin {
    49  	return &plugin{}
    50  }
    51  
    52  func (p *plugin) Name() string {
    53  	return "compare"
    54  }
    55  
    56  func (p *plugin) Init(g *generator.Generator) {
    57  	p.Generator = g
    58  }
    59  
    60  func (p *plugin) Generate(file *generator.FileDescriptor) {
    61  	p.PluginImports = generator.NewPluginImports(p.Generator)
    62  	p.fmtPkg = p.NewImport("fmt")
    63  	p.bytesPkg = p.NewImport("bytes")
    64  	p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys")
    65  	p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto")
    66  
    67  	for _, msg := range file.Messages() {
    68  		if msg.DescriptorProto.GetOptions().GetMapEntry() {
    69  			continue
    70  		}
    71  		if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) {
    72  			p.generateMessage(file, msg)
    73  		}
    74  	}
    75  }
    76  
    77  func (p *plugin) generateNullableField(fieldname string) {
    78  	p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
    79  	p.In()
    80  	p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
    81  	p.In()
    82  	p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`)
    83  	p.In()
    84  	p.P(`return -1`)
    85  	p.Out()
    86  	p.P(`}`)
    87  	p.P(`return 1`)
    88  	p.Out()
    89  	p.P(`}`)
    90  	p.Out()
    91  	p.P(`} else if this.`, fieldname, ` != nil {`)
    92  	p.In()
    93  	p.P(`return 1`)
    94  	p.Out()
    95  	p.P(`} else if that1.`, fieldname, ` != nil {`)
    96  	p.In()
    97  	p.P(`return -1`)
    98  	p.Out()
    99  	p.P(`}`)
   100  }
   101  
   102  func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
   103  	p.P(`if that == nil {`)
   104  	p.In()
   105  	p.P(`if this == nil {`)
   106  	p.In()
   107  	p.P(`return 0`)
   108  	p.Out()
   109  	p.P(`}`)
   110  	p.P(`return 1`)
   111  	p.Out()
   112  	p.P(`}`)
   113  	p.P(``)
   114  	p.P(`that1, ok := that.(*`, ccTypeName, `)`)
   115  	p.P(`if !ok {`)
   116  	p.In()
   117  	p.P(`that2, ok := that.(`, ccTypeName, `)`)
   118  	p.P(`if ok {`)
   119  	p.In()
   120  	p.P(`that1 = &that2`)
   121  	p.Out()
   122  	p.P(`} else {`)
   123  	p.In()
   124  	p.P(`return 1`)
   125  	p.Out()
   126  	p.P(`}`)
   127  	p.Out()
   128  	p.P(`}`)
   129  	p.P(`if that1 == nil {`)
   130  	p.In()
   131  	p.P(`if this == nil {`)
   132  	p.In()
   133  	p.P(`return 0`)
   134  	p.Out()
   135  	p.P(`}`)
   136  	p.P(`return 1`)
   137  	p.Out()
   138  	p.P(`} else if this == nil {`)
   139  	p.In()
   140  	p.P(`return -1`)
   141  	p.Out()
   142  	p.P(`}`)
   143  }
   144  
   145  func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
   146  	proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
   147  	fieldname := p.GetOneOfFieldName(message, field)
   148  	repeated := field.IsRepeated()
   149  	ctype := gogoproto.IsCustomType(field)
   150  	nullable := gogoproto.IsNullable(field)
   151  	// oneof := field.OneofIndex != nil
   152  	if !repeated {
   153  		if ctype {
   154  			if nullable {
   155  				p.P(`if that1.`, fieldname, ` == nil {`)
   156  				p.In()
   157  				p.P(`if this.`, fieldname, ` != nil {`)
   158  				p.In()
   159  				p.P(`return 1`)
   160  				p.Out()
   161  				p.P(`}`)
   162  				p.Out()
   163  				p.P(`} else if this.`, fieldname, ` == nil {`)
   164  				p.In()
   165  				p.P(`return -1`)
   166  				p.Out()
   167  				p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`)
   168  			} else {
   169  				p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
   170  			}
   171  			p.In()
   172  			p.P(`return c`)
   173  			p.Out()
   174  			p.P(`}`)
   175  		} else {
   176  			if field.IsMessage() || p.IsGroup(field) {
   177  				if nullable {
   178  					p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
   179  				} else {
   180  					p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`)
   181  				}
   182  				p.In()
   183  				p.P(`return c`)
   184  				p.Out()
   185  				p.P(`}`)
   186  			} else if field.IsBytes() {
   187  				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
   188  				p.In()
   189  				p.P(`return c`)
   190  				p.Out()
   191  				p.P(`}`)
   192  			} else if field.IsString() {
   193  				if nullable && !proto3 {
   194  					p.generateNullableField(fieldname)
   195  				} else {
   196  					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
   197  					p.In()
   198  					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
   199  					p.In()
   200  					p.P(`return -1`)
   201  					p.Out()
   202  					p.P(`}`)
   203  					p.P(`return 1`)
   204  					p.Out()
   205  					p.P(`}`)
   206  				}
   207  			} else if field.IsBool() {
   208  				if nullable && !proto3 {
   209  					p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
   210  					p.In()
   211  					p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
   212  					p.In()
   213  					p.P(`if !*this.`, fieldname, ` {`)
   214  					p.In()
   215  					p.P(`return -1`)
   216  					p.Out()
   217  					p.P(`}`)
   218  					p.P(`return 1`)
   219  					p.Out()
   220  					p.P(`}`)
   221  					p.Out()
   222  					p.P(`} else if this.`, fieldname, ` != nil {`)
   223  					p.In()
   224  					p.P(`return 1`)
   225  					p.Out()
   226  					p.P(`} else if that1.`, fieldname, ` != nil {`)
   227  					p.In()
   228  					p.P(`return -1`)
   229  					p.Out()
   230  					p.P(`}`)
   231  				} else {
   232  					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
   233  					p.In()
   234  					p.P(`if !this.`, fieldname, ` {`)
   235  					p.In()
   236  					p.P(`return -1`)
   237  					p.Out()
   238  					p.P(`}`)
   239  					p.P(`return 1`)
   240  					p.Out()
   241  					p.P(`}`)
   242  				}
   243  			} else {
   244  				if nullable && !proto3 {
   245  					p.generateNullableField(fieldname)
   246  				} else {
   247  					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
   248  					p.In()
   249  					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
   250  					p.In()
   251  					p.P(`return -1`)
   252  					p.Out()
   253  					p.P(`}`)
   254  					p.P(`return 1`)
   255  					p.Out()
   256  					p.P(`}`)
   257  				}
   258  			}
   259  		}
   260  	} else {
   261  		p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`)
   262  		p.In()
   263  		p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`)
   264  		p.In()
   265  		p.P(`return -1`)
   266  		p.Out()
   267  		p.P(`}`)
   268  		p.P(`return 1`)
   269  		p.Out()
   270  		p.P(`}`)
   271  		p.P(`for i := range this.`, fieldname, ` {`)
   272  		p.In()
   273  		if ctype {
   274  			p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
   275  			p.In()
   276  			p.P(`return c`)
   277  			p.Out()
   278  			p.P(`}`)
   279  		} else {
   280  			if p.IsMap(field) {
   281  				m := p.GoMapType(nil, field)
   282  				valuegoTyp, _ := p.GoType(nil, m.ValueField)
   283  				valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
   284  				nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
   285  
   286  				mapValue := m.ValueAliasField
   287  				if mapValue.IsMessage() || p.IsGroup(mapValue) {
   288  					if nullable && valuegoTyp == valuegoAliasTyp {
   289  						p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
   290  					} else {
   291  						// Compare() has a pointer receiver, but map value is a value type
   292  						a := `this.` + fieldname + `[i]`
   293  						b := `that1.` + fieldname + `[i]`
   294  						if valuegoTyp != valuegoAliasTyp {
   295  							// cast back to the type that has the generated methods on it
   296  							a = `(` + valuegoTyp + `)(` + a + `)`
   297  							b = `(` + valuegoTyp + `)(` + b + `)`
   298  						}
   299  						p.P(`a := `, a)
   300  						p.P(`b := `, b)
   301  						if nullable {
   302  							p.P(`if c := a.Compare(b); c != 0 {`)
   303  						} else {
   304  							p.P(`if c := (&a).Compare(&b); c != 0 {`)
   305  						}
   306  					}
   307  					p.In()
   308  					p.P(`return c`)
   309  					p.Out()
   310  					p.P(`}`)
   311  				} else if mapValue.IsBytes() {
   312  					p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
   313  					p.In()
   314  					p.P(`return c`)
   315  					p.Out()
   316  					p.P(`}`)
   317  				} else if mapValue.IsString() {
   318  					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
   319  					p.In()
   320  					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
   321  					p.In()
   322  					p.P(`return -1`)
   323  					p.Out()
   324  					p.P(`}`)
   325  					p.P(`return 1`)
   326  					p.Out()
   327  					p.P(`}`)
   328  				} else {
   329  					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
   330  					p.In()
   331  					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
   332  					p.In()
   333  					p.P(`return -1`)
   334  					p.Out()
   335  					p.P(`}`)
   336  					p.P(`return 1`)
   337  					p.Out()
   338  					p.P(`}`)
   339  				}
   340  			} else if field.IsMessage() || p.IsGroup(field) {
   341  				if nullable {
   342  					p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
   343  					p.In()
   344  					p.P(`return c`)
   345  					p.Out()
   346  					p.P(`}`)
   347  				} else {
   348  					p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`)
   349  					p.In()
   350  					p.P(`return c`)
   351  					p.Out()
   352  					p.P(`}`)
   353  				}
   354  			} else if field.IsBytes() {
   355  				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
   356  				p.In()
   357  				p.P(`return c`)
   358  				p.Out()
   359  				p.P(`}`)
   360  			} else if field.IsString() {
   361  				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
   362  				p.In()
   363  				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
   364  				p.In()
   365  				p.P(`return -1`)
   366  				p.Out()
   367  				p.P(`}`)
   368  				p.P(`return 1`)
   369  				p.Out()
   370  				p.P(`}`)
   371  			} else if field.IsBool() {
   372  				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
   373  				p.In()
   374  				p.P(`if !this.`, fieldname, `[i] {`)
   375  				p.In()
   376  				p.P(`return -1`)
   377  				p.Out()
   378  				p.P(`}`)
   379  				p.P(`return 1`)
   380  				p.Out()
   381  				p.P(`}`)
   382  			} else {
   383  				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
   384  				p.In()
   385  				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
   386  				p.In()
   387  				p.P(`return -1`)
   388  				p.Out()
   389  				p.P(`}`)
   390  				p.P(`return 1`)
   391  				p.Out()
   392  				p.P(`}`)
   393  			}
   394  		}
   395  		p.Out()
   396  		p.P(`}`)
   397  	}
   398  }
   399  
   400  func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) {
   401  	ccTypeName := generator.CamelCaseSlice(message.TypeName())
   402  	p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
   403  	p.In()
   404  	p.generateMsgNullAndTypeCheck(ccTypeName)
   405  	oneofs := make(map[string]struct{})
   406  
   407  	for _, field := range message.Field {
   408  		oneof := field.OneofIndex != nil
   409  		if oneof {
   410  			fieldname := p.GetFieldName(message, field)
   411  			if _, ok := oneofs[fieldname]; ok {
   412  				continue
   413  			} else {
   414  				oneofs[fieldname] = struct{}{}
   415  			}
   416  			p.P(`if that1.`, fieldname, ` == nil {`)
   417  			p.In()
   418  			p.P(`if this.`, fieldname, ` != nil {`)
   419  			p.In()
   420  			p.P(`return 1`)
   421  			p.Out()
   422  			p.P(`}`)
   423  			p.Out()
   424  			p.P(`} else if this.`, fieldname, ` == nil {`)
   425  			p.In()
   426  			p.P(`return -1`)
   427  			p.Out()
   428  			p.P(`} else {`)
   429  			p.In()
   430  
   431  			// Generate two type switches in order to compare the
   432  			// types of the oneofs. If they are of the same type
   433  			// call Compare, otherwise return 1 or -1.
   434  			p.P(`thisType := -1`)
   435  			p.P(`switch this.`, fieldname, `.(type) {`)
   436  			for i, subfield := range message.Field {
   437  				if *subfield.OneofIndex == *field.OneofIndex {
   438  					ccTypeName := p.OneOfTypeName(message, subfield)
   439  					p.P(`case *`, ccTypeName, `:`)
   440  					p.In()
   441  					p.P(`thisType = `, i)
   442  					p.Out()
   443  				}
   444  			}
   445  			p.P(`default:`)
   446  			p.In()
   447  			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", this.`, fieldname, `))`)
   448  			p.Out()
   449  			p.P(`}`)
   450  
   451  			p.P(`that1Type := -1`)
   452  			p.P(`switch that1.`, fieldname, `.(type) {`)
   453  			for i, subfield := range message.Field {
   454  				if *subfield.OneofIndex == *field.OneofIndex {
   455  					ccTypeName := p.OneOfTypeName(message, subfield)
   456  					p.P(`case *`, ccTypeName, `:`)
   457  					p.In()
   458  					p.P(`that1Type = `, i)
   459  					p.Out()
   460  				}
   461  			}
   462  			p.P(`default:`)
   463  			p.In()
   464  			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", that1.`, fieldname, `))`)
   465  			p.Out()
   466  			p.P(`}`)
   467  
   468  			p.P(`if thisType == that1Type {`)
   469  			p.In()
   470  			p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
   471  			p.In()
   472  			p.P(`return c`)
   473  			p.Out()
   474  			p.P(`}`)
   475  			p.Out()
   476  			p.P(`} else if thisType < that1Type {`)
   477  			p.In()
   478  			p.P(`return -1`)
   479  			p.Out()
   480  			p.P(`} else if thisType > that1Type {`)
   481  			p.In()
   482  			p.P(`return 1`)
   483  			p.Out()
   484  			p.P(`}`)
   485  			p.Out()
   486  			p.P(`}`)
   487  		} else {
   488  			p.generateField(file, message, field)
   489  		}
   490  	}
   491  	if message.DescriptorProto.HasExtension() {
   492  		if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
   493  			p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`)
   494  			p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`)
   495  			p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`)
   496  			p.P(`for k, _ := range thismap {`)
   497  			p.In()
   498  			p.P(`extkeys = append(extkeys, k)`)
   499  			p.Out()
   500  			p.P(`}`)
   501  			p.P(`for k, _ := range thatmap {`)
   502  			p.In()
   503  			p.P(`if _, ok := thismap[k]; !ok {`)
   504  			p.In()
   505  			p.P(`extkeys = append(extkeys, k)`)
   506  			p.Out()
   507  			p.P(`}`)
   508  			p.Out()
   509  			p.P(`}`)
   510  			p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`)
   511  			p.P(`for _, k := range extkeys {`)
   512  			p.In()
   513  			p.P(`if v, ok := thismap[k]; ok {`)
   514  			p.In()
   515  			p.P(`if v2, ok := thatmap[k]; ok {`)
   516  			p.In()
   517  			p.P(`if c := v.Compare(&v2); c != 0 {`)
   518  			p.In()
   519  			p.P(`return c`)
   520  			p.Out()
   521  			p.P(`}`)
   522  			p.Out()
   523  			p.P(`} else  {`)
   524  			p.In()
   525  			p.P(`return 1`)
   526  			p.Out()
   527  			p.P(`}`)
   528  			p.Out()
   529  			p.P(`} else {`)
   530  			p.In()
   531  			p.P(`return -1`)
   532  			p.Out()
   533  			p.P(`}`)
   534  			p.Out()
   535  			p.P(`}`)
   536  		} else {
   537  			fieldname := "XXX_extensions"
   538  			p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
   539  			p.In()
   540  			p.P(`return c`)
   541  			p.Out()
   542  			p.P(`}`)
   543  		}
   544  	}
   545  	if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
   546  		fieldname := "XXX_unrecognized"
   547  		p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
   548  		p.In()
   549  		p.P(`return c`)
   550  		p.Out()
   551  		p.P(`}`)
   552  	}
   553  	p.P(`return 0`)
   554  	p.Out()
   555  	p.P(`}`)
   556  
   557  	//Generate Compare methods for oneof fields
   558  	m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
   559  	for _, field := range m.Field {
   560  		oneof := field.OneofIndex != nil
   561  		if !oneof {
   562  			continue
   563  		}
   564  		ccTypeName := p.OneOfTypeName(message, field)
   565  		p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
   566  		p.In()
   567  
   568  		p.generateMsgNullAndTypeCheck(ccTypeName)
   569  		vanity.TurnOffNullableForNativeTypes(field)
   570  		p.generateField(file, message, field)
   571  
   572  		p.P(`return 0`)
   573  		p.Out()
   574  		p.P(`}`)
   575  	}
   576  }
   577  
   578  func init() {
   579  	generator.RegisterPlugin(NewPlugin())
   580  }
   581  

View as plain text