1
2
3
4
5
6
7
8
9
10
11 package protogen
12
13 import (
14 "bufio"
15 "bytes"
16 "fmt"
17 "go/ast"
18 "go/parser"
19 "go/printer"
20 "go/token"
21 "go/types"
22 "io"
23 "os"
24 "path"
25 "path/filepath"
26 "sort"
27 "strconv"
28 "strings"
29
30 "google.golang.org/protobuf/encoding/prototext"
31 "google.golang.org/protobuf/internal/genid"
32 "google.golang.org/protobuf/internal/strs"
33 "google.golang.org/protobuf/proto"
34 "google.golang.org/protobuf/reflect/protodesc"
35 "google.golang.org/protobuf/reflect/protoreflect"
36 "google.golang.org/protobuf/reflect/protoregistry"
37
38 "google.golang.org/protobuf/types/descriptorpb"
39 "google.golang.org/protobuf/types/dynamicpb"
40 "google.golang.org/protobuf/types/pluginpb"
41 )
42
43 const goPackageDocURL = "https://protobuf.dev/reference/go/go-generated#package"
44
45
46
47
48
49
50
51
52 func (opts Options) Run(f func(*Plugin) error) {
53 if err := run(opts, f); err != nil {
54 fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err)
55 os.Exit(1)
56 }
57 }
58
59 func run(opts Options, f func(*Plugin) error) error {
60 if len(os.Args) > 1 {
61 return fmt.Errorf("unknown argument %q (this program should be run by protoc, not directly)", os.Args[1])
62 }
63 in, err := io.ReadAll(os.Stdin)
64 if err != nil {
65 return err
66 }
67 req := &pluginpb.CodeGeneratorRequest{}
68 if err := proto.Unmarshal(in, req); err != nil {
69 return err
70 }
71 gen, err := opts.New(req)
72 if err != nil {
73 return err
74 }
75 if err := f(gen); err != nil {
76
77
78
79
80
81
82 gen.Error(err)
83 }
84 resp := gen.Response()
85 out, err := proto.Marshal(resp)
86 if err != nil {
87 return err
88 }
89 if _, err := os.Stdout.Write(out); err != nil {
90 return err
91 }
92 return nil
93 }
94
95
96 type Plugin struct {
97
98 Request *pluginpb.CodeGeneratorRequest
99
100
101
102
103 Files []*File
104 FilesByPath map[string]*File
105
106
107
108
109 SupportedFeatures uint64
110
111 SupportedEditionsMinimum descriptorpb.Edition
112 SupportedEditionsMaximum descriptorpb.Edition
113
114 fileReg *protoregistry.Files
115 enumsByName map[protoreflect.FullName]*Enum
116 messagesByName map[protoreflect.FullName]*Message
117 annotateCode bool
118 pathType pathType
119 module string
120 genFiles []*GeneratedFile
121 opts Options
122 err error
123 }
124
125 type Options struct {
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149 ParamFunc func(name, value string) error
150
151
152
153
154 ImportRewriteFunc func(GoImportPath) GoImportPath
155 }
156
157
158 func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
159 gen := &Plugin{
160 Request: req,
161 FilesByPath: make(map[string]*File),
162 fileReg: new(protoregistry.Files),
163 enumsByName: make(map[protoreflect.FullName]*Enum),
164 messagesByName: make(map[protoreflect.FullName]*Message),
165 opts: opts,
166 }
167
168 packageNames := make(map[string]GoPackageName)
169 importPaths := make(map[string]GoImportPath)
170 for _, param := range strings.Split(req.GetParameter(), ",") {
171 var value string
172 if i := strings.Index(param, "="); i >= 0 {
173 value = param[i+1:]
174 param = param[0:i]
175 }
176 switch param {
177 case "":
178
179 case "module":
180 gen.module = value
181 case "paths":
182 switch value {
183 case "import":
184 gen.pathType = pathTypeImport
185 case "source_relative":
186 gen.pathType = pathTypeSourceRelative
187 default:
188 return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value)
189 }
190 case "annotate_code":
191 switch value {
192 case "true", "":
193 gen.annotateCode = true
194 case "false":
195 default:
196 return nil, fmt.Errorf(`bad value for parameter %q: want "true" or "false"`, param)
197 }
198 default:
199 if param[0] == 'M' {
200 impPath, pkgName := splitImportPathAndPackageName(value)
201 if pkgName != "" {
202 packageNames[param[1:]] = pkgName
203 }
204 if impPath != "" {
205 importPaths[param[1:]] = impPath
206 }
207 continue
208 }
209 if opts.ParamFunc != nil {
210 if err := opts.ParamFunc(param, value); err != nil {
211 return nil, err
212 }
213 }
214 }
215 }
216
217
218
219
220 if gen.module != "" && gen.pathType == pathTypeSourceRelative {
221 return nil, fmt.Errorf("cannot use module= with paths=source_relative")
222 }
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238 for _, fdesc := range gen.Request.ProtoFile {
239
240
241 filename := fdesc.GetName()
242 impPath, pkgName := splitImportPathAndPackageName(fdesc.GetOptions().GetGoPackage())
243 if importPaths[filename] == "" && impPath != "" {
244 importPaths[filename] = impPath
245 }
246 if packageNames[filename] == "" && pkgName != "" {
247 packageNames[filename] = pkgName
248 }
249 switch {
250 case importPaths[filename] == "":
251
252 return nil, fmt.Errorf(
253 "unable to determine Go import path for %q\n\n"+
254 "Please specify either:\n"+
255 "\t• a \"go_package\" option in the .proto source file, or\n"+
256 "\t• a \"M\" argument on the command line.\n\n"+
257 "See %v for more information.\n",
258 fdesc.GetName(), goPackageDocURL)
259 case !strings.Contains(string(importPaths[filename]), ".") &&
260 !strings.Contains(string(importPaths[filename]), "/"):
261
262
263 return nil, fmt.Errorf(
264 "invalid Go import path %q for %q\n\n"+
265 "The import path must contain at least one period ('.') or forward slash ('/') character.\n\n"+
266 "See %v for more information.\n",
267 string(importPaths[filename]), fdesc.GetName(), goPackageDocURL)
268 case packageNames[filename] == "":
269
270
271
272
273
274
275
276
277
278
279 if impPath == "" {
280 impPath = importPaths[filename]
281 }
282 packageNames[filename] = cleanPackageName(path.Base(string(impPath)))
283 }
284 }
285
286
287
288 packageFiles := make(map[GoImportPath][]string)
289 for filename, importPath := range importPaths {
290 if _, ok := packageNames[filename]; !ok {
291
292
293 continue
294 }
295 packageFiles[importPath] = append(packageFiles[importPath], filename)
296 }
297 for importPath, filenames := range packageFiles {
298 for i := 1; i < len(filenames); i++ {
299 if a, b := packageNames[filenames[0]], packageNames[filenames[i]]; a != b {
300 return nil, fmt.Errorf("Go package %v has inconsistent names %v (%v) and %v (%v)",
301 importPath, a, filenames[0], b, filenames[i])
302 }
303 }
304 }
305
306
307 typeRegistry := newExtensionRegistry()
308 for _, fdesc := range gen.Request.ProtoFile {
309 filename := fdesc.GetName()
310 if gen.FilesByPath[filename] != nil {
311 return nil, fmt.Errorf("duplicate file name: %q", filename)
312 }
313 f, err := newFile(gen, fdesc, packageNames[filename], importPaths[filename])
314 if err != nil {
315 return nil, err
316 }
317 gen.Files = append(gen.Files, f)
318 gen.FilesByPath[filename] = f
319 if err = typeRegistry.registerAllExtensionsFromFile(f.Desc); err != nil {
320 return nil, err
321 }
322 }
323 for _, filename := range gen.Request.FileToGenerate {
324 f, ok := gen.FilesByPath[filename]
325 if !ok {
326 return nil, fmt.Errorf("no descriptor for generated file: %v", filename)
327 }
328 f.Generate = true
329 }
330
331
332 if typeRegistry.hasNovelExtensions() {
333 for _, f := range gen.Files {
334 b, err := proto.Marshal(f.Proto.ProtoReflect().Interface())
335 if err != nil {
336 return nil, err
337 }
338 err = proto.UnmarshalOptions{Resolver: typeRegistry}.Unmarshal(b, f.Proto)
339 if err != nil {
340 return nil, err
341 }
342 }
343 }
344 return gen, nil
345 }
346
347
348
349 func (gen *Plugin) Error(err error) {
350 if gen.err == nil {
351 gen.err = err
352 }
353 }
354
355
356 func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
357 resp := &pluginpb.CodeGeneratorResponse{}
358 if gen.err != nil {
359 resp.Error = proto.String(gen.err.Error())
360 return resp
361 }
362 for _, g := range gen.genFiles {
363 if g.skip {
364 continue
365 }
366 content, err := g.Content()
367 if err != nil {
368 return &pluginpb.CodeGeneratorResponse{
369 Error: proto.String(err.Error()),
370 }
371 }
372 filename := g.filename
373 if gen.module != "" {
374 trim := gen.module + "/"
375 if !strings.HasPrefix(filename, trim) {
376 return &pluginpb.CodeGeneratorResponse{
377 Error: proto.String(fmt.Sprintf("%v: generated file does not match prefix %q", filename, gen.module)),
378 }
379 }
380 filename = strings.TrimPrefix(filename, trim)
381 }
382 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
383 Name: proto.String(filename),
384 Content: proto.String(string(content)),
385 })
386 if gen.annotateCode && strings.HasSuffix(g.filename, ".go") {
387 meta, err := g.metaFile(content)
388 if err != nil {
389 return &pluginpb.CodeGeneratorResponse{
390 Error: proto.String(err.Error()),
391 }
392 }
393 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
394 Name: proto.String(filename + ".meta"),
395 Content: proto.String(meta),
396 })
397 }
398 }
399 if gen.SupportedFeatures > 0 {
400 resp.SupportedFeatures = proto.Uint64(gen.SupportedFeatures)
401 }
402 if gen.SupportedEditionsMinimum != descriptorpb.Edition_EDITION_UNKNOWN && gen.SupportedEditionsMaximum != descriptorpb.Edition_EDITION_UNKNOWN {
403 resp.MinimumEdition = proto.Int32(int32(gen.SupportedEditionsMinimum))
404 resp.MaximumEdition = proto.Int32(int32(gen.SupportedEditionsMaximum))
405 }
406 return resp
407 }
408
409
410 type File struct {
411 Desc protoreflect.FileDescriptor
412 Proto *descriptorpb.FileDescriptorProto
413
414 GoDescriptorIdent GoIdent
415 GoPackageName GoPackageName
416 GoImportPath GoImportPath
417
418 Enums []*Enum
419 Messages []*Message
420 Extensions []*Extension
421 Services []*Service
422
423 Generate bool
424
425
426
427
428
429
430 GeneratedFilenamePrefix string
431
432 location Location
433 }
434
435 func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) {
436 desc, err := protodesc.NewFile(p, gen.fileReg)
437 if err != nil {
438 return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
439 }
440 if err := gen.fileReg.RegisterFile(desc); err != nil {
441 return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
442 }
443 f := &File{
444 Desc: desc,
445 Proto: p,
446 GoPackageName: packageName,
447 GoImportPath: importPath,
448 location: Location{SourceFile: desc.Path()},
449 }
450
451
452 prefix := p.GetName()
453 if ext := path.Ext(prefix); ext == ".proto" || ext == ".protodevel" {
454 prefix = prefix[:len(prefix)-len(ext)]
455 }
456 switch gen.pathType {
457 case pathTypeImport:
458
459 prefix = path.Join(string(f.GoImportPath), path.Base(prefix))
460 case pathTypeSourceRelative:
461
462
463 }
464 f.GoDescriptorIdent = GoIdent{
465 GoName: "File_" + strs.GoSanitized(p.GetName()),
466 GoImportPath: f.GoImportPath,
467 }
468 f.GeneratedFilenamePrefix = prefix
469
470 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
471 f.Enums = append(f.Enums, newEnum(gen, f, nil, eds.Get(i)))
472 }
473 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
474 f.Messages = append(f.Messages, newMessage(gen, f, nil, mds.Get(i)))
475 }
476 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
477 f.Extensions = append(f.Extensions, newField(gen, f, nil, xds.Get(i)))
478 }
479 for i, sds := 0, desc.Services(); i < sds.Len(); i++ {
480 f.Services = append(f.Services, newService(gen, f, sds.Get(i)))
481 }
482 for _, message := range f.Messages {
483 if err := message.resolveDependencies(gen); err != nil {
484 return nil, err
485 }
486 }
487 for _, extension := range f.Extensions {
488 if err := extension.resolveDependencies(gen); err != nil {
489 return nil, err
490 }
491 }
492 for _, service := range f.Services {
493 for _, method := range service.Methods {
494 if err := method.resolveDependencies(gen); err != nil {
495 return nil, err
496 }
497 }
498 }
499 return f, nil
500 }
501
502
503
504 func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) {
505 if i := strings.Index(s, ";"); i >= 0 {
506 return GoImportPath(s[:i]), GoPackageName(s[i+1:])
507 }
508 return GoImportPath(s), ""
509 }
510
511
512 type Enum struct {
513 Desc protoreflect.EnumDescriptor
514
515 GoIdent GoIdent
516
517 Values []*EnumValue
518
519 Location Location
520 Comments CommentSet
521 }
522
523 func newEnum(gen *Plugin, f *File, parent *Message, desc protoreflect.EnumDescriptor) *Enum {
524 var loc Location
525 if parent != nil {
526 loc = parent.Location.appendPath(genid.DescriptorProto_EnumType_field_number, desc.Index())
527 } else {
528 loc = f.location.appendPath(genid.FileDescriptorProto_EnumType_field_number, desc.Index())
529 }
530 enum := &Enum{
531 Desc: desc,
532 GoIdent: newGoIdent(f, desc),
533 Location: loc,
534 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
535 }
536 gen.enumsByName[desc.FullName()] = enum
537 for i, vds := 0, enum.Desc.Values(); i < vds.Len(); i++ {
538 enum.Values = append(enum.Values, newEnumValue(gen, f, parent, enum, vds.Get(i)))
539 }
540 return enum
541 }
542
543
544 type EnumValue struct {
545 Desc protoreflect.EnumValueDescriptor
546
547 GoIdent GoIdent
548
549 Parent *Enum
550
551 Location Location
552 Comments CommentSet
553 }
554
555 func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc protoreflect.EnumValueDescriptor) *EnumValue {
556
557
558
559
560 parentIdent := enum.GoIdent
561 if message != nil {
562 parentIdent = message.GoIdent
563 }
564 name := parentIdent.GoName + "_" + string(desc.Name())
565 loc := enum.Location.appendPath(genid.EnumDescriptorProto_Value_field_number, desc.Index())
566 return &EnumValue{
567 Desc: desc,
568 GoIdent: f.GoImportPath.Ident(name),
569 Parent: enum,
570 Location: loc,
571 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
572 }
573 }
574
575
576 type Message struct {
577 Desc protoreflect.MessageDescriptor
578
579 GoIdent GoIdent
580
581 Fields []*Field
582 Oneofs []*Oneof
583
584 Enums []*Enum
585 Messages []*Message
586 Extensions []*Extension
587
588 Location Location
589 Comments CommentSet
590 }
591
592 func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor) *Message {
593 var loc Location
594 if parent != nil {
595 loc = parent.Location.appendPath(genid.DescriptorProto_NestedType_field_number, desc.Index())
596 } else {
597 loc = f.location.appendPath(genid.FileDescriptorProto_MessageType_field_number, desc.Index())
598 }
599 message := &Message{
600 Desc: desc,
601 GoIdent: newGoIdent(f, desc),
602 Location: loc,
603 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
604 }
605 gen.messagesByName[desc.FullName()] = message
606 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
607 message.Enums = append(message.Enums, newEnum(gen, f, message, eds.Get(i)))
608 }
609 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
610 message.Messages = append(message.Messages, newMessage(gen, f, message, mds.Get(i)))
611 }
612 for i, fds := 0, desc.Fields(); i < fds.Len(); i++ {
613 message.Fields = append(message.Fields, newField(gen, f, message, fds.Get(i)))
614 }
615 for i, ods := 0, desc.Oneofs(); i < ods.Len(); i++ {
616 message.Oneofs = append(message.Oneofs, newOneof(gen, f, message, ods.Get(i)))
617 }
618 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
619 message.Extensions = append(message.Extensions, newField(gen, f, message, xds.Get(i)))
620 }
621
622
623 for _, field := range message.Fields {
624 if od := field.Desc.ContainingOneof(); od != nil {
625 oneof := message.Oneofs[od.Index()]
626 field.Oneof = oneof
627 oneof.Fields = append(oneof.Fields, field)
628 }
629 }
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646 usedNames := map[string]bool{
647 "Reset": true,
648 "String": true,
649 "ProtoMessage": true,
650 "Marshal": true,
651 "Unmarshal": true,
652 "ExtensionRangeArray": true,
653 "ExtensionMap": true,
654 "Descriptor": true,
655 }
656 makeNameUnique := func(name string, hasGetter bool) string {
657 for usedNames[name] || (hasGetter && usedNames["Get"+name]) {
658 name += "_"
659 }
660 usedNames[name] = true
661 usedNames["Get"+name] = hasGetter
662 return name
663 }
664 for _, field := range message.Fields {
665 field.GoName = makeNameUnique(field.GoName, true)
666 field.GoIdent.GoName = message.GoIdent.GoName + "_" + field.GoName
667 if field.Oneof != nil && field.Oneof.Fields[0] == field {
668
669
670
671 field.Oneof.GoName = makeNameUnique(field.Oneof.GoName, false)
672 field.Oneof.GoIdent.GoName = message.GoIdent.GoName + "_" + field.Oneof.GoName
673 }
674 }
675
676
677
678
679
680 for _, field := range message.Fields {
681 if field.Oneof != nil {
682 Loop:
683 for {
684 for _, nestedMessage := range message.Messages {
685 if nestedMessage.GoIdent == field.GoIdent {
686 field.GoIdent.GoName += "_"
687 continue Loop
688 }
689 }
690 for _, nestedEnum := range message.Enums {
691 if nestedEnum.GoIdent == field.GoIdent {
692 field.GoIdent.GoName += "_"
693 continue Loop
694 }
695 }
696 break Loop
697 }
698 }
699 }
700
701 return message
702 }
703
704 func (message *Message) resolveDependencies(gen *Plugin) error {
705 for _, field := range message.Fields {
706 if err := field.resolveDependencies(gen); err != nil {
707 return err
708 }
709 }
710 for _, message := range message.Messages {
711 if err := message.resolveDependencies(gen); err != nil {
712 return err
713 }
714 }
715 for _, extension := range message.Extensions {
716 if err := extension.resolveDependencies(gen); err != nil {
717 return err
718 }
719 }
720 return nil
721 }
722
723
724 type Field struct {
725 Desc protoreflect.FieldDescriptor
726
727
728
729
730 GoName string
731
732
733
734
735
736 GoIdent GoIdent
737
738 Parent *Message
739 Oneof *Oneof
740 Extendee *Message
741
742 Enum *Enum
743 Message *Message
744
745 Location Location
746 Comments CommentSet
747 }
748
749 func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDescriptor) *Field {
750 var loc Location
751 switch {
752 case desc.IsExtension() && message == nil:
753 loc = f.location.appendPath(genid.FileDescriptorProto_Extension_field_number, desc.Index())
754 case desc.IsExtension() && message != nil:
755 loc = message.Location.appendPath(genid.DescriptorProto_Extension_field_number, desc.Index())
756 default:
757 loc = message.Location.appendPath(genid.DescriptorProto_Field_field_number, desc.Index())
758 }
759 camelCased := strs.GoCamelCase(string(desc.Name()))
760 var parentPrefix string
761 if message != nil {
762 parentPrefix = message.GoIdent.GoName + "_"
763 }
764 field := &Field{
765 Desc: desc,
766 GoName: camelCased,
767 GoIdent: GoIdent{
768 GoImportPath: f.GoImportPath,
769 GoName: parentPrefix + camelCased,
770 },
771 Parent: message,
772 Location: loc,
773 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
774 }
775 return field
776 }
777
778 func (field *Field) resolveDependencies(gen *Plugin) error {
779 desc := field.Desc
780 switch desc.Kind() {
781 case protoreflect.EnumKind:
782 name := field.Desc.Enum().FullName()
783 enum, ok := gen.enumsByName[name]
784 if !ok {
785 return fmt.Errorf("field %v: no descriptor for enum %v", desc.FullName(), name)
786 }
787 field.Enum = enum
788 case protoreflect.MessageKind, protoreflect.GroupKind:
789 name := desc.Message().FullName()
790 message, ok := gen.messagesByName[name]
791 if !ok {
792 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
793 }
794 field.Message = message
795 }
796 if desc.IsExtension() {
797 name := desc.ContainingMessage().FullName()
798 message, ok := gen.messagesByName[name]
799 if !ok {
800 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
801 }
802 field.Extendee = message
803 }
804 return nil
805 }
806
807
808 type Oneof struct {
809 Desc protoreflect.OneofDescriptor
810
811
812
813
814 GoName string
815
816
817 GoIdent GoIdent
818
819 Parent *Message
820
821 Fields []*Field
822
823 Location Location
824 Comments CommentSet
825 }
826
827 func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDescriptor) *Oneof {
828 loc := message.Location.appendPath(genid.DescriptorProto_OneofDecl_field_number, desc.Index())
829 camelCased := strs.GoCamelCase(string(desc.Name()))
830 parentPrefix := message.GoIdent.GoName + "_"
831 return &Oneof{
832 Desc: desc,
833 Parent: message,
834 GoName: camelCased,
835 GoIdent: GoIdent{
836 GoImportPath: f.GoImportPath,
837 GoName: parentPrefix + camelCased,
838 },
839 Location: loc,
840 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
841 }
842 }
843
844
845 type Extension = Field
846
847
848 type Service struct {
849 Desc protoreflect.ServiceDescriptor
850
851 GoName string
852
853 Methods []*Method
854
855 Location Location
856 Comments CommentSet
857 }
858
859 func newService(gen *Plugin, f *File, desc protoreflect.ServiceDescriptor) *Service {
860 loc := f.location.appendPath(genid.FileDescriptorProto_Service_field_number, desc.Index())
861 service := &Service{
862 Desc: desc,
863 GoName: strs.GoCamelCase(string(desc.Name())),
864 Location: loc,
865 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
866 }
867 for i, mds := 0, desc.Methods(); i < mds.Len(); i++ {
868 service.Methods = append(service.Methods, newMethod(gen, f, service, mds.Get(i)))
869 }
870 return service
871 }
872
873
874 type Method struct {
875 Desc protoreflect.MethodDescriptor
876
877 GoName string
878
879 Parent *Service
880
881 Input *Message
882 Output *Message
883
884 Location Location
885 Comments CommentSet
886 }
887
888 func newMethod(gen *Plugin, f *File, service *Service, desc protoreflect.MethodDescriptor) *Method {
889 loc := service.Location.appendPath(genid.ServiceDescriptorProto_Method_field_number, desc.Index())
890 method := &Method{
891 Desc: desc,
892 GoName: strs.GoCamelCase(string(desc.Name())),
893 Parent: service,
894 Location: loc,
895 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
896 }
897 return method
898 }
899
900 func (method *Method) resolveDependencies(gen *Plugin) error {
901 desc := method.Desc
902
903 inName := desc.Input().FullName()
904 in, ok := gen.messagesByName[inName]
905 if !ok {
906 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), inName)
907 }
908 method.Input = in
909
910 outName := desc.Output().FullName()
911 out, ok := gen.messagesByName[outName]
912 if !ok {
913 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), outName)
914 }
915 method.Output = out
916
917 return nil
918 }
919
920
921 type GeneratedFile struct {
922 gen *Plugin
923 skip bool
924 filename string
925 goImportPath GoImportPath
926 buf bytes.Buffer
927 packageNames map[GoImportPath]GoPackageName
928 usedPackageNames map[GoPackageName]bool
929 manualImports map[GoImportPath]bool
930 annotations map[string][]Annotation
931 }
932
933
934
935 func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
936 g := &GeneratedFile{
937 gen: gen,
938 filename: filename,
939 goImportPath: goImportPath,
940 packageNames: make(map[GoImportPath]GoPackageName),
941 usedPackageNames: make(map[GoPackageName]bool),
942 manualImports: make(map[GoImportPath]bool),
943 annotations: make(map[string][]Annotation),
944 }
945
946
947 for _, s := range types.Universe.Names() {
948 g.usedPackageNames[GoPackageName(s)] = true
949 }
950
951 gen.genFiles = append(gen.genFiles, g)
952 return g
953 }
954
955
956
957
958 func (g *GeneratedFile) P(v ...interface{}) {
959 for _, x := range v {
960 switch x := x.(type) {
961 case GoIdent:
962 fmt.Fprint(&g.buf, g.QualifiedGoIdent(x))
963 default:
964 fmt.Fprint(&g.buf, x)
965 }
966 }
967 fmt.Fprintln(&g.buf)
968 }
969
970
971
972
973
974
975 func (g *GeneratedFile) QualifiedGoIdent(ident GoIdent) string {
976 if ident.GoImportPath == g.goImportPath {
977 return ident.GoName
978 }
979 if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
980 return string(packageName) + "." + ident.GoName
981 }
982 packageName := cleanPackageName(path.Base(string(ident.GoImportPath)))
983 for i, orig := 1, packageName; g.usedPackageNames[packageName]; i++ {
984 packageName = orig + GoPackageName(strconv.Itoa(i))
985 }
986 g.packageNames[ident.GoImportPath] = packageName
987 g.usedPackageNames[packageName] = true
988 return string(packageName) + "." + ident.GoName
989 }
990
991
992
993
994
995
996 func (g *GeneratedFile) Import(importPath GoImportPath) {
997 g.manualImports[importPath] = true
998 }
999
1000
1001 func (g *GeneratedFile) Write(p []byte) (n int, err error) {
1002 return g.buf.Write(p)
1003 }
1004
1005
1006 func (g *GeneratedFile) Skip() {
1007 g.skip = true
1008 }
1009
1010
1011
1012 func (g *GeneratedFile) Unskip() {
1013 g.skip = false
1014 }
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024 func (g *GeneratedFile) Annotate(symbol string, loc Location) {
1025 g.AnnotateSymbol(symbol, Annotation{Location: loc})
1026 }
1027
1028
1029
1030
1031
1032 type Annotation struct {
1033
1034 Location Location
1035
1036
1037 Semantic *descriptorpb.GeneratedCodeInfo_Annotation_Semantic
1038 }
1039
1040
1041
1042
1043
1044
1045
1046 func (g *GeneratedFile) AnnotateSymbol(symbol string, info Annotation) {
1047 g.annotations[symbol] = append(g.annotations[symbol], info)
1048 }
1049
1050
1051 func (g *GeneratedFile) Content() ([]byte, error) {
1052 if !strings.HasSuffix(g.filename, ".go") {
1053 return g.buf.Bytes(), nil
1054 }
1055
1056
1057 original := g.buf.Bytes()
1058 fset := token.NewFileSet()
1059 file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
1060 if err != nil {
1061
1062
1063
1064 var src bytes.Buffer
1065 s := bufio.NewScanner(bytes.NewReader(original))
1066 for line := 1; s.Scan(); line++ {
1067 fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
1068 }
1069 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
1070 }
1071
1072
1073 var importPaths [][2]string
1074 rewriteImport := func(importPath string) string {
1075 if f := g.gen.opts.ImportRewriteFunc; f != nil {
1076 return string(f(GoImportPath(importPath)))
1077 }
1078 return importPath
1079 }
1080 for importPath := range g.packageNames {
1081 pkgName := string(g.packageNames[GoImportPath(importPath)])
1082 pkgPath := rewriteImport(string(importPath))
1083 importPaths = append(importPaths, [2]string{pkgName, pkgPath})
1084 }
1085 for importPath := range g.manualImports {
1086 if _, ok := g.packageNames[importPath]; !ok {
1087 pkgPath := rewriteImport(string(importPath))
1088 importPaths = append(importPaths, [2]string{"_", pkgPath})
1089 }
1090 }
1091 sort.Slice(importPaths, func(i, j int) bool {
1092 return importPaths[i][1] < importPaths[j][1]
1093 })
1094
1095
1096 if len(importPaths) > 0 {
1097
1098
1099 pos := file.Package
1100 tokFile := fset.File(file.Package)
1101 pkgLine := tokFile.Line(file.Package)
1102 for _, c := range file.Comments {
1103 if tokFile.Line(c.Pos()) > pkgLine {
1104 break
1105 }
1106 pos = c.End()
1107 }
1108
1109
1110 impDecl := &ast.GenDecl{
1111 Tok: token.IMPORT,
1112 TokPos: pos,
1113 Lparen: pos,
1114 Rparen: pos,
1115 }
1116 for _, importPath := range importPaths {
1117 impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
1118 Name: &ast.Ident{
1119 Name: importPath[0],
1120 NamePos: pos,
1121 },
1122 Path: &ast.BasicLit{
1123 Kind: token.STRING,
1124 Value: strconv.Quote(importPath[1]),
1125 ValuePos: pos,
1126 },
1127 EndPos: pos,
1128 })
1129 }
1130 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
1131 }
1132
1133 var out bytes.Buffer
1134 if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, file); err != nil {
1135 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
1136 }
1137 return out.Bytes(), nil
1138 }
1139
1140 func (g *GeneratedFile) generatedCodeInfo(content []byte) (*descriptorpb.GeneratedCodeInfo, error) {
1141 fset := token.NewFileSet()
1142 astFile, err := parser.ParseFile(fset, "", content, 0)
1143 if err != nil {
1144 return nil, err
1145 }
1146 info := &descriptorpb.GeneratedCodeInfo{}
1147
1148 seenAnnotations := make(map[string]bool)
1149 annotate := func(s string, ident *ast.Ident) {
1150 seenAnnotations[s] = true
1151 for _, a := range g.annotations[s] {
1152 info.Annotation = append(info.Annotation, &descriptorpb.GeneratedCodeInfo_Annotation{
1153 SourceFile: proto.String(a.Location.SourceFile),
1154 Path: a.Location.Path,
1155 Begin: proto.Int32(int32(fset.Position(ident.Pos()).Offset)),
1156 End: proto.Int32(int32(fset.Position(ident.End()).Offset)),
1157 Semantic: a.Semantic,
1158 })
1159 }
1160 }
1161 for _, decl := range astFile.Decls {
1162 switch decl := decl.(type) {
1163 case *ast.GenDecl:
1164 for _, spec := range decl.Specs {
1165 switch spec := spec.(type) {
1166 case *ast.TypeSpec:
1167 annotate(spec.Name.Name, spec.Name)
1168 switch st := spec.Type.(type) {
1169 case *ast.StructType:
1170 for _, field := range st.Fields.List {
1171 for _, name := range field.Names {
1172 annotate(spec.Name.Name+"."+name.Name, name)
1173 }
1174 }
1175 case *ast.InterfaceType:
1176 for _, field := range st.Methods.List {
1177 for _, name := range field.Names {
1178 annotate(spec.Name.Name+"."+name.Name, name)
1179 }
1180 }
1181 }
1182 case *ast.ValueSpec:
1183 for _, name := range spec.Names {
1184 annotate(name.Name, name)
1185 }
1186 }
1187 }
1188 case *ast.FuncDecl:
1189 if decl.Recv == nil {
1190 annotate(decl.Name.Name, decl.Name)
1191 } else {
1192 recv := decl.Recv.List[0].Type
1193 if s, ok := recv.(*ast.StarExpr); ok {
1194 recv = s.X
1195 }
1196 if id, ok := recv.(*ast.Ident); ok {
1197 annotate(id.Name+"."+decl.Name.Name, decl.Name)
1198 }
1199 }
1200 }
1201 }
1202 for a := range g.annotations {
1203 if !seenAnnotations[a] {
1204 return nil, fmt.Errorf("%v: no symbol matching annotation %q", g.filename, a)
1205 }
1206 }
1207
1208 return info, nil
1209 }
1210
1211
1212
1213 func (g *GeneratedFile) metaFile(content []byte) (string, error) {
1214 info, err := g.generatedCodeInfo(content)
1215 if err != nil {
1216 return "", err
1217 }
1218
1219 b, err := prototext.Marshal(info)
1220 if err != nil {
1221 return "", err
1222 }
1223 return string(b), nil
1224 }
1225
1226
1227
1228 type GoIdent struct {
1229 GoName string
1230 GoImportPath GoImportPath
1231 }
1232
1233 func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
1234
1235
1236 func newGoIdent(f *File, d protoreflect.Descriptor) GoIdent {
1237 name := strings.TrimPrefix(string(d.FullName()), string(f.Desc.Package())+".")
1238 return GoIdent{
1239 GoName: strs.GoCamelCase(name),
1240 GoImportPath: f.GoImportPath,
1241 }
1242 }
1243
1244
1245
1246 type GoImportPath string
1247
1248 func (p GoImportPath) String() string { return strconv.Quote(string(p)) }
1249
1250
1251 func (p GoImportPath) Ident(s string) GoIdent {
1252 return GoIdent{GoName: s, GoImportPath: p}
1253 }
1254
1255
1256 type GoPackageName string
1257
1258
1259 func cleanPackageName(name string) GoPackageName {
1260 return GoPackageName(strs.GoSanitized(name))
1261 }
1262
1263 type pathType int
1264
1265 const (
1266 pathTypeImport pathType = iota
1267 pathTypeSourceRelative
1268 )
1269
1270
1271
1272
1273
1274 type Location struct {
1275 SourceFile string
1276 Path protoreflect.SourcePath
1277 }
1278
1279
1280 func (loc Location) appendPath(num protoreflect.FieldNumber, idx int) Location {
1281 loc.Path = append(protoreflect.SourcePath(nil), loc.Path...)
1282 loc.Path = append(loc.Path, int32(num), int32(idx))
1283 return loc
1284 }
1285
1286
1287
1288 type CommentSet struct {
1289 LeadingDetached []Comments
1290 Leading Comments
1291 Trailing Comments
1292 }
1293
1294 func makeCommentSet(loc protoreflect.SourceLocation) CommentSet {
1295 var leadingDetached []Comments
1296 for _, s := range loc.LeadingDetachedComments {
1297 leadingDetached = append(leadingDetached, Comments(s))
1298 }
1299 return CommentSet{
1300 LeadingDetached: leadingDetached,
1301 Leading: Comments(loc.LeadingComments),
1302 Trailing: Comments(loc.TrailingComments),
1303 }
1304 }
1305
1306
1307 type Comments string
1308
1309
1310
1311
1312 func (c Comments) String() string {
1313 if c == "" {
1314 return ""
1315 }
1316 var b []byte
1317 for _, line := range strings.Split(strings.TrimSuffix(string(c), "\n"), "\n") {
1318 b = append(b, "//"...)
1319 b = append(b, line...)
1320 b = append(b, "\n"...)
1321 }
1322 return string(b)
1323 }
1324
1325
1326
1327
1328
1329
1330 type extensionRegistry struct {
1331 base *protoregistry.Types
1332 local *protoregistry.Types
1333 }
1334
1335 func newExtensionRegistry() *extensionRegistry {
1336 return &extensionRegistry{
1337 base: protoregistry.GlobalTypes,
1338 local: &protoregistry.Types{},
1339 }
1340 }
1341
1342
1343 func (e *extensionRegistry) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
1344 if xt, err := e.local.FindExtensionByName(field); err == nil {
1345 return xt, nil
1346 }
1347
1348 return e.base.FindExtensionByName(field)
1349 }
1350
1351
1352 func (e *extensionRegistry) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
1353 if xt, err := e.local.FindExtensionByNumber(message, field); err == nil {
1354 return xt, nil
1355 }
1356
1357 return e.base.FindExtensionByNumber(message, field)
1358 }
1359
1360 func (e *extensionRegistry) hasNovelExtensions() bool {
1361 return e.local.NumExtensions() > 0
1362 }
1363
1364 func (e *extensionRegistry) registerAllExtensionsFromFile(f protoreflect.FileDescriptor) error {
1365 if err := e.registerAllExtensions(f.Extensions()); err != nil {
1366 return err
1367 }
1368 return nil
1369 }
1370
1371 func (e *extensionRegistry) registerAllExtensionsFromMessage(ms protoreflect.MessageDescriptors) error {
1372 for i := 0; i < ms.Len(); i++ {
1373 m := ms.Get(i)
1374 if err := e.registerAllExtensions(m.Extensions()); err != nil {
1375 return err
1376 }
1377 }
1378 return nil
1379 }
1380
1381 func (e *extensionRegistry) registerAllExtensions(exts protoreflect.ExtensionDescriptors) error {
1382 for i := 0; i < exts.Len(); i++ {
1383 if err := e.registerExtension(exts.Get(i)); err != nil {
1384 return err
1385 }
1386 }
1387 return nil
1388 }
1389
1390
1391
1392 func (e *extensionRegistry) registerExtension(xd protoreflect.ExtensionDescriptor) error {
1393 if _, err := e.FindExtensionByName(xd.FullName()); err != protoregistry.NotFound {
1394
1395 return err
1396 }
1397 return e.local.RegisterExtension(dynamicpb.NewExtensionType(xd))
1398 }
1399
View as plain text