1
2
3
4
5
6
7
8
9
10
11 package grpc
12
13 import (
14 "fmt"
15 "strconv"
16 "strings"
17
18 pb "github.com/golang/protobuf/protoc-gen-go/descriptor"
19 "github.com/golang/protobuf/protoc-gen-go/generator"
20 )
21
22
23
24
25
26 const generatedCodeVersion = 6
27
28
29
30 const (
31 contextPkgPath = "context"
32 grpcPkgPath = "google.golang.org/grpc"
33 codePkgPath = "google.golang.org/grpc/codes"
34 statusPkgPath = "google.golang.org/grpc/status"
35 )
36
37 func init() {
38 generator.RegisterPlugin(new(grpc))
39 }
40
41
42
43 type grpc struct {
44 gen *generator.Generator
45 }
46
47
48 func (g *grpc) Name() string {
49 return "grpc"
50 }
51
52
53
54
55 var (
56 contextPkg string
57 grpcPkg string
58 )
59
60
61 func (g *grpc) Init(gen *generator.Generator) {
62 g.gen = gen
63 }
64
65
66
67 func (g *grpc) objectNamed(name string) generator.Object {
68 g.gen.RecordTypeUse(name)
69 return g.gen.ObjectNamed(name)
70 }
71
72
73 func (g *grpc) typeName(str string) string {
74 return g.gen.TypeName(g.objectNamed(str))
75 }
76
77
78 func (g *grpc) P(args ...interface{}) { g.gen.P(args...) }
79
80
81 func (g *grpc) Generate(file *generator.FileDescriptor) {
82 if len(file.FileDescriptorProto.Service) == 0 {
83 return
84 }
85
86 contextPkg = string(g.gen.AddImport(contextPkgPath))
87 grpcPkg = string(g.gen.AddImport(grpcPkgPath))
88
89 g.P("// Reference imports to suppress errors if they are not otherwise used.")
90 g.P("var _ ", contextPkg, ".Context")
91 g.P("var _ ", grpcPkg, ".ClientConnInterface")
92 g.P()
93
94
95 g.P("// This is a compile-time assertion to ensure that this generated file")
96 g.P("// is compatible with the grpc package it is being compiled against.")
97 g.P("const _ = ", grpcPkg, ".SupportPackageIsVersion", generatedCodeVersion)
98 g.P()
99
100 for i, service := range file.FileDescriptorProto.Service {
101 g.generateService(file, service, i)
102 }
103 }
104
105
106 func (g *grpc) GenerateImports(file *generator.FileDescriptor) {
107 }
108
109
110 var reservedClientName = map[string]bool{
111
112 }
113
114 func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
115
116
117
118 var deprecationComment = "// Deprecated: Do not use."
119
120
121 func (g *grpc) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
122 path := fmt.Sprintf("6,%d", index)
123
124 origServName := service.GetName()
125 fullServName := origServName
126 if pkg := file.GetPackage(); pkg != "" {
127 fullServName = pkg + "." + fullServName
128 }
129 servName := generator.CamelCase(origServName)
130 deprecated := service.GetOptions().GetDeprecated()
131
132 g.P()
133 g.P(fmt.Sprintf(`// %sClient is the client API for %s service.
134 //
135 // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.`, servName, servName))
136
137
138 if deprecated {
139 g.P("//")
140 g.P(deprecationComment)
141 }
142 g.P("type ", servName, "Client interface {")
143 for i, method := range service.Method {
144 g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i))
145 if method.GetOptions().GetDeprecated() {
146 g.P("//")
147 g.P(deprecationComment)
148 }
149 g.P(g.generateClientSignature(servName, method))
150 }
151 g.P("}")
152 g.P()
153
154
155 g.P("type ", unexport(servName), "Client struct {")
156 g.P("cc ", grpcPkg, ".ClientConnInterface")
157 g.P("}")
158 g.P()
159
160
161 if deprecated {
162 g.P(deprecationComment)
163 }
164 g.P("func New", servName, "Client (cc ", grpcPkg, ".ClientConnInterface) ", servName, "Client {")
165 g.P("return &", unexport(servName), "Client{cc}")
166 g.P("}")
167 g.P()
168
169 var methodIndex, streamIndex int
170 serviceDescVar := "_" + servName + "_serviceDesc"
171
172 for _, method := range service.Method {
173 var descExpr string
174 if !method.GetServerStreaming() && !method.GetClientStreaming() {
175
176 descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex)
177 methodIndex++
178 } else {
179
180 descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex)
181 streamIndex++
182 }
183 g.generateClientMethod(servName, fullServName, serviceDescVar, method, descExpr)
184 }
185
186
187 serverType := servName + "Server"
188 g.P("// ", serverType, " is the server API for ", servName, " service.")
189 if deprecated {
190 g.P("//")
191 g.P(deprecationComment)
192 }
193 g.P("type ", serverType, " interface {")
194 for i, method := range service.Method {
195 g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i))
196 if method.GetOptions().GetDeprecated() {
197 g.P("//")
198 g.P(deprecationComment)
199 }
200 g.P(g.generateServerSignature(servName, method))
201 }
202 g.P("}")
203 g.P()
204
205
206 if deprecated {
207 g.P(deprecationComment)
208 }
209 g.generateUnimplementedServer(servName, service)
210
211
212 if deprecated {
213 g.P(deprecationComment)
214 }
215 g.P("func Register", servName, "Server(s *", grpcPkg, ".Server, srv ", serverType, ") {")
216 g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
217 g.P("}")
218 g.P()
219
220
221 var handlerNames []string
222 for _, method := range service.Method {
223 hname := g.generateServerMethod(servName, fullServName, method)
224 handlerNames = append(handlerNames, hname)
225 }
226
227
228 g.P("var ", serviceDescVar, " = ", grpcPkg, ".ServiceDesc {")
229 g.P("ServiceName: ", strconv.Quote(fullServName), ",")
230 g.P("HandlerType: (*", serverType, ")(nil),")
231 g.P("Methods: []", grpcPkg, ".MethodDesc{")
232 for i, method := range service.Method {
233 if method.GetServerStreaming() || method.GetClientStreaming() {
234 continue
235 }
236 g.P("{")
237 g.P("MethodName: ", strconv.Quote(method.GetName()), ",")
238 g.P("Handler: ", handlerNames[i], ",")
239 g.P("},")
240 }
241 g.P("},")
242 g.P("Streams: []", grpcPkg, ".StreamDesc{")
243 for i, method := range service.Method {
244 if !method.GetServerStreaming() && !method.GetClientStreaming() {
245 continue
246 }
247 g.P("{")
248 g.P("StreamName: ", strconv.Quote(method.GetName()), ",")
249 g.P("Handler: ", handlerNames[i], ",")
250 if method.GetServerStreaming() {
251 g.P("ServerStreams: true,")
252 }
253 if method.GetClientStreaming() {
254 g.P("ClientStreams: true,")
255 }
256 g.P("},")
257 }
258 g.P("},")
259 g.P("Metadata: \"", file.GetName(), "\",")
260 g.P("}")
261 g.P()
262 }
263
264
265 func (g *grpc) generateUnimplementedServer(servName string, service *pb.ServiceDescriptorProto) {
266 serverType := servName + "Server"
267 g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.")
268 g.P("type Unimplemented", serverType, " struct {")
269 g.P("}")
270 g.P()
271
272 for _, method := range service.Method {
273 g.generateServerMethodConcrete(servName, method)
274 }
275 g.P()
276 }
277
278
279 func (g *grpc) generateServerMethodConcrete(servName string, method *pb.MethodDescriptorProto) {
280 header := g.generateServerSignatureWithParamNames(servName, method)
281 g.P("func (*Unimplemented", servName, "Server) ", header, " {")
282 var nilArg string
283 if !method.GetServerStreaming() && !method.GetClientStreaming() {
284 nilArg = "nil, "
285 }
286 methName := generator.CamelCase(method.GetName())
287 statusPkg := string(g.gen.AddImport(statusPkgPath))
288 codePkg := string(g.gen.AddImport(codePkgPath))
289 g.P("return ", nilArg, statusPkg, `.Errorf(`, codePkg, `.Unimplemented, "method `, methName, ` not implemented")`)
290 g.P("}")
291 }
292
293
294 func (g *grpc) generateClientSignature(servName string, method *pb.MethodDescriptorProto) string {
295 origMethName := method.GetName()
296 methName := generator.CamelCase(origMethName)
297 if reservedClientName[methName] {
298 methName += "_"
299 }
300 reqArg := ", in *" + g.typeName(method.GetInputType())
301 if method.GetClientStreaming() {
302 reqArg = ""
303 }
304 respName := "*" + g.typeName(method.GetOutputType())
305 if method.GetServerStreaming() || method.GetClientStreaming() {
306 respName = servName + "_" + generator.CamelCase(origMethName) + "Client"
307 }
308 return fmt.Sprintf("%s(ctx %s.Context%s, opts ...%s.CallOption) (%s, error)", methName, contextPkg, reqArg, grpcPkg, respName)
309 }
310
311 func (g *grpc) generateClientMethod(servName, fullServName, serviceDescVar string, method *pb.MethodDescriptorProto, descExpr string) {
312 sname := fmt.Sprintf("/%s/%s", fullServName, method.GetName())
313 methName := generator.CamelCase(method.GetName())
314 inType := g.typeName(method.GetInputType())
315 outType := g.typeName(method.GetOutputType())
316
317 if method.GetOptions().GetDeprecated() {
318 g.P(deprecationComment)
319 }
320 g.P("func (c *", unexport(servName), "Client) ", g.generateClientSignature(servName, method), "{")
321 if !method.GetServerStreaming() && !method.GetClientStreaming() {
322 g.P("out := new(", outType, ")")
323
324 g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
325 g.P("if err != nil { return nil, err }")
326 g.P("return out, nil")
327 g.P("}")
328 g.P()
329 return
330 }
331 streamType := unexport(servName) + methName + "Client"
332 g.P("stream, err := c.cc.NewStream(ctx, ", descExpr, `, "`, sname, `", opts...)`)
333 g.P("if err != nil { return nil, err }")
334 g.P("x := &", streamType, "{stream}")
335 if !method.GetClientStreaming() {
336 g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
337 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
338 }
339 g.P("return x, nil")
340 g.P("}")
341 g.P()
342
343 genSend := method.GetClientStreaming()
344 genRecv := method.GetServerStreaming()
345 genCloseAndRecv := !method.GetServerStreaming()
346
347
348 g.P("type ", servName, "_", methName, "Client interface {")
349 if genSend {
350 g.P("Send(*", inType, ") error")
351 }
352 if genRecv {
353 g.P("Recv() (*", outType, ", error)")
354 }
355 if genCloseAndRecv {
356 g.P("CloseAndRecv() (*", outType, ", error)")
357 }
358 g.P(grpcPkg, ".ClientStream")
359 g.P("}")
360 g.P()
361
362 g.P("type ", streamType, " struct {")
363 g.P(grpcPkg, ".ClientStream")
364 g.P("}")
365 g.P()
366
367 if genSend {
368 g.P("func (x *", streamType, ") Send(m *", inType, ") error {")
369 g.P("return x.ClientStream.SendMsg(m)")
370 g.P("}")
371 g.P()
372 }
373 if genRecv {
374 g.P("func (x *", streamType, ") Recv() (*", outType, ", error) {")
375 g.P("m := new(", outType, ")")
376 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
377 g.P("return m, nil")
378 g.P("}")
379 g.P()
380 }
381 if genCloseAndRecv {
382 g.P("func (x *", streamType, ") CloseAndRecv() (*", outType, ", error) {")
383 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
384 g.P("m := new(", outType, ")")
385 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
386 g.P("return m, nil")
387 g.P("}")
388 g.P()
389 }
390 }
391
392
393 func (g *grpc) generateServerSignatureWithParamNames(servName string, method *pb.MethodDescriptorProto) string {
394 origMethName := method.GetName()
395 methName := generator.CamelCase(origMethName)
396 if reservedClientName[methName] {
397 methName += "_"
398 }
399
400 var reqArgs []string
401 ret := "error"
402 if !method.GetServerStreaming() && !method.GetClientStreaming() {
403 reqArgs = append(reqArgs, "ctx "+contextPkg+".Context")
404 ret = "(*" + g.typeName(method.GetOutputType()) + ", error)"
405 }
406 if !method.GetClientStreaming() {
407 reqArgs = append(reqArgs, "req *"+g.typeName(method.GetInputType()))
408 }
409 if method.GetServerStreaming() || method.GetClientStreaming() {
410 reqArgs = append(reqArgs, "srv "+servName+"_"+generator.CamelCase(origMethName)+"Server")
411 }
412
413 return methName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
414 }
415
416
417 func (g *grpc) generateServerSignature(servName string, method *pb.MethodDescriptorProto) string {
418 origMethName := method.GetName()
419 methName := generator.CamelCase(origMethName)
420 if reservedClientName[methName] {
421 methName += "_"
422 }
423
424 var reqArgs []string
425 ret := "error"
426 if !method.GetServerStreaming() && !method.GetClientStreaming() {
427 reqArgs = append(reqArgs, contextPkg+".Context")
428 ret = "(*" + g.typeName(method.GetOutputType()) + ", error)"
429 }
430 if !method.GetClientStreaming() {
431 reqArgs = append(reqArgs, "*"+g.typeName(method.GetInputType()))
432 }
433 if method.GetServerStreaming() || method.GetClientStreaming() {
434 reqArgs = append(reqArgs, servName+"_"+generator.CamelCase(origMethName)+"Server")
435 }
436
437 return methName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
438 }
439
440 func (g *grpc) generateServerMethod(servName, fullServName string, method *pb.MethodDescriptorProto) string {
441 methName := generator.CamelCase(method.GetName())
442 hname := fmt.Sprintf("_%s_%s_Handler", servName, methName)
443 inType := g.typeName(method.GetInputType())
444 outType := g.typeName(method.GetOutputType())
445
446 if !method.GetServerStreaming() && !method.GetClientStreaming() {
447 g.P("func ", hname, "(srv interface{}, ctx ", contextPkg, ".Context, dec func(interface{}) error, interceptor ", grpcPkg, ".UnaryServerInterceptor) (interface{}, error) {")
448 g.P("in := new(", inType, ")")
449 g.P("if err := dec(in); err != nil { return nil, err }")
450 g.P("if interceptor == nil { return srv.(", servName, "Server).", methName, "(ctx, in) }")
451 g.P("info := &", grpcPkg, ".UnaryServerInfo{")
452 g.P("Server: srv,")
453 g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", fullServName, methName)), ",")
454 g.P("}")
455 g.P("handler := func(ctx ", contextPkg, ".Context, req interface{}) (interface{}, error) {")
456 g.P("return srv.(", servName, "Server).", methName, "(ctx, req.(*", inType, "))")
457 g.P("}")
458 g.P("return interceptor(ctx, in, info, handler)")
459 g.P("}")
460 g.P()
461 return hname
462 }
463 streamType := unexport(servName) + methName + "Server"
464 g.P("func ", hname, "(srv interface{}, stream ", grpcPkg, ".ServerStream) error {")
465 if !method.GetClientStreaming() {
466 g.P("m := new(", inType, ")")
467 g.P("if err := stream.RecvMsg(m); err != nil { return err }")
468 g.P("return srv.(", servName, "Server).", methName, "(m, &", streamType, "{stream})")
469 } else {
470 g.P("return srv.(", servName, "Server).", methName, "(&", streamType, "{stream})")
471 }
472 g.P("}")
473 g.P()
474
475 genSend := method.GetServerStreaming()
476 genSendAndClose := !method.GetServerStreaming()
477 genRecv := method.GetClientStreaming()
478
479
480 g.P("type ", servName, "_", methName, "Server interface {")
481 if genSend {
482 g.P("Send(*", outType, ") error")
483 }
484 if genSendAndClose {
485 g.P("SendAndClose(*", outType, ") error")
486 }
487 if genRecv {
488 g.P("Recv() (*", inType, ", error)")
489 }
490 g.P(grpcPkg, ".ServerStream")
491 g.P("}")
492 g.P()
493
494 g.P("type ", streamType, " struct {")
495 g.P(grpcPkg, ".ServerStream")
496 g.P("}")
497 g.P()
498
499 if genSend {
500 g.P("func (x *", streamType, ") Send(m *", outType, ") error {")
501 g.P("return x.ServerStream.SendMsg(m)")
502 g.P("}")
503 g.P()
504 }
505 if genSendAndClose {
506 g.P("func (x *", streamType, ") SendAndClose(m *", outType, ") error {")
507 g.P("return x.ServerStream.SendMsg(m)")
508 g.P("}")
509 g.P()
510 }
511 if genRecv {
512 g.P("func (x *", streamType, ") Recv() (*", inType, ", error) {")
513 g.P("m := new(", inType, ")")
514 g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
515 g.P("return m, nil")
516 g.P("}")
517 g.P()
518 }
519
520 return hname
521 }
522
View as plain text