1
2
3
4
5
6 package gengogrpc
7
8 import (
9 "fmt"
10 "strconv"
11 "strings"
12
13 "google.golang.org/protobuf/compiler/protogen"
14
15 "google.golang.org/protobuf/types/descriptorpb"
16 )
17
18 const (
19 contextPackage = protogen.GoImportPath("context")
20 grpcPackage = protogen.GoImportPath("google.golang.org/grpc")
21 codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes")
22 statusPackage = protogen.GoImportPath("google.golang.org/grpc/status")
23 )
24
25
26 func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
27 if len(file.Services) == 0 {
28 return nil
29 }
30 filename := file.GeneratedFilenamePrefix + "_grpc.pb.go"
31 g := gen.NewGeneratedFile(filename, file.GoImportPath)
32 g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.")
33 g.P()
34 g.P("package ", file.GoPackageName)
35 g.P()
36 GenerateFileContent(gen, file, g)
37 return g
38 }
39
40
41 func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
42 if len(file.Services) == 0 {
43 return
44 }
45
46
47 g.P("// Reference imports to suppress errors if they are not otherwise used.")
48 g.P("var _ ", contextPackage.Ident("Context"))
49 g.P("var _ ", grpcPackage.Ident("ClientConnInterface"))
50 g.P()
51
52 g.P("// This is a compile-time assertion to ensure that this generated file")
53 g.P("// is compatible with the grpc package it is being compiled against.")
54 g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6"))
55 g.P()
56 for _, service := range file.Services {
57 genService(gen, file, g, service)
58 }
59 }
60
61 func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
62 clientName := service.GoName + "Client"
63
64 g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
65 g.P("//")
66 g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.")
67
68
69 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
70 g.P("//")
71 g.P(deprecationComment)
72 }
73 g.Annotate(clientName, service.Location)
74 g.P("type ", clientName, " interface {")
75 for _, method := range service.Methods {
76 g.Annotate(clientName+"."+method.GoName, method.Location)
77 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
78 g.P(deprecationComment)
79 }
80 g.P(method.Comments.Leading,
81 clientSignature(g, method))
82 }
83 g.P("}")
84 g.P()
85
86
87 g.P("type ", unexport(clientName), " struct {")
88 g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
89 g.P("}")
90 g.P()
91
92
93 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
94 g.P(deprecationComment)
95 }
96 g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
97 g.P("return &", unexport(clientName), "{cc}")
98 g.P("}")
99 g.P()
100
101 var methodIndex, streamIndex int
102
103 for _, method := range service.Methods {
104 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
105
106 genClientMethod(gen, file, g, method, methodIndex)
107 methodIndex++
108 } else {
109
110 genClientMethod(gen, file, g, method, streamIndex)
111 streamIndex++
112 }
113 }
114
115
116 serverType := service.GoName + "Server"
117 g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
118 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
119 g.P("//")
120 g.P(deprecationComment)
121 }
122 g.Annotate(serverType, service.Location)
123 g.P("type ", serverType, " interface {")
124 for _, method := range service.Methods {
125 g.Annotate(serverType+"."+method.GoName, method.Location)
126 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
127 g.P(deprecationComment)
128 }
129 g.P(method.Comments.Leading,
130 serverSignature(g, method))
131 }
132 g.P("}")
133 g.P()
134
135
136 g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.")
137 g.P("type Unimplemented", serverType, " struct {")
138 g.P("}")
139 g.P()
140 for _, method := range service.Methods {
141 nilArg := ""
142 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
143 nilArg = "nil,"
144 }
145 g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{")
146 g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
147 g.P("}")
148 }
149 g.P()
150
151
152 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
153 g.P(deprecationComment)
154 }
155 serviceDescVar := "_" + service.GoName + "_serviceDesc"
156 g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {")
157 g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
158 g.P("}")
159 g.P()
160
161
162 var handlerNames []string
163 for _, method := range service.Methods {
164 hname := genServerMethod(gen, file, g, method)
165 handlerNames = append(handlerNames, hname)
166 }
167
168
169 g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
170 g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
171 g.P("HandlerType: (*", serverType, ")(nil),")
172 g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
173 for i, method := range service.Methods {
174 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
175 continue
176 }
177 g.P("{")
178 g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
179 g.P("Handler: ", handlerNames[i], ",")
180 g.P("},")
181 }
182 g.P("},")
183 g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
184 for i, method := range service.Methods {
185 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
186 continue
187 }
188 g.P("{")
189 g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
190 g.P("Handler: ", handlerNames[i], ",")
191 if method.Desc.IsStreamingServer() {
192 g.P("ServerStreams: true,")
193 }
194 if method.Desc.IsStreamingClient() {
195 g.P("ClientStreams: true,")
196 }
197 g.P("},")
198 }
199 g.P("},")
200 g.P("Metadata: \"", file.Desc.Path(), "\",")
201 g.P("}")
202 g.P()
203 }
204
205 func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
206 s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
207 if !method.Desc.IsStreamingClient() {
208 s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
209 }
210 s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
211 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
212 s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
213 } else {
214 s += method.Parent.GoName + "_" + method.GoName + "Client"
215 }
216 s += ", error)"
217 return s
218 }
219
220 func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
221 service := method.Parent
222 sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
223
224 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
225 g.P(deprecationComment)
226 }
227 g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
228 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
229 g.P("out := new(", method.Output.GoIdent, ")")
230 g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
231 g.P("if err != nil { return nil, err }")
232 g.P("return out, nil")
233 g.P("}")
234 g.P()
235 return
236 }
237 streamType := unexport(service.GoName) + method.GoName + "Client"
238 serviceDescVar := "_" + service.GoName + "_serviceDesc"
239 g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
240 g.P("if err != nil { return nil, err }")
241 g.P("x := &", streamType, "{stream}")
242 if !method.Desc.IsStreamingClient() {
243 g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
244 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
245 }
246 g.P("return x, nil")
247 g.P("}")
248 g.P()
249
250 genSend := method.Desc.IsStreamingClient()
251 genRecv := method.Desc.IsStreamingServer()
252 genCloseAndRecv := !method.Desc.IsStreamingServer()
253
254
255 g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
256 if genSend {
257 g.P("Send(*", method.Input.GoIdent, ") error")
258 }
259 if genRecv {
260 g.P("Recv() (*", method.Output.GoIdent, ", error)")
261 }
262 if genCloseAndRecv {
263 g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
264 }
265 g.P(grpcPackage.Ident("ClientStream"))
266 g.P("}")
267 g.P()
268
269 g.P("type ", streamType, " struct {")
270 g.P(grpcPackage.Ident("ClientStream"))
271 g.P("}")
272 g.P()
273
274 if genSend {
275 g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
276 g.P("return x.ClientStream.SendMsg(m)")
277 g.P("}")
278 g.P()
279 }
280 if genRecv {
281 g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
282 g.P("m := new(", method.Output.GoIdent, ")")
283 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
284 g.P("return m, nil")
285 g.P("}")
286 g.P()
287 }
288 if genCloseAndRecv {
289 g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
290 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
291 g.P("m := new(", method.Output.GoIdent, ")")
292 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
293 g.P("return m, nil")
294 g.P("}")
295 g.P()
296 }
297 }
298
299 func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
300 var reqArgs []string
301 ret := "error"
302 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
303 reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
304 ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
305 }
306 if !method.Desc.IsStreamingClient() {
307 reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
308 }
309 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
310 reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
311 }
312 return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
313 }
314
315 func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
316 service := method.Parent
317 hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
318
319 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
320 g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
321 g.P("in := new(", method.Input.GoIdent, ")")
322 g.P("if err := dec(in); err != nil { return nil, err }")
323 g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
324 g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
325 g.P("Server: srv,")
326 g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",")
327 g.P("}")
328 g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
329 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
330 g.P("}")
331 g.P("return interceptor(ctx, in, info, handler)")
332 g.P("}")
333 g.P()
334 return hname
335 }
336 streamType := unexport(service.GoName) + method.GoName + "Server"
337 g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
338 if !method.Desc.IsStreamingClient() {
339 g.P("m := new(", method.Input.GoIdent, ")")
340 g.P("if err := stream.RecvMsg(m); err != nil { return err }")
341 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
342 } else {
343 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
344 }
345 g.P("}")
346 g.P()
347
348 genSend := method.Desc.IsStreamingServer()
349 genSendAndClose := !method.Desc.IsStreamingServer()
350 genRecv := method.Desc.IsStreamingClient()
351
352
353 g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
354 if genSend {
355 g.P("Send(*", method.Output.GoIdent, ") error")
356 }
357 if genSendAndClose {
358 g.P("SendAndClose(*", method.Output.GoIdent, ") error")
359 }
360 if genRecv {
361 g.P("Recv() (*", method.Input.GoIdent, ", error)")
362 }
363 g.P(grpcPackage.Ident("ServerStream"))
364 g.P("}")
365 g.P()
366
367 g.P("type ", streamType, " struct {")
368 g.P(grpcPackage.Ident("ServerStream"))
369 g.P("}")
370 g.P()
371
372 if genSend {
373 g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
374 g.P("return x.ServerStream.SendMsg(m)")
375 g.P("}")
376 g.P()
377 }
378 if genSendAndClose {
379 g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
380 g.P("return x.ServerStream.SendMsg(m)")
381 g.P("}")
382 g.P()
383 }
384 if genRecv {
385 g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
386 g.P("m := new(", method.Input.GoIdent, ")")
387 g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
388 g.P("return m, nil")
389 g.P("}")
390 g.P()
391 }
392
393 return hname
394 }
395
396 const deprecationComment = "// Deprecated: Do not use."
397
398 func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
399
View as plain text