1 package gengateway
2
3 import (
4 "errors"
5 "fmt"
6 "go/format"
7 "path"
8 "path/filepath"
9 "strings"
10
11 "github.com/golang/glog"
12 "github.com/golang/protobuf/proto"
13 plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
14 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
15 gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
16 )
17
18 var (
19 errNoTargetService = errors.New("no target service defined in the file")
20 )
21
22 type pathType int
23
24 const (
25 pathTypeImport pathType = iota
26 pathTypeSourceRelative
27 )
28
29 type generator struct {
30 reg *descriptor.Registry
31 baseImports []descriptor.GoPackage
32 useRequestContext bool
33 registerFuncSuffix string
34 pathType pathType
35 modulePath string
36 allowPatchFeature bool
37 }
38
39
40 func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix, pathTypeString, modulePathString string, allowPatchFeature bool) gen.Generator {
41 var imports []descriptor.GoPackage
42 for _, pkgpath := range []string{
43 "context",
44 "io",
45 "net/http",
46 "github.com/grpc-ecosystem/grpc-gateway/runtime",
47 "github.com/grpc-ecosystem/grpc-gateway/utilities",
48 "github.com/golang/protobuf/descriptor",
49 "github.com/golang/protobuf/proto",
50 "google.golang.org/grpc",
51 "google.golang.org/grpc/codes",
52 "google.golang.org/grpc/grpclog",
53 "google.golang.org/grpc/metadata",
54 "google.golang.org/grpc/status",
55 } {
56 pkg := descriptor.GoPackage{
57 Path: pkgpath,
58 Name: path.Base(pkgpath),
59 }
60 if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
61 for i := 0; ; i++ {
62 alias := fmt.Sprintf("%s_%d", pkg.Name, i)
63 if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil {
64 continue
65 }
66 pkg.Alias = alias
67 break
68 }
69 }
70 imports = append(imports, pkg)
71 }
72
73 var pathType pathType
74 switch pathTypeString {
75 case "", "import":
76
77 case "source_relative":
78 pathType = pathTypeSourceRelative
79 default:
80 glog.Fatalf(`Unknown path type %q: want "import" or "source_relative".`, pathTypeString)
81 }
82
83 return &generator{
84 reg: reg,
85 baseImports: imports,
86 useRequestContext: useRequestContext,
87 registerFuncSuffix: registerFuncSuffix,
88 pathType: pathType,
89 modulePath: modulePathString,
90 allowPatchFeature: allowPatchFeature,
91 }
92 }
93
94 func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) {
95 var files []*plugin.CodeGeneratorResponse_File
96 for _, file := range targets {
97 glog.V(1).Infof("Processing %s", file.GetName())
98 code, err := g.generate(file)
99 if err == errNoTargetService {
100 glog.V(1).Infof("%s: %v", file.GetName(), err)
101 continue
102 }
103 if err != nil {
104 return nil, err
105 }
106 formatted, err := format.Source([]byte(code))
107 if err != nil {
108 glog.Errorf("%v: %s", err, code)
109 return nil, err
110 }
111 name, err := g.getFilePath(file)
112 if err != nil {
113 glog.Errorf("%v: %s", err, code)
114 return nil, err
115 }
116 ext := filepath.Ext(name)
117 base := strings.TrimSuffix(name, ext)
118 output := fmt.Sprintf("%s.pb.gw.go", base)
119 files = append(files, &plugin.CodeGeneratorResponse_File{
120 Name: proto.String(output),
121 Content: proto.String(string(formatted)),
122 })
123 glog.V(1).Infof("Will emit %s", output)
124 }
125 return files, nil
126 }
127
128 func (g *generator) getFilePath(file *descriptor.File) (string, error) {
129 name := file.GetName()
130 switch {
131 case g.modulePath != "" && g.pathType != pathTypeImport:
132 return "", errors.New("cannot use module= with paths=")
133
134 case g.modulePath != "":
135 trimPath, pkgPath := g.modulePath+"/", file.GoPkg.Path+"/"
136 if !strings.HasPrefix(pkgPath, trimPath) {
137 return "", fmt.Errorf("%v: file go path does not match module prefix: %v", file.GoPkg.Path, trimPath)
138 }
139 return filepath.Join(strings.TrimPrefix(pkgPath, trimPath), filepath.Base(name)), nil
140
141 case g.pathType == pathTypeImport && file.GoPkg.Path != "":
142 return fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)), nil
143
144 default:
145 return name, nil
146 }
147 }
148
149 func (g *generator) generate(file *descriptor.File) (string, error) {
150 pkgSeen := make(map[string]bool)
151 var imports []descriptor.GoPackage
152 for _, pkg := range g.baseImports {
153 pkgSeen[pkg.Path] = true
154 imports = append(imports, pkg)
155 }
156 for _, svc := range file.Services {
157 for _, m := range svc.Methods {
158 imports = append(imports, g.addEnumPathParamImports(file, m, pkgSeen)...)
159 pkg := m.RequestType.File.GoPkg
160 if len(m.Bindings) == 0 ||
161 pkg == file.GoPkg || pkgSeen[pkg.Path] {
162 continue
163 }
164 pkgSeen[pkg.Path] = true
165 imports = append(imports, pkg)
166 }
167 }
168 params := param{
169 File: file,
170 Imports: imports,
171 UseRequestContext: g.useRequestContext,
172 RegisterFuncSuffix: g.registerFuncSuffix,
173 AllowPatchFeature: g.allowPatchFeature,
174 }
175 if g.reg != nil {
176 params.OmitPackageDoc = g.reg.GetOmitPackageDoc()
177 }
178 return applyTemplate(params, g.reg)
179 }
180
181
182 func (g *generator) addEnumPathParamImports(file *descriptor.File, m *descriptor.Method, pkgSeen map[string]bool) []descriptor.GoPackage {
183 var imports []descriptor.GoPackage
184 for _, b := range m.Bindings {
185 for _, p := range b.PathParams {
186 e, err := g.reg.LookupEnum("", p.Target.GetTypeName())
187 if err != nil {
188 continue
189 }
190 pkg := e.File.GoPkg
191 if pkg == file.GoPkg || pkgSeen[pkg.Path] {
192 continue
193 }
194 pkgSeen[pkg.Path] = true
195 imports = append(imports, pkg)
196 }
197 }
198 return imports
199 }
200
View as plain text