...

Source file src/github.com/containerd/ttrpc/cmd/protoc-gen-go-ttrpc/generator.go

Documentation: github.com/containerd/ttrpc/cmd/protoc-gen-go-ttrpc

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"fmt"
    21  	"strings"
    22  
    23  	"google.golang.org/protobuf/compiler/protogen"
    24  )
    25  
    26  // generator is a Go code generator that uses ttrpc.Server and ttrpc.Client.
    27  // Unlike the original gogo version, this doesn't generate serializers for message types and
    28  // let protoc-gen-go handle them.
    29  type generator struct {
    30  	out *protogen.GeneratedFile
    31  
    32  	ident struct {
    33  		context     string
    34  		server      string
    35  		client      string
    36  		method      string
    37  		stream      string
    38  		serviceDesc string
    39  		streamDesc  string
    40  
    41  		streamServerIdent protogen.GoIdent
    42  		streamClientIdent protogen.GoIdent
    43  
    44  		streamServer string
    45  		streamClient string
    46  	}
    47  }
    48  
    49  func newGenerator(out *protogen.GeneratedFile) *generator {
    50  	gen := generator{out: out}
    51  	gen.ident.context = out.QualifiedGoIdent(protogen.GoIdent{
    52  		GoImportPath: "context",
    53  		GoName:       "Context",
    54  	})
    55  	gen.ident.server = out.QualifiedGoIdent(protogen.GoIdent{
    56  		GoImportPath: "github.com/containerd/ttrpc",
    57  		GoName:       "Server",
    58  	})
    59  	gen.ident.client = out.QualifiedGoIdent(protogen.GoIdent{
    60  		GoImportPath: "github.com/containerd/ttrpc",
    61  		GoName:       "Client",
    62  	})
    63  	gen.ident.method = out.QualifiedGoIdent(protogen.GoIdent{
    64  		GoImportPath: "github.com/containerd/ttrpc",
    65  		GoName:       "Method",
    66  	})
    67  	gen.ident.stream = out.QualifiedGoIdent(protogen.GoIdent{
    68  		GoImportPath: "github.com/containerd/ttrpc",
    69  		GoName:       "Stream",
    70  	})
    71  	gen.ident.serviceDesc = out.QualifiedGoIdent(protogen.GoIdent{
    72  		GoImportPath: "github.com/containerd/ttrpc",
    73  		GoName:       "ServiceDesc",
    74  	})
    75  	gen.ident.streamDesc = out.QualifiedGoIdent(protogen.GoIdent{
    76  		GoImportPath: "github.com/containerd/ttrpc",
    77  		GoName:       "StreamDesc",
    78  	})
    79  
    80  	gen.ident.streamServerIdent = protogen.GoIdent{
    81  		GoImportPath: "github.com/containerd/ttrpc",
    82  		GoName:       "StreamServer",
    83  	}
    84  	gen.ident.streamClientIdent = protogen.GoIdent{
    85  		GoImportPath: "github.com/containerd/ttrpc",
    86  		GoName:       "ClientStream",
    87  	}
    88  	gen.ident.streamServer = out.QualifiedGoIdent(gen.ident.streamServerIdent)
    89  	gen.ident.streamClient = out.QualifiedGoIdent(gen.ident.streamClientIdent)
    90  	return &gen
    91  }
    92  
    93  func generate(plugin *protogen.Plugin, input *protogen.File, servicePrefix string) error {
    94  	if len(input.Services) == 0 {
    95  		// Only generate a Go file if the file has some services.
    96  		return nil
    97  	}
    98  
    99  	file := plugin.NewGeneratedFile(input.GeneratedFilenamePrefix+"_ttrpc.pb.go", input.GoImportPath)
   100  	file.P("// Code generated by protoc-gen-go-ttrpc. DO NOT EDIT.")
   101  	file.P("// source: ", input.Desc.Path())
   102  	file.P("package ", input.GoPackageName)
   103  
   104  	gen := newGenerator(file)
   105  	for _, service := range input.Services {
   106  		service.GoName = servicePrefix + service.GoName
   107  		gen.genService(service)
   108  	}
   109  	return nil
   110  }
   111  
   112  func (gen *generator) genService(service *protogen.Service) {
   113  	fullName := service.Desc.FullName()
   114  	p := gen.out
   115  
   116  	var methods []*protogen.Method
   117  	var streams []*protogen.Method
   118  
   119  	serviceName := service.GoName + "Service"
   120  	p.P("type ", serviceName, " interface{")
   121  	for _, method := range service.Methods {
   122  		var sendArgs, retArgs string
   123  		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
   124  			streams = append(streams, method)
   125  			sendArgs = fmt.Sprintf("%s_%sServer", service.GoName, method.GoName)
   126  			if !method.Desc.IsStreamingClient() {
   127  				sendArgs = fmt.Sprintf("*%s, %s", p.QualifiedGoIdent(method.Input.GoIdent), sendArgs)
   128  			}
   129  			if method.Desc.IsStreamingServer() {
   130  				retArgs = "error"
   131  			} else {
   132  				retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
   133  			}
   134  		} else {
   135  			methods = append(methods, method)
   136  			sendArgs = fmt.Sprintf("*%s", p.QualifiedGoIdent(method.Input.GoIdent))
   137  			retArgs = fmt.Sprintf("(*%s, error)", p.QualifiedGoIdent(method.Output.GoIdent))
   138  		}
   139  		p.P(method.GoName, "(", gen.ident.context, ", ", sendArgs, ") ", retArgs)
   140  	}
   141  	p.P("}")
   142  	p.P()
   143  
   144  	for _, method := range streams {
   145  		structName := strings.ToLower(service.GoName) + method.GoName + "Server"
   146  
   147  		p.P("type ", service.GoName, "_", method.GoName, "Server interface {")
   148  		if method.Desc.IsStreamingServer() {
   149  			p.P("Send(*", method.Output.GoIdent, ") error")
   150  		}
   151  		if method.Desc.IsStreamingClient() {
   152  			p.P("Recv() (*", method.Input.GoIdent, ", error)")
   153  
   154  		}
   155  		p.P(gen.ident.streamServer)
   156  		p.P("}")
   157  		p.P()
   158  
   159  		p.P("type ", structName, " struct {")
   160  		p.P(gen.ident.streamServer)
   161  		p.P("}")
   162  		p.P()
   163  
   164  		if method.Desc.IsStreamingServer() {
   165  			p.P("func (x *", structName, ") Send(m *", method.Output.GoIdent, ") error {")
   166  			p.P("return x.StreamServer.SendMsg(m)")
   167  			p.P("}")
   168  			p.P()
   169  		}
   170  
   171  		if method.Desc.IsStreamingClient() {
   172  			p.P("func (x *", structName, ") Recv() (*", method.Input.GoIdent, ", error) {")
   173  			p.P("m := new(", method.Input.GoIdent, ")")
   174  			p.P("if err := x.StreamServer.RecvMsg(m); err != nil {")
   175  			p.P("return nil, err")
   176  			p.P("}")
   177  			p.P("return m, nil")
   178  			p.P("}")
   179  			p.P()
   180  		}
   181  	}
   182  
   183  	// registration method
   184  	p.P("func Register", serviceName, "(srv *", gen.ident.server, ", svc ", serviceName, "){")
   185  	p.P(`srv.RegisterService("`, fullName, `", &`, gen.ident.serviceDesc, "{")
   186  	if len(methods) > 0 {
   187  		p.P(`Methods: map[string]`, gen.ident.method, "{")
   188  		for _, method := range methods {
   189  			p.P(`"`, method.GoName, `": func(ctx `, gen.ident.context, ", unmarshal func(interface{}) error)(interface{}, error){")
   190  			p.P("var req ", method.Input.GoIdent)
   191  			p.P("if err := unmarshal(&req); err != nil {")
   192  			p.P("return nil, err")
   193  			p.P("}")
   194  			p.P("return svc.", method.GoName, "(ctx, &req)")
   195  			p.P("},")
   196  		}
   197  		p.P("},")
   198  	}
   199  	if len(streams) > 0 {
   200  		p.P(`Streams: map[string]`, gen.ident.stream, "{")
   201  		for _, method := range streams {
   202  			p.P(`"`, method.GoName, `": {`)
   203  			p.P(`Handler: func(ctx `, gen.ident.context, ", stream ", gen.ident.streamServer, ") (interface{}, error) {")
   204  
   205  			structName := strings.ToLower(service.GoName) + method.GoName + "Server"
   206  			var sendArg string
   207  			if !method.Desc.IsStreamingClient() {
   208  				sendArg = "m, "
   209  				p.P("m := new(", method.Input.GoIdent, ")")
   210  				p.P("if err := stream.RecvMsg(m); err != nil {")
   211  				p.P("return nil, err")
   212  				p.P("}")
   213  			}
   214  			if method.Desc.IsStreamingServer() {
   215  				p.P("return nil, svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
   216  			} else {
   217  				p.P("return svc.", method.GoName, "(ctx, ", sendArg, "&", structName, "{stream})")
   218  
   219  			}
   220  			p.P("},")
   221  			if method.Desc.IsStreamingClient() {
   222  				p.P("StreamingClient: true,")
   223  			} else {
   224  				p.P("StreamingClient: false,")
   225  			}
   226  			if method.Desc.IsStreamingServer() {
   227  				p.P("StreamingServer: true,")
   228  			} else {
   229  				p.P("StreamingServer: false,")
   230  			}
   231  			p.P("},")
   232  		}
   233  		p.P("},")
   234  	}
   235  	p.P("})")
   236  	p.P("}")
   237  	p.P()
   238  
   239  	clientType := service.GoName + "Client"
   240  
   241  	// For consistency with ttrpc 1.0 without streaming, just use
   242  	// the service name if no streams are defined
   243  	clientInterface := serviceName
   244  	if len(streams) > 0 {
   245  		clientInterface = clientType
   246  		// Stream client interfaces are different than the server interface
   247  		p.P("type ", clientInterface, " interface{")
   248  		for _, method := range service.Methods {
   249  			if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
   250  				streams = append(streams, method)
   251  				var sendArg string
   252  				if !method.Desc.IsStreamingClient() {
   253  					sendArg = fmt.Sprintf("*%s, ", p.QualifiedGoIdent(method.Input.GoIdent))
   254  				}
   255  				p.P(method.GoName,
   256  					"(", gen.ident.context, ", ", sendArg,
   257  					") (", service.GoName, "_", method.GoName, "Client, error)")
   258  			} else {
   259  				methods = append(methods, method)
   260  				p.P(method.GoName,
   261  					"(", gen.ident.context, ", ",
   262  					"*", method.Input.GoIdent, ")",
   263  					"(*", method.Output.GoIdent, ", error)")
   264  			}
   265  		}
   266  		p.P("}")
   267  		p.P()
   268  	}
   269  
   270  	clientStructType := strings.ToLower(service.GoName) + "Client"
   271  	p.P("type ", clientStructType, " struct{")
   272  	p.P("client *", gen.ident.client)
   273  	p.P("}")
   274  	p.P("func New", clientType, "(client *", gen.ident.client, ")", clientInterface, "{")
   275  	p.P("return &", clientStructType, "{")
   276  	p.P("client:client,")
   277  	p.P("}")
   278  	p.P("}")
   279  	p.P()
   280  
   281  	for _, method := range service.Methods {
   282  		var sendArg string
   283  		if !method.Desc.IsStreamingClient() {
   284  			sendArg = ", req *" + gen.out.QualifiedGoIdent(method.Input.GoIdent)
   285  		}
   286  
   287  		intName := service.GoName + "_" + method.GoName + "Client"
   288  		var retArg string
   289  		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
   290  			retArg = intName
   291  		} else {
   292  			retArg = "*" + gen.out.QualifiedGoIdent(method.Output.GoIdent)
   293  		}
   294  
   295  		p.P("func (c *", clientStructType, ") ", method.GoName,
   296  			"(ctx ", gen.ident.context, "", sendArg, ") ",
   297  			"(", retArg, ", error) {")
   298  
   299  		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
   300  			var streamingClient, streamingServer, req string
   301  			if method.Desc.IsStreamingClient() {
   302  				streamingClient = "true"
   303  				req = "nil"
   304  			} else {
   305  				streamingClient = "false"
   306  				req = "req"
   307  			}
   308  			if method.Desc.IsStreamingServer() {
   309  				streamingServer = "true"
   310  			} else {
   311  				streamingServer = "false"
   312  			}
   313  			p.P("stream, err := c.client.NewStream(ctx, &", gen.ident.streamDesc, "{")
   314  			p.P("StreamingClient: ", streamingClient, ",")
   315  			p.P("StreamingServer: ", streamingServer, ",")
   316  			p.P("}, ", `"`+fullName+`", `, `"`+method.GoName+`", `, req, `)`)
   317  			p.P("if err != nil {")
   318  			p.P("return nil, err")
   319  			p.P("}")
   320  
   321  			structName := strings.ToLower(service.GoName) + method.GoName + "Client"
   322  
   323  			p.P("x := &", structName, "{stream}")
   324  
   325  			p.P("return x, nil")
   326  			p.P("}")
   327  			p.P()
   328  
   329  			// Create interface
   330  			p.P("type ", intName, " interface {")
   331  			if method.Desc.IsStreamingClient() {
   332  				p.P("Send(*", method.Input.GoIdent, ") error")
   333  			}
   334  			if method.Desc.IsStreamingServer() {
   335  				p.P("Recv() (*", method.Output.GoIdent, ", error)")
   336  			} else {
   337  				p.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
   338  			}
   339  
   340  			p.P(gen.ident.streamClient)
   341  			p.P("}")
   342  			p.P()
   343  
   344  			// Create struct
   345  			p.P("type ", structName, " struct {")
   346  			p.P(gen.ident.streamClient)
   347  			p.P("}")
   348  			p.P()
   349  
   350  			if method.Desc.IsStreamingClient() {
   351  				p.P("func (x *", structName, ") Send(m *", method.Input.GoIdent, ") error {")
   352  				p.P("return x.", gen.ident.streamClientIdent.GoName, ".SendMsg(m)")
   353  				p.P("}")
   354  				p.P()
   355  			}
   356  
   357  			if method.Desc.IsStreamingServer() {
   358  				p.P("func (x *", structName, ") Recv() (*", method.Output.GoIdent, ", error) {")
   359  				p.P("m := new(", method.Output.GoIdent, ")")
   360  				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
   361  				p.P("return nil, err")
   362  				p.P("}")
   363  				p.P("return m, nil")
   364  				p.P("}")
   365  				p.P()
   366  			} else {
   367  				p.P("func (x *", structName, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
   368  				p.P("if err := x.ClientStream.CloseSend(); err != nil {")
   369  				p.P("return nil, err")
   370  				p.P("}")
   371  				p.P("m := new(", method.Output.GoIdent, ")")
   372  				p.P("if err := x.ClientStream.RecvMsg(m); err != nil {")
   373  				p.P("return nil, err")
   374  				p.P("}")
   375  				p.P("return m, nil")
   376  				p.P("}")
   377  				p.P()
   378  			}
   379  		} else {
   380  			p.P("var resp ", method.Output.GoIdent)
   381  			p.P(`if err := c.client.Call(ctx, "`, fullName, `", "`, method.Desc.Name(), `", req, &resp); err != nil {`)
   382  			p.P("return nil, err")
   383  			p.P("}")
   384  			p.P("return &resp, nil")
   385  			p.P("}")
   386  			p.P()
   387  		}
   388  	}
   389  }
   390  

View as plain text