1 package descriptor
2
3 import (
4 "fmt"
5 "sort"
6 "strings"
7
8 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/codegenerator"
9 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor/openapiconfig"
10 "github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options"
11 "golang.org/x/text/cases"
12 "golang.org/x/text/language"
13 "google.golang.org/genproto/googleapis/api/annotations"
14 "google.golang.org/grpc/grpclog"
15 "google.golang.org/protobuf/compiler/protogen"
16 "google.golang.org/protobuf/types/descriptorpb"
17 "google.golang.org/protobuf/types/pluginpb"
18 )
19
20
21 type Registry struct {
22
23 msgs map[string]*Message
24
25
26 enums map[string]*Enum
27
28
29 files map[string]*File
30
31
32 meths map[string]*Method
33
34
35 prefix string
36
37
38 pkgMap map[string]string
39
40
41 pkgAliases map[string]string
42
43
44 allowDeleteBody bool
45
46
47 externalHTTPRules map[string][]*annotations.HttpRule
48
49
50 allowMerge bool
51
52
53 mergeFileName string
54
55
56
57 includePackageInTags bool
58
59
60 repeatedPathParamSeparator repeatedFieldSeparator
61
62
63
64
65 useJSONNamesForFields bool
66
67
68
69
70
71
72
73
74
75 openAPINamingStrategy string
76
77
78 visibilityRestrictionSelectors map[string]bool
79
80
81
82 useGoTemplate bool
83
84
85 goTemplateArgs map[string]string
86
87
88 ignoreComments bool
89
90
91
92 removeInternalComments bool
93
94
95 enumsAsInts bool
96
97
98 omitEnumDefaultValue bool
99
100
101
102 disableDefaultErrors bool
103
104
105
106 simpleOperationIDs bool
107
108 standalone bool
109
110
111 warnOnUnboundMethods bool
112
113
114 proto3OptionalNullable bool
115
116
117 fileOptions map[string]*options.Swagger
118
119
120 methodOptions map[string]*options.Operation
121
122
123 messageOptions map[string]*options.Schema
124
125
126 serviceOptions map[string]*options.Tag
127
128
129
130 fieldOptions map[string]*options.JSONSchema
131
132
133
134 generateUnboundMethods bool
135
136
137 omitPackageDoc bool
138
139
140 recursiveDepth int
141
142
143 annotationMap map[annotationIdentifier]struct{}
144
145
146
147 disableServiceTags bool
148
149
150
151 disableDefaultResponses bool
152
153
154
155 useAllOfForRefs bool
156
157
158 allowPatchFeature bool
159
160
161
162 preserveRPCOrder bool
163 }
164
165 type repeatedFieldSeparator struct {
166 name string
167 sep rune
168 }
169
170 type annotationIdentifier struct {
171 method string
172 pathTemplate string
173 service *Service
174 }
175
176
177 func NewRegistry() *Registry {
178 return &Registry{
179 msgs: make(map[string]*Message),
180 enums: make(map[string]*Enum),
181 meths: make(map[string]*Method),
182 files: make(map[string]*File),
183 pkgMap: make(map[string]string),
184 pkgAliases: make(map[string]string),
185 externalHTTPRules: make(map[string][]*annotations.HttpRule),
186 openAPINamingStrategy: "legacy",
187 visibilityRestrictionSelectors: make(map[string]bool),
188 repeatedPathParamSeparator: repeatedFieldSeparator{
189 name: "csv",
190 sep: ',',
191 },
192 fileOptions: make(map[string]*options.Swagger),
193 methodOptions: make(map[string]*options.Operation),
194 messageOptions: make(map[string]*options.Schema),
195 serviceOptions: make(map[string]*options.Tag),
196 fieldOptions: make(map[string]*options.JSONSchema),
197 annotationMap: make(map[annotationIdentifier]struct{}),
198 recursiveDepth: 1000,
199 }
200 }
201
202
203 func (r *Registry) Load(req *pluginpb.CodeGeneratorRequest) error {
204 gen, err := protogen.Options{}.New(req)
205 if err != nil {
206 return err
207 }
208
209
210
211 codegenerator.SetSupportedFeaturesOnPluginGen(gen)
212 return r.load(gen)
213 }
214
215 func (r *Registry) LoadFromPlugin(gen *protogen.Plugin) error {
216 return r.load(gen)
217 }
218
219 func (r *Registry) load(gen *protogen.Plugin) error {
220 filePaths := make([]string, 0, len(gen.FilesByPath))
221 for filePath := range gen.FilesByPath {
222 filePaths = append(filePaths, filePath)
223 }
224 sort.Strings(filePaths)
225
226 for _, filePath := range filePaths {
227 r.loadFile(filePath, gen.FilesByPath[filePath])
228 }
229
230 for _, filePath := range filePaths {
231 if !gen.FilesByPath[filePath].Generate {
232 continue
233 }
234 file := r.files[filePath]
235 if err := r.loadServices(file); err != nil {
236 return err
237 }
238 }
239
240 return nil
241 }
242
243
244
245
246 func (r *Registry) loadFile(filePath string, file *protogen.File) {
247 pkg := GoPackage{
248 Path: string(file.GoImportPath),
249 Name: string(file.GoPackageName),
250 }
251 if r.standalone {
252 pkg.Alias = "ext" + cases.Title(language.AmericanEnglish).String(pkg.Name)
253 }
254
255 if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
256 for i := 0; ; i++ {
257 alias := fmt.Sprintf("%s_%d", pkg.Name, i)
258 if err := r.ReserveGoPackageAlias(alias, pkg.Path); err == nil {
259 pkg.Alias = alias
260 break
261 }
262 }
263 }
264 f := &File{
265 FileDescriptorProto: file.Proto,
266 GoPkg: pkg,
267 GeneratedFilenamePrefix: file.GeneratedFilenamePrefix,
268 }
269
270 r.files[filePath] = f
271 r.registerMsg(f, nil, file.Proto.MessageType)
272 r.registerEnum(f, nil, file.Proto.EnumType)
273 }
274
275 func (r *Registry) registerMsg(file *File, outerPath []string, msgs []*descriptorpb.DescriptorProto) {
276 for i, md := range msgs {
277 m := &Message{
278 File: file,
279 Outers: outerPath,
280 DescriptorProto: md,
281 Index: i,
282 ForcePrefixedName: r.standalone,
283 }
284 for _, fd := range md.GetField() {
285 m.Fields = append(m.Fields, &Field{
286 Message: m,
287 FieldDescriptorProto: fd,
288 ForcePrefixedName: r.standalone,
289 })
290 }
291 file.Messages = append(file.Messages, m)
292 r.msgs[m.FQMN()] = m
293 if grpclog.V(1) {
294 grpclog.Infof("Register name: %s", m.FQMN())
295 }
296
297 var outers []string
298 outers = append(outers, outerPath...)
299 outers = append(outers, m.GetName())
300 r.registerMsg(file, outers, m.GetNestedType())
301 r.registerEnum(file, outers, m.GetEnumType())
302 }
303 }
304
305 func (r *Registry) registerEnum(file *File, outerPath []string, enums []*descriptorpb.EnumDescriptorProto) {
306 for i, ed := range enums {
307 e := &Enum{
308 File: file,
309 Outers: outerPath,
310 EnumDescriptorProto: ed,
311 Index: i,
312 ForcePrefixedName: r.standalone,
313 }
314 file.Enums = append(file.Enums, e)
315 r.enums[e.FQEN()] = e
316 if grpclog.V(1) {
317 grpclog.Infof("Register enum name: %s", e.FQEN())
318 }
319 }
320 }
321
322
323
324 func (r *Registry) LookupMsg(location, name string) (*Message, error) {
325 if grpclog.V(1) {
326 grpclog.Infof("Lookup %s from %s", name, location)
327 }
328 if strings.HasPrefix(name, ".") {
329 m, ok := r.msgs[name]
330 if !ok {
331 return nil, fmt.Errorf("no message found: %s", name)
332 }
333 return m, nil
334 }
335
336 if !strings.HasPrefix(location, ".") {
337 location = fmt.Sprintf(".%s", location)
338 }
339 components := strings.Split(location, ".")
340 for len(components) > 0 {
341 fqmn := strings.Join(append(components, name), ".")
342 if m, ok := r.msgs[fqmn]; ok {
343 return m, nil
344 }
345 components = components[:len(components)-1]
346 }
347 return nil, fmt.Errorf("no message found: %s", name)
348 }
349
350
351
352 func (r *Registry) LookupEnum(location, name string) (*Enum, error) {
353 if grpclog.V(1) {
354 grpclog.Infof("Lookup enum %s from %s", name, location)
355 }
356 if strings.HasPrefix(name, ".") {
357 e, ok := r.enums[name]
358 if !ok {
359 return nil, fmt.Errorf("no enum found: %s", name)
360 }
361 return e, nil
362 }
363
364 if !strings.HasPrefix(location, ".") {
365 location = fmt.Sprintf(".%s", location)
366 }
367 components := strings.Split(location, ".")
368 for len(components) > 0 {
369 fqen := strings.Join(append(components, name), ".")
370 if e, ok := r.enums[fqen]; ok {
371 return e, nil
372 }
373 components = components[:len(components)-1]
374 }
375 return nil, fmt.Errorf("no enum found: %s", name)
376 }
377
378
379 func (r *Registry) LookupFile(name string) (*File, error) {
380 f, ok := r.files[name]
381 if !ok {
382 return nil, fmt.Errorf("no such file given: %s", name)
383 }
384 return f, nil
385 }
386
387
388 func (r *Registry) LookupExternalHTTPRules(qualifiedMethodName string) []*annotations.HttpRule {
389 return r.externalHTTPRules[qualifiedMethodName]
390 }
391
392
393 func (r *Registry) AddExternalHTTPRule(qualifiedMethodName string, rule *annotations.HttpRule) {
394 r.externalHTTPRules[qualifiedMethodName] = append(r.externalHTTPRules[qualifiedMethodName], rule)
395 }
396
397
398
399 func (r *Registry) UnboundExternalHTTPRules() []string {
400 allServiceMethods := make(map[string]struct{})
401 for _, f := range r.files {
402 for _, s := range f.GetService() {
403 svc := &Service{File: f, ServiceDescriptorProto: s}
404 for _, m := range s.GetMethod() {
405 method := &Method{Service: svc, MethodDescriptorProto: m}
406 allServiceMethods[method.FQMN()] = struct{}{}
407 }
408 }
409 }
410
411 var missingMethods []string
412 for httpRuleMethod := range r.externalHTTPRules {
413 if _, ok := allServiceMethods[httpRuleMethod]; !ok {
414 missingMethods = append(missingMethods, httpRuleMethod)
415 }
416 }
417 return missingMethods
418 }
419
420
421 func (r *Registry) AddPkgMap(file, protoPkg string) {
422 r.pkgMap[file] = protoPkg
423 }
424
425
426 func (r *Registry) SetPrefix(prefix string) {
427 r.prefix = prefix
428 }
429
430
431 func (r *Registry) SetStandalone(standalone bool) {
432 r.standalone = standalone
433 }
434
435
436 func (r *Registry) SetRecursiveDepth(count int) {
437 r.recursiveDepth = count
438 }
439
440
441 func (r *Registry) GetRecursiveDepth() int {
442 return r.recursiveDepth
443 }
444
445
446
447
448
449 func (r *Registry) ReserveGoPackageAlias(alias, pkgpath string) error {
450 if taken, ok := r.pkgAliases[alias]; ok {
451 if taken == pkgpath {
452 return nil
453 }
454 return fmt.Errorf("package name %s is already taken. Use another alias", alias)
455 }
456 r.pkgAliases[alias] = pkgpath
457 return nil
458 }
459
460
461 func (r *Registry) GetAllFQMNs() []string {
462 keys := make([]string, 0, len(r.msgs))
463 for k := range r.msgs {
464 keys = append(keys, k)
465 }
466 return keys
467 }
468
469
470 func (r *Registry) GetAllFQENs() []string {
471 keys := make([]string, 0, len(r.enums))
472 for k := range r.enums {
473 keys = append(keys, k)
474 }
475 return keys
476 }
477
478 func (r *Registry) GetAllFQMethNs() []string {
479 keys := make([]string, 0, len(r.meths))
480 for k := range r.meths {
481 keys = append(keys, k)
482 }
483 return keys
484 }
485
486
487
488 func (r *Registry) SetAllowDeleteBody(allow bool) {
489 r.allowDeleteBody = allow
490 }
491
492
493 func (r *Registry) SetAllowMerge(allow bool) {
494 r.allowMerge = allow
495 }
496
497
498 func (r *Registry) IsAllowMerge() bool {
499 return r.allowMerge
500 }
501
502
503 func (r *Registry) SetMergeFileName(mergeFileName string) {
504 r.mergeFileName = mergeFileName
505 }
506
507
508
509 func (r *Registry) SetIncludePackageInTags(allow bool) {
510 r.includePackageInTags = allow
511 }
512
513
514
515 func (r *Registry) IsIncludePackageInTags() bool {
516 return r.includePackageInTags
517 }
518
519
520
521 func (r *Registry) GetRepeatedPathParamSeparator() rune {
522 return r.repeatedPathParamSeparator.sep
523 }
524
525
526
527 func (r *Registry) GetRepeatedPathParamSeparatorName() string {
528 return r.repeatedPathParamSeparator.name
529 }
530
531
532
533 func (r *Registry) SetRepeatedPathParamSeparator(name string) error {
534 var sep rune
535 switch name {
536 case "csv":
537 sep = ','
538 case "pipes":
539 sep = '|'
540 case "ssv":
541 sep = ' '
542 case "tsv":
543 sep = '\t'
544 default:
545 return fmt.Errorf("unknown repeated path parameter separator: %s", name)
546 }
547 r.repeatedPathParamSeparator = repeatedFieldSeparator{
548 name: name,
549 sep: sep,
550 }
551 return nil
552 }
553
554
555 func (r *Registry) SetUseJSONNamesForFields(use bool) {
556 r.useJSONNamesForFields = use
557 }
558
559
560 func (r *Registry) GetUseJSONNamesForFields() bool {
561 return r.useJSONNamesForFields
562 }
563
564
565
566 func (r *Registry) SetUseFQNForOpenAPIName(use bool) {
567 r.openAPINamingStrategy = "fqn"
568 }
569
570
571
572 func (r *Registry) GetUseFQNForOpenAPIName() bool {
573 return r.openAPINamingStrategy == "fqn"
574 }
575
576
577 func (r *Registry) GetMergeFileName() string {
578 return r.mergeFileName
579 }
580
581
582 func (r *Registry) SetOpenAPINamingStrategy(strategy string) {
583 r.openAPINamingStrategy = strategy
584 }
585
586
587 func (r *Registry) GetOpenAPINamingStrategy() string {
588 return r.openAPINamingStrategy
589 }
590
591
592 func (r *Registry) SetUseGoTemplate(use bool) {
593 r.useGoTemplate = use
594 }
595
596
597 func (r *Registry) GetUseGoTemplate() bool {
598 return r.useGoTemplate
599 }
600
601 func (r *Registry) SetGoTemplateArgs(kvs []string) {
602 r.goTemplateArgs = make(map[string]string)
603 for _, kv := range kvs {
604 if key, value, found := strings.Cut(kv, "="); found {
605 r.goTemplateArgs[key] = value
606 }
607 }
608 }
609
610 func (r *Registry) GetGoTemplateArgs() map[string]string {
611 return r.goTemplateArgs
612 }
613
614
615 func (r *Registry) SetIgnoreComments(ignore bool) {
616 r.ignoreComments = ignore
617 }
618
619
620 func (r *Registry) GetIgnoreComments() bool {
621 return r.ignoreComments
622 }
623
624
625 func (r *Registry) SetRemoveInternalComments(remove bool) {
626 r.removeInternalComments = remove
627 }
628
629
630 func (r *Registry) GetRemoveInternalComments() bool {
631 return r.removeInternalComments
632 }
633
634
635 func (r *Registry) SetEnumsAsInts(enumsAsInts bool) {
636 r.enumsAsInts = enumsAsInts
637 }
638
639
640 func (r *Registry) GetEnumsAsInts() bool {
641 return r.enumsAsInts
642 }
643
644
645 func (r *Registry) SetOmitEnumDefaultValue(omit bool) {
646 r.omitEnumDefaultValue = omit
647 }
648
649
650 func (r *Registry) GetOmitEnumDefaultValue() bool {
651 return r.omitEnumDefaultValue
652 }
653
654
655 func (r *Registry) SetVisibilityRestrictionSelectors(selectors []string) {
656 r.visibilityRestrictionSelectors = make(map[string]bool)
657 for _, selector := range selectors {
658 r.visibilityRestrictionSelectors[strings.TrimSpace(selector)] = true
659 }
660 }
661
662
663 func (r *Registry) GetVisibilityRestrictionSelectors() map[string]bool {
664 return r.visibilityRestrictionSelectors
665 }
666
667
668 func (r *Registry) SetDisableDefaultErrors(use bool) {
669 r.disableDefaultErrors = use
670 }
671
672
673 func (r *Registry) GetDisableDefaultErrors() bool {
674 return r.disableDefaultErrors
675 }
676
677
678 func (r *Registry) SetSimpleOperationIDs(use bool) {
679 r.simpleOperationIDs = use
680 }
681
682
683 func (r *Registry) GetSimpleOperationIDs() bool {
684 return r.simpleOperationIDs
685 }
686
687
688 func (r *Registry) SetWarnOnUnboundMethods(warn bool) {
689 r.warnOnUnboundMethods = warn
690 }
691
692
693 func (r *Registry) SetGenerateUnboundMethods(generate bool) {
694 r.generateUnboundMethods = generate
695 }
696
697
698 func (r *Registry) SetOmitPackageDoc(omit bool) {
699 r.omitPackageDoc = omit
700 }
701
702
703 func (r *Registry) GetOmitPackageDoc() bool {
704 return r.omitPackageDoc
705 }
706
707
708 func (r *Registry) SetProto3OptionalNullable(proto3OtionalNullable bool) {
709 r.proto3OptionalNullable = proto3OtionalNullable
710 }
711
712
713 func (r *Registry) GetProto3OptionalNullable() bool {
714 return r.proto3OptionalNullable
715 }
716
717
718 func (r *Registry) RegisterOpenAPIOptions(opts *openapiconfig.OpenAPIOptions) error {
719 if opts == nil {
720 return nil
721 }
722
723 for _, opt := range opts.File {
724 if _, ok := r.files[opt.File]; !ok {
725 return fmt.Errorf("no file %s found", opt.File)
726 }
727 r.fileOptions[opt.File] = opt.Option
728 }
729
730
731 methods := make(map[string]struct{})
732 services := make(map[string]struct{})
733 for _, f := range r.files {
734 for _, s := range f.Services {
735 services[s.FQSN()] = struct{}{}
736 for _, m := range s.Methods {
737 methods[m.FQMN()] = struct{}{}
738 }
739 }
740 }
741
742 for _, opt := range opts.Method {
743 qualifiedMethod := "." + opt.Method
744 if _, ok := methods[qualifiedMethod]; !ok {
745 return fmt.Errorf("no method %s found", opt.Method)
746 }
747 r.methodOptions[qualifiedMethod] = opt.Option
748 }
749
750 for _, opt := range opts.Message {
751 qualifiedMessage := "." + opt.Message
752 if _, ok := r.msgs[qualifiedMessage]; !ok {
753 return fmt.Errorf("no message %s found", opt.Message)
754 }
755 r.messageOptions[qualifiedMessage] = opt.Option
756 }
757
758 for _, opt := range opts.Service {
759 qualifiedService := "." + opt.Service
760 if _, ok := services[qualifiedService]; !ok {
761 return fmt.Errorf("no service %s found", opt.Service)
762 }
763 r.serviceOptions[qualifiedService] = opt.Option
764 }
765
766
767 fields := make(map[string]struct{})
768 for _, m := range r.msgs {
769 for _, f := range m.Fields {
770 fields[f.FQFN()] = struct{}{}
771 }
772 }
773 for _, opt := range opts.Field {
774 qualifiedField := "." + opt.Field
775 if _, ok := fields[qualifiedField]; !ok {
776 return fmt.Errorf("no field %s found", opt.Field)
777 }
778 r.fieldOptions[qualifiedField] = opt.Option
779 }
780 return nil
781 }
782
783
784 func (r *Registry) GetOpenAPIFileOption(file string) (*options.Swagger, bool) {
785 opt, ok := r.fileOptions[file]
786 return opt, ok
787 }
788
789
790 func (r *Registry) GetOpenAPIMethodOption(qualifiedMethod string) (*options.Operation, bool) {
791 opt, ok := r.methodOptions[qualifiedMethod]
792 return opt, ok
793 }
794
795
796 func (r *Registry) GetOpenAPIMessageOption(qualifiedMessage string) (*options.Schema, bool) {
797 opt, ok := r.messageOptions[qualifiedMessage]
798 return opt, ok
799 }
800
801
802 func (r *Registry) GetOpenAPIServiceOption(qualifiedService string) (*options.Tag, bool) {
803 opt, ok := r.serviceOptions[qualifiedService]
804 return opt, ok
805 }
806
807
808 func (r *Registry) GetOpenAPIFieldOption(qualifiedField string) (*options.JSONSchema, bool) {
809 opt, ok := r.fieldOptions[qualifiedField]
810 return opt, ok
811 }
812
813 func (r *Registry) FieldName(f *Field) string {
814 if r.useJSONNamesForFields {
815 return f.GetJsonName()
816 }
817 return f.GetName()
818 }
819
820 func (r *Registry) CheckDuplicateAnnotation(httpMethod string, httpTemplate string, svc *Service) error {
821 a := annotationIdentifier{method: httpMethod, pathTemplate: httpTemplate, service: svc}
822 if _, ok := r.annotationMap[a]; ok {
823 return fmt.Errorf("duplicate annotation: method=%s, template=%s", httpMethod, httpTemplate)
824 }
825 r.annotationMap[a] = struct{}{}
826 return nil
827 }
828
829
830 func (r *Registry) SetDisableServiceTags(use bool) {
831 r.disableServiceTags = use
832 }
833
834
835 func (r *Registry) GetDisableServiceTags() bool {
836 return r.disableServiceTags
837 }
838
839
840 func (r *Registry) SetDisableDefaultResponses(use bool) {
841 r.disableDefaultResponses = use
842 }
843
844
845 func (r *Registry) GetDisableDefaultResponses() bool {
846 return r.disableDefaultResponses
847 }
848
849
850 func (r *Registry) SetUseAllOfForRefs(use bool) {
851 r.useAllOfForRefs = use
852 }
853
854
855 func (r *Registry) GetUseAllOfForRefs() bool {
856 return r.useAllOfForRefs
857 }
858
859
860 func (r *Registry) SetAllowPatchFeature(allow bool) {
861 r.allowPatchFeature = allow
862 }
863
864
865 func (r *Registry) GetAllowPatchFeature() bool {
866 return r.allowPatchFeature
867 }
868
869
870 func (r *Registry) SetPreserveRPCOrder(preserve bool) {
871 r.preserveRPCOrder = preserve
872 }
873
874
875 func (r *Registry) IsPreserveRPCOrder() bool {
876 return r.preserveRPCOrder
877 }
878
View as plain text