1
16
17 package plugin
18
19 import (
20 "strings"
21
22 "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
23 "github.com/gogo/protobuf/protoc-gen-gogo/generator"
24 )
25
26 type ttrpcGenerator struct {
27 *generator.Generator
28 generator.PluginImports
29
30 typeurlPkg generator.Single
31 ttrpcPkg generator.Single
32 contextPkg generator.Single
33 }
34
35 func init() {
36 generator.RegisterPlugin(new(ttrpcGenerator))
37 }
38
39 func (p *ttrpcGenerator) Name() string {
40 return "ttrpc"
41 }
42
43 func (p *ttrpcGenerator) Init(g *generator.Generator) {
44 p.Generator = g
45 }
46
47 func (p *ttrpcGenerator) Generate(file *generator.FileDescriptor) {
48 p.PluginImports = generator.NewPluginImports(p.Generator)
49 p.contextPkg = p.NewImport("context")
50 p.typeurlPkg = p.NewImport("github.com/containerd/typeurl")
51 p.ttrpcPkg = p.NewImport("github.com/containerd/ttrpc")
52
53 for _, service := range file.GetService() {
54 serviceName := service.GetName()
55 if pkg := file.GetPackage(); pkg != "" {
56 serviceName = pkg + "." + serviceName
57 }
58
59 p.genService(serviceName, service)
60 }
61 }
62
63 func (p *ttrpcGenerator) genService(fullName string, service *descriptor.ServiceDescriptorProto) {
64 serviceName := service.GetName() + "Service"
65 p.P()
66 p.P("type ", serviceName, " interface{")
67 p.In()
68 for _, method := range service.Method {
69 p.P(method.GetName(),
70 "(ctx ", p.contextPkg.Use(), ".Context, ",
71 "req *", p.typeName(method.GetInputType()), ") ",
72 "(*", p.typeName(method.GetOutputType()), ", error)")
73
74 }
75 p.Out()
76 p.P("}")
77
78 p.P()
79
80 p.P("func Register", serviceName, "(srv *", p.ttrpcPkg.Use(), ".Server, svc ", serviceName, ") {")
81 p.In()
82 p.P(`srv.Register("`, fullName, `", map[string]`, p.ttrpcPkg.Use(), ".Method{")
83 p.In()
84 for _, method := range service.Method {
85 p.P(`"`, method.GetName(), `": `, `func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {`)
86 p.In()
87 p.P("var req ", p.typeName(method.GetInputType()))
88 p.P(`if err := unmarshal(&req); err != nil {`)
89 p.In()
90 p.P(`return nil, err`)
91 p.Out()
92 p.P(`}`)
93 p.P("return svc.", method.GetName(), "(ctx, &req)")
94 p.Out()
95 p.P("},")
96 }
97 p.Out()
98 p.P("})")
99 p.Out()
100 p.P("}")
101
102 clientType := service.GetName() + "Client"
103 clientStructType := strings.ToLower(clientType[:1]) + clientType[1:]
104 p.P()
105 p.P("type ", clientStructType, " struct{")
106 p.In()
107 p.P("client *", p.ttrpcPkg.Use(), ".Client")
108 p.Out()
109 p.P("}")
110 p.P()
111 p.P("func New", clientType, "(client *", p.ttrpcPkg.Use(), ".Client)", serviceName, "{")
112 p.In()
113 p.P("return &", clientStructType, "{")
114 p.In()
115 p.P("client: client,")
116 p.Out()
117 p.P("}")
118 p.Out()
119 p.P("}")
120 p.P()
121 for _, method := range service.Method {
122 p.P()
123 p.P("func (c *", clientStructType, ") ", method.GetName(),
124 "(ctx ", p.contextPkg.Use(), ".Context, ",
125 "req *", p.typeName(method.GetInputType()), ") ",
126 "(*", p.typeName(method.GetOutputType()), ", error) {")
127 p.In()
128 p.P("var resp ", p.typeName(method.GetOutputType()))
129 p.P("if err := c.client.Call(ctx, ", `"`+fullName+`", `, `"`+method.GetName()+`"`, ", req, &resp); err != nil {")
130 p.In()
131 p.P("return nil, err")
132 p.Out()
133 p.P("}")
134 p.P("return &resp, nil")
135 p.Out()
136 p.P("}")
137 }
138 }
139
140 func (p *ttrpcGenerator) objectNamed(name string) generator.Object {
141 p.Generator.RecordTypeUse(name)
142 return p.Generator.ObjectNamed(name)
143 }
144
145 func (p *ttrpcGenerator) typeName(str string) string {
146 return p.Generator.TypeName(p.objectNamed(str))
147 }
148
View as plain text