1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 package compare
30
31 import (
32 "github.com/gogo/protobuf/gogoproto"
33 "github.com/gogo/protobuf/proto"
34 descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
35 "github.com/gogo/protobuf/protoc-gen-gogo/generator"
36 "github.com/gogo/protobuf/vanity"
37 )
38
39 type plugin struct {
40 *generator.Generator
41 generator.PluginImports
42 fmtPkg generator.Single
43 bytesPkg generator.Single
44 sortkeysPkg generator.Single
45 protoPkg generator.Single
46 }
47
48 func NewPlugin() *plugin {
49 return &plugin{}
50 }
51
52 func (p *plugin) Name() string {
53 return "compare"
54 }
55
56 func (p *plugin) Init(g *generator.Generator) {
57 p.Generator = g
58 }
59
60 func (p *plugin) Generate(file *generator.FileDescriptor) {
61 p.PluginImports = generator.NewPluginImports(p.Generator)
62 p.fmtPkg = p.NewImport("fmt")
63 p.bytesPkg = p.NewImport("bytes")
64 p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys")
65 p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto")
66
67 for _, msg := range file.Messages() {
68 if msg.DescriptorProto.GetOptions().GetMapEntry() {
69 continue
70 }
71 if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) {
72 p.generateMessage(file, msg)
73 }
74 }
75 }
76
77 func (p *plugin) generateNullableField(fieldname string) {
78 p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
79 p.In()
80 p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
81 p.In()
82 p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`)
83 p.In()
84 p.P(`return -1`)
85 p.Out()
86 p.P(`}`)
87 p.P(`return 1`)
88 p.Out()
89 p.P(`}`)
90 p.Out()
91 p.P(`} else if this.`, fieldname, ` != nil {`)
92 p.In()
93 p.P(`return 1`)
94 p.Out()
95 p.P(`} else if that1.`, fieldname, ` != nil {`)
96 p.In()
97 p.P(`return -1`)
98 p.Out()
99 p.P(`}`)
100 }
101
102 func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
103 p.P(`if that == nil {`)
104 p.In()
105 p.P(`if this == nil {`)
106 p.In()
107 p.P(`return 0`)
108 p.Out()
109 p.P(`}`)
110 p.P(`return 1`)
111 p.Out()
112 p.P(`}`)
113 p.P(``)
114 p.P(`that1, ok := that.(*`, ccTypeName, `)`)
115 p.P(`if !ok {`)
116 p.In()
117 p.P(`that2, ok := that.(`, ccTypeName, `)`)
118 p.P(`if ok {`)
119 p.In()
120 p.P(`that1 = &that2`)
121 p.Out()
122 p.P(`} else {`)
123 p.In()
124 p.P(`return 1`)
125 p.Out()
126 p.P(`}`)
127 p.Out()
128 p.P(`}`)
129 p.P(`if that1 == nil {`)
130 p.In()
131 p.P(`if this == nil {`)
132 p.In()
133 p.P(`return 0`)
134 p.Out()
135 p.P(`}`)
136 p.P(`return 1`)
137 p.Out()
138 p.P(`} else if this == nil {`)
139 p.In()
140 p.P(`return -1`)
141 p.Out()
142 p.P(`}`)
143 }
144
145 func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
146 proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
147 fieldname := p.GetOneOfFieldName(message, field)
148 repeated := field.IsRepeated()
149 ctype := gogoproto.IsCustomType(field)
150 nullable := gogoproto.IsNullable(field)
151
152 if !repeated {
153 if ctype {
154 if nullable {
155 p.P(`if that1.`, fieldname, ` == nil {`)
156 p.In()
157 p.P(`if this.`, fieldname, ` != nil {`)
158 p.In()
159 p.P(`return 1`)
160 p.Out()
161 p.P(`}`)
162 p.Out()
163 p.P(`} else if this.`, fieldname, ` == nil {`)
164 p.In()
165 p.P(`return -1`)
166 p.Out()
167 p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`)
168 } else {
169 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
170 }
171 p.In()
172 p.P(`return c`)
173 p.Out()
174 p.P(`}`)
175 } else {
176 if field.IsMessage() || p.IsGroup(field) {
177 if nullable {
178 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
179 } else {
180 p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`)
181 }
182 p.In()
183 p.P(`return c`)
184 p.Out()
185 p.P(`}`)
186 } else if field.IsBytes() {
187 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
188 p.In()
189 p.P(`return c`)
190 p.Out()
191 p.P(`}`)
192 } else if field.IsString() {
193 if nullable && !proto3 {
194 p.generateNullableField(fieldname)
195 } else {
196 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
197 p.In()
198 p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
199 p.In()
200 p.P(`return -1`)
201 p.Out()
202 p.P(`}`)
203 p.P(`return 1`)
204 p.Out()
205 p.P(`}`)
206 }
207 } else if field.IsBool() {
208 if nullable && !proto3 {
209 p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
210 p.In()
211 p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
212 p.In()
213 p.P(`if !*this.`, fieldname, ` {`)
214 p.In()
215 p.P(`return -1`)
216 p.Out()
217 p.P(`}`)
218 p.P(`return 1`)
219 p.Out()
220 p.P(`}`)
221 p.Out()
222 p.P(`} else if this.`, fieldname, ` != nil {`)
223 p.In()
224 p.P(`return 1`)
225 p.Out()
226 p.P(`} else if that1.`, fieldname, ` != nil {`)
227 p.In()
228 p.P(`return -1`)
229 p.Out()
230 p.P(`}`)
231 } else {
232 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
233 p.In()
234 p.P(`if !this.`, fieldname, ` {`)
235 p.In()
236 p.P(`return -1`)
237 p.Out()
238 p.P(`}`)
239 p.P(`return 1`)
240 p.Out()
241 p.P(`}`)
242 }
243 } else {
244 if nullable && !proto3 {
245 p.generateNullableField(fieldname)
246 } else {
247 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
248 p.In()
249 p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
250 p.In()
251 p.P(`return -1`)
252 p.Out()
253 p.P(`}`)
254 p.P(`return 1`)
255 p.Out()
256 p.P(`}`)
257 }
258 }
259 }
260 } else {
261 p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`)
262 p.In()
263 p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`)
264 p.In()
265 p.P(`return -1`)
266 p.Out()
267 p.P(`}`)
268 p.P(`return 1`)
269 p.Out()
270 p.P(`}`)
271 p.P(`for i := range this.`, fieldname, ` {`)
272 p.In()
273 if ctype {
274 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
275 p.In()
276 p.P(`return c`)
277 p.Out()
278 p.P(`}`)
279 } else {
280 if p.IsMap(field) {
281 m := p.GoMapType(nil, field)
282 valuegoTyp, _ := p.GoType(nil, m.ValueField)
283 valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
284 nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
285
286 mapValue := m.ValueAliasField
287 if mapValue.IsMessage() || p.IsGroup(mapValue) {
288 if nullable && valuegoTyp == valuegoAliasTyp {
289 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
290 } else {
291
292 a := `this.` + fieldname + `[i]`
293 b := `that1.` + fieldname + `[i]`
294 if valuegoTyp != valuegoAliasTyp {
295
296 a = `(` + valuegoTyp + `)(` + a + `)`
297 b = `(` + valuegoTyp + `)(` + b + `)`
298 }
299 p.P(`a := `, a)
300 p.P(`b := `, b)
301 if nullable {
302 p.P(`if c := a.Compare(b); c != 0 {`)
303 } else {
304 p.P(`if c := (&a).Compare(&b); c != 0 {`)
305 }
306 }
307 p.In()
308 p.P(`return c`)
309 p.Out()
310 p.P(`}`)
311 } else if mapValue.IsBytes() {
312 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
313 p.In()
314 p.P(`return c`)
315 p.Out()
316 p.P(`}`)
317 } else if mapValue.IsString() {
318 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
319 p.In()
320 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
321 p.In()
322 p.P(`return -1`)
323 p.Out()
324 p.P(`}`)
325 p.P(`return 1`)
326 p.Out()
327 p.P(`}`)
328 } else {
329 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
330 p.In()
331 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
332 p.In()
333 p.P(`return -1`)
334 p.Out()
335 p.P(`}`)
336 p.P(`return 1`)
337 p.Out()
338 p.P(`}`)
339 }
340 } else if field.IsMessage() || p.IsGroup(field) {
341 if nullable {
342 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
343 p.In()
344 p.P(`return c`)
345 p.Out()
346 p.P(`}`)
347 } else {
348 p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`)
349 p.In()
350 p.P(`return c`)
351 p.Out()
352 p.P(`}`)
353 }
354 } else if field.IsBytes() {
355 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
356 p.In()
357 p.P(`return c`)
358 p.Out()
359 p.P(`}`)
360 } else if field.IsString() {
361 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
362 p.In()
363 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
364 p.In()
365 p.P(`return -1`)
366 p.Out()
367 p.P(`}`)
368 p.P(`return 1`)
369 p.Out()
370 p.P(`}`)
371 } else if field.IsBool() {
372 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
373 p.In()
374 p.P(`if !this.`, fieldname, `[i] {`)
375 p.In()
376 p.P(`return -1`)
377 p.Out()
378 p.P(`}`)
379 p.P(`return 1`)
380 p.Out()
381 p.P(`}`)
382 } else {
383 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
384 p.In()
385 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
386 p.In()
387 p.P(`return -1`)
388 p.Out()
389 p.P(`}`)
390 p.P(`return 1`)
391 p.Out()
392 p.P(`}`)
393 }
394 }
395 p.Out()
396 p.P(`}`)
397 }
398 }
399
400 func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) {
401 ccTypeName := generator.CamelCaseSlice(message.TypeName())
402 p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
403 p.In()
404 p.generateMsgNullAndTypeCheck(ccTypeName)
405 oneofs := make(map[string]struct{})
406
407 for _, field := range message.Field {
408 oneof := field.OneofIndex != nil
409 if oneof {
410 fieldname := p.GetFieldName(message, field)
411 if _, ok := oneofs[fieldname]; ok {
412 continue
413 } else {
414 oneofs[fieldname] = struct{}{}
415 }
416 p.P(`if that1.`, fieldname, ` == nil {`)
417 p.In()
418 p.P(`if this.`, fieldname, ` != nil {`)
419 p.In()
420 p.P(`return 1`)
421 p.Out()
422 p.P(`}`)
423 p.Out()
424 p.P(`} else if this.`, fieldname, ` == nil {`)
425 p.In()
426 p.P(`return -1`)
427 p.Out()
428 p.P(`} else {`)
429 p.In()
430
431
432
433
434 p.P(`thisType := -1`)
435 p.P(`switch this.`, fieldname, `.(type) {`)
436 for i, subfield := range message.Field {
437 if *subfield.OneofIndex == *field.OneofIndex {
438 ccTypeName := p.OneOfTypeName(message, subfield)
439 p.P(`case *`, ccTypeName, `:`)
440 p.In()
441 p.P(`thisType = `, i)
442 p.Out()
443 }
444 }
445 p.P(`default:`)
446 p.In()
447 p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", this.`, fieldname, `))`)
448 p.Out()
449 p.P(`}`)
450
451 p.P(`that1Type := -1`)
452 p.P(`switch that1.`, fieldname, `.(type) {`)
453 for i, subfield := range message.Field {
454 if *subfield.OneofIndex == *field.OneofIndex {
455 ccTypeName := p.OneOfTypeName(message, subfield)
456 p.P(`case *`, ccTypeName, `:`)
457 p.In()
458 p.P(`that1Type = `, i)
459 p.Out()
460 }
461 }
462 p.P(`default:`)
463 p.In()
464 p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", that1.`, fieldname, `))`)
465 p.Out()
466 p.P(`}`)
467
468 p.P(`if thisType == that1Type {`)
469 p.In()
470 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
471 p.In()
472 p.P(`return c`)
473 p.Out()
474 p.P(`}`)
475 p.Out()
476 p.P(`} else if thisType < that1Type {`)
477 p.In()
478 p.P(`return -1`)
479 p.Out()
480 p.P(`} else if thisType > that1Type {`)
481 p.In()
482 p.P(`return 1`)
483 p.Out()
484 p.P(`}`)
485 p.Out()
486 p.P(`}`)
487 } else {
488 p.generateField(file, message, field)
489 }
490 }
491 if message.DescriptorProto.HasExtension() {
492 if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
493 p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`)
494 p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`)
495 p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`)
496 p.P(`for k, _ := range thismap {`)
497 p.In()
498 p.P(`extkeys = append(extkeys, k)`)
499 p.Out()
500 p.P(`}`)
501 p.P(`for k, _ := range thatmap {`)
502 p.In()
503 p.P(`if _, ok := thismap[k]; !ok {`)
504 p.In()
505 p.P(`extkeys = append(extkeys, k)`)
506 p.Out()
507 p.P(`}`)
508 p.Out()
509 p.P(`}`)
510 p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`)
511 p.P(`for _, k := range extkeys {`)
512 p.In()
513 p.P(`if v, ok := thismap[k]; ok {`)
514 p.In()
515 p.P(`if v2, ok := thatmap[k]; ok {`)
516 p.In()
517 p.P(`if c := v.Compare(&v2); c != 0 {`)
518 p.In()
519 p.P(`return c`)
520 p.Out()
521 p.P(`}`)
522 p.Out()
523 p.P(`} else {`)
524 p.In()
525 p.P(`return 1`)
526 p.Out()
527 p.P(`}`)
528 p.Out()
529 p.P(`} else {`)
530 p.In()
531 p.P(`return -1`)
532 p.Out()
533 p.P(`}`)
534 p.Out()
535 p.P(`}`)
536 } else {
537 fieldname := "XXX_extensions"
538 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
539 p.In()
540 p.P(`return c`)
541 p.Out()
542 p.P(`}`)
543 }
544 }
545 if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
546 fieldname := "XXX_unrecognized"
547 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
548 p.In()
549 p.P(`return c`)
550 p.Out()
551 p.P(`}`)
552 }
553 p.P(`return 0`)
554 p.Out()
555 p.P(`}`)
556
557
558 m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
559 for _, field := range m.Field {
560 oneof := field.OneofIndex != nil
561 if !oneof {
562 continue
563 }
564 ccTypeName := p.OneOfTypeName(message, field)
565 p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
566 p.In()
567
568 p.generateMsgNullAndTypeCheck(ccTypeName)
569 vanity.TurnOffNullableForNativeTypes(field)
570 p.generateField(file, message, field)
571
572 p.P(`return 0`)
573 p.Out()
574 p.P(`}`)
575 }
576 }
577
578 func init() {
579 generator.RegisterPlugin(NewPlugin())
580 }
581
View as plain text