1
16
17 package main
18
19 import (
20 "fmt"
21 "strings"
22
23 "google.golang.org/protobuf/compiler/protogen"
24 )
25
26
27
28
29 type generator struct {
30 out *protogen.GeneratedFile
31
32 ident struct {
33 context string
34 server string
35 client string
36 method string
37 stream string
38 serviceDesc string
39 streamDesc string
40
41 streamServerIdent protogen.GoIdent
42 streamClientIdent protogen.GoIdent
43
44 streamServer string
45 streamClient string
46 }
47 }
48
49 func newGenerator(out *protogen.GeneratedFile) *generator {
50 gen := generator{out: out}
51 gen.ident.context = out.QualifiedGoIdent(protogen.GoIdent{
52 GoImportPath: "context",
53 GoName: "Context",
54 })
55 gen.ident.server = out.QualifiedGoIdent(protogen.GoIdent{
56 GoImportPath: "github.com/containerd/ttrpc",
57 GoName: "Server",
58 })
59 gen.ident.client = out.QualifiedGoIdent(protogen.GoIdent{
60 GoImportPath: "github.com/containerd/ttrpc",
61 GoName: "Client",
62 })
63 gen.ident.method = out.QualifiedGoIdent(protogen.GoIdent{
64 GoImportPath: "github.com/containerd/ttrpc",
65 GoName: "Method",
66 })
67 gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
68 GoImportPath: "github.com/containerd/ttrpc",
69 GoName: "Stream",
70 })
71 gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
72 GoImportPath: "github.com/containerd/ttrpc",
73 GoName: "ServiceDesc",
74 })
75 gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
76 GoImportPath: "github.com/containerd/ttrpc",
77 GoName: "StreamDesc",
78 })
79
80 gen.ident.streamServerIdent = protogen.GoIdent{
81 GoImportPath: "github.com/containerd/ttrpc",
82 GoName: "StreamServer",
83 }
84 gen.ident.streamClientIdent = protogen.GoIdent{
85 GoImportPath: "github.com/containerd/ttrpc",
86 GoName: "ClientStream",
87 }
88 gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
89 gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
90 return &gen
91 }
92
93 func generate(plugin *protogen.Plugin, input *protogen.File, servicePrefix string) error {
94 if len(input.Services) == 0 {
95
96 return nil
97 }
98
99 file := plugin.NewGeneratedFile(input.GeneratedFilenamePrefix+"_ttrpc.pb.go", input.GoImportPath)
100 file.P("// Code generated by protoc-gen-go-ttrpc. DO NOT EDIT.")
101 file.P("// source: ", input.Desc.Path())
102 file.P("package ", input.GoPackageName)
103
104 gen := newGenerator(file)
105 for _, service := range input.Services {
106 service.GoName = servicePrefix + service.GoName
107 gen.genService(service)
108 }
109 return nil
110 }
111
112 func (gen *generator) genService(service *protogen.Service) {
113 fullName := service.Desc.FullName()
114 p := gen.out
115
116 var methods []*protogen.Method
117 var streams []*protogen.Method
118
119 serviceName := service.GoName + "Service"
120 p.P("type ", serviceName, " interface{")
121 for _, method := range service.Methods {
122 var sendArgs, retArgs string
123 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
124 streams = append(streams, method)
125 sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
126 if !method.Desc.IsStreamingClient() {
127 sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
128 }
129 if method.Desc.IsStreamingServer() {
130 retArgs = "error"
131 } else {
132 retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
133 }
134 } else {
135 methods = append(methods, method)
136 sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
137 retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
138 }
139 p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
140 }
141 p.P("}")
142 p.P()
143
144 for _, method := range streams {
145 structName := strings.ToLower(service.GoName) + method.GoName + "Server"
146
147 p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
148 if method.Desc.IsStreamingServer() {
149 p.P("Send(*", method.Output.GoIdent, ") error")
150 }
151 if method.Desc.IsStreamingClient() {
152 p.P("Recv() (*", method.Input.GoIdent, ", error)")
153
154 }
155 p.P(gen.ident.streamServer)
156 p.P("}")
157 p.P()
158
159 p.P("type ", structName, " struct {")
160 p.P(gen.ident.streamServer)
161 p.P("}")
162 p.P()
163
164 if method.Desc.IsStreamingServer() {
165 p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
166 p.P("return x.StreamServer.SendMsg(m)")
167 p.P("}")
168 p.P()
169 }
170
171 if method.Desc.IsStreamingClient() {
172 p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
173 p.P("m := new(", method.Input.GoIdent, ")")
174 p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
175 p.P("return nil, err")
176 p.P("}")
177 p.P("return m, nil")
178 p.P("}")
179 p.P()
180 }
181 }
182
183
184 p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
185 p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
186 if len(methods) > 0 {
187 p.P(`Methods: map[string]`, gen.ident.method, "{")
188 for _, method := range methods {
189 p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
190 p.P("var req ", method.Input.GoIdent)
191 p.P("if err := unmarshal(&req); err != nil {")
192 p.P("return nil, err")
193 p.P("}")
194 p.P("return svc.", method.GoName, "(ctx, &req)")
195 p.P("},")
196 }
197 p.P("},")
198 }
199 if len(streams) > 0 {
200 p.P(`Streams: map[string]`, gen.ident.stream, "{")
201 for _, method := range streams {
202 p.P(`"`, method.GoName, `": {`)
203 p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")
204
205 structName := strings.ToLower(service.GoName) + method.GoName + "Server"
206 var sendArg string
207 if !method.Desc.IsStreamingClient() {
208 sendArg = "m, "
209 p.P("m := new(", method.Input.GoIdent, ")")
210 p.P("if err := stream.RecvMsg(m); err != nil {")
211 p.P("return nil, err")
212 p.P("}")
213 }
214 if method.Desc.IsStreamingServer() {
215 p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
216 } else {
217 p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
218
219 }
220 p.P("},")
221 if method.Desc.IsStreamingClient() {
222 p.P("StreamingClient: true,")
223 } else {
224 p.P("StreamingClient: false,")
225 }
226 if method.Desc.IsStreamingServer() {
227 p.P("StreamingServer: true,")
228 } else {
229 p.P("StreamingServer: false,")
230 }
231 p.P("},")
232 }
233 p.P("},")
234 }
235 p.P("})")
236 p.P("}")
237 p.P()
238
239 clientType := service.GoName + "Client"
240
241
242
243 clientInterface := serviceName
244 if len(streams) > 0 {
245 clientInterface = clientType
246
247 p.P("type ", clientInterface, " interface{")
248 for _, method := range service.Methods {
249 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
250 streams = append(streams, method)
251 var sendArg string
252 if !method.Desc.IsStreamingClient() {
253 sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
254 }
255 p.P(method.GoName,
256 "(", gen.ident.context, ", ", sendArg,
257 ") (", service.GoName, "_", method.GoName, "Client, error)")
258 } else {
259 methods = append(methods, method)
260 p.P(method.GoName,
261 "(", gen.ident.context, ", ",
262 "*", method.Input.GoIdent, ")",
263 "(*", method.Output.GoIdent, ", error)")
264 }
265 }
266 p.P("}")
267 p.P()
268 }
269
270 clientStructType := strings.ToLower(service.GoName) + "Client"
271 p.P("type ", clientStructType, " struct{")
272 p.P("client *", gen.ident.client)
273 p.P("}")
274 p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
275 p.P("return &", clientStructType, "{")
276 p.P("client:client,")
277 p.P("}")
278 p.P("}")
279 p.P()
280
281 for _, method := range service.Methods {
282 var sendArg string
283 if !method.Desc.IsStreamingClient() {
284 sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
285 }
286
287 intName := service.GoName + "_" + method.GoName + "Client"
288 var retArg string
289 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
290 retArg = intName
291 } else {
292 retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
293 }
294
295 p.P("func (c *", clientStructType, ") ", method.GoName,
296 "(ctx ", gen.ident.context, "", sendArg, ") ",
297 "(", retArg, ", error) {")
298
299 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
300 var streamingClient, streamingServer, req string
301 if method.Desc.IsStreamingClient() {
302 streamingClient = "true"
303 req = "nil"
304 } else {
305 streamingClient = "false"
306 req = "req"
307 }
308 if method.Desc.IsStreamingServer() {
309 streamingServer = "true"
310 } else {
311 streamingServer = "false"
312 }
313 p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
314 p.P("StreamingClient: ", streamingClient, ",")
315 p.P("StreamingServer: ", streamingServer, ",")
316 p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
317 p.P("if err != nil {")
318 p.P("return nil, err")
319 p.P("}")
320
321 structName := strings.ToLower(service.GoName) + method.GoName + "Client"
322
323 p.P("x := &", structName, "{stream}")
324
325 p.P("return x, nil")
326 p.P("}")
327 p.P()
328
329
330 p.P("type ", intName, " interface {")
331 if method.Desc.IsStreamingClient() {
332 p.P("Send(*", method.Input.GoIdent, ") error")
333 }
334 if method.Desc.IsStreamingServer() {
335 p.P("Recv() (*", method.Output.GoIdent, ", error)")
336 } else {
337 p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
338 }
339
340 p.P(gen.ident.streamClient)
341 p.P("}")
342 p.P()
343
344
345 p.P("type ", structName, " struct {")
346 p.P(gen.ident.streamClient)
347 p.P("}")
348 p.P()
349
350 if method.Desc.IsStreamingClient() {
351 p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
352 p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
353 p.P("}")
354 p.P()
355 }
356
357 if method.Desc.IsStreamingServer() {
358 p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
359 p.P("m := new(", method.Output.GoIdent, ")")
360 p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
361 p.P("return nil, err")
362 p.P("}")
363 p.P("return m, nil")
364 p.P("}")
365 p.P()
366 } else {
367 p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
368 p.P("if err := x.ClientStream.CloseSend(); err != nil {")
369 p.P("return nil, err")
370 p.P("}")
371 p.P("m := new(", method.Output.GoIdent, ")")
372 p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
373 p.P("return nil, err")
374 p.P("}")
375 p.P("return m, nil")
376 p.P("}")
377 p.P()
378 }
379 } else {
380 p.P("var resp ", method.Output.GoIdent)
381 p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
382 p.P("return nil, err")
383 p.P("}")
384 p.P("return &resp, nil")
385 p.P("}")
386 p.P()
387 }
388 }
389 }
390
View as plain text