1 package gengateway
2
3 import (
4 "errors"
5 "fmt"
6 "go/format"
7 "path"
8
9 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
10 gen "github.com/grpc-ecosystem/grpc-gateway/v2/internal/generator"
11 "google.golang.org/grpc/grpclog"
12 "google.golang.org/protobuf/proto"
13 "google.golang.org/protobuf/types/pluginpb"
14 )
15
16 var errNoTargetService = errors.New("no target service defined in the file")
17
18 type generator struct {
19 reg *descriptor.Registry
20 baseImports []descriptor.GoPackage
21 useRequestContext bool
22 registerFuncSuffix string
23 allowPatchFeature bool
24 standalone bool
25 }
26
27
28 func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix string,
29 allowPatchFeature, standalone bool) gen.Generator {
30 var imports []descriptor.GoPackage
31 for _, pkgpath := range []string{
32 "context",
33 "io",
34 "net/http",
35 "github.com/grpc-ecosystem/grpc-gateway/v2/runtime",
36 "github.com/grpc-ecosystem/grpc-gateway/v2/utilities",
37 "google.golang.org/protobuf/proto",
38 "google.golang.org/grpc",
39 "google.golang.org/grpc/codes",
40 "google.golang.org/grpc/grpclog",
41 "google.golang.org/grpc/metadata",
42 "google.golang.org/grpc/status",
43 } {
44 pkg := descriptor.GoPackage{
45 Path: pkgpath,
46 Name: path.Base(pkgpath),
47 }
48 if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
49 for i := 0; ; i++ {
50 alias := fmt.Sprintf("%s_%d", pkg.Name, i)
51 if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
52 continue
53 }
54 pkg.Alias = alias
55 break
56 }
57 }
58 imports = append(imports, pkg)
59 }
60
61 return &generator{
62 reg: reg,
63 baseImports: imports,
64 useRequestContext: useRequestContext,
65 registerFuncSuffix: registerFuncSuffix,
66 allowPatchFeature: allowPatchFeature,
67 standalone: standalone,
68 }
69 }
70
71 func (g *generator) Generate(targets []*descriptor.File) ([]*descriptor.ResponseFile, error) {
72 var files []*descriptor.ResponseFile
73 for _, file := range targets {
74 if grpclog.V(1) {
75 grpclog.Infof("Processing %s", file.GetName())
76 }
77
78 code, err := g.generate(file)
79 if errors.Is(err, errNoTargetService) {
80 if grpclog.V(1) {
81 grpclog.Infof("%s: %v", file.GetName(), err)
82 }
83 continue
84 }
85 if err != nil {
86 return nil, err
87 }
88 formatted, err := format.Source([]byte(code))
89 if err != nil {
90 grpclog.Errorf("%v: %s", err, code)
91 return nil, err
92 }
93 files = append(files, &descriptor.ResponseFile{
94 GoPkg: file.GoPkg,
95 CodeGeneratorResponse_File: &pluginpb.CodeGeneratorResponse_File{
96 Name: proto.String(file.GeneratedFilenamePrefix + ".pb.gw.go"),
97 Content: proto.String(string(formatted)),
98 },
99 })
100 }
101 return files, nil
102 }
103
104 func (g *generator) generate(file *descriptor.File) (string, error) {
105 pkgSeen := make(map[string]bool)
106 var imports []descriptor.GoPackage
107 for _, pkg := range g.baseImports {
108 pkgSeen[pkg.Path] = true
109 imports = append(imports, pkg)
110 }
111
112 if g.standalone {
113 imports = append(imports, file.GoPkg)
114 }
115
116 for _, svc := range file.Services {
117 for _, m := range svc.Methods {
118 imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...)
119 pkg := m.RequestType.File.GoPkg
120 if len(m.Bindings) == 0 ||
121 pkg == file.GoPkg || pkgSeen[pkg.Path] {
122 continue
123 }
124 pkgSeen[pkg.Path] = true
125 imports = append(imports, pkg)
126 }
127 }
128 params := param{
129 File: file,
130 Imports: imports,
131 UseRequestContext: g.useRequestContext,
132 RegisterFuncSuffix: g.registerFuncSuffix,
133 AllowPatchFeature: g.allowPatchFeature,
134 }
135 if g.reg != nil {
136 params.OmitPackageDoc = g.reg.GetOmitPackageDoc()
137 }
138 return applyTemplate(params, g.reg)
139 }
140
141
142 func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage {
143 var imports []descriptor.GoPackage
144 for _, b := range m.Bindings {
145 for _, p := range b.PathParams {
146 e, err := g.reg.LookupEnum("", p.Target.GetTypeName())
147 if err != nil {
148 continue
149 }
150 pkg := e.File.GoPkg
151 if pkg == file.GoPkg || pkgSeen[pkg.Path] {
152 continue
153 }
154 pkgSeen[pkg.Path] = true
155 imports = append(imports, pkg)
156 }
157 }
158 return imports
159 }
160
View as plain text