...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway/template.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway

     1  package gengateway
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  	"text/template"
     9  
    10  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/casing"
    11  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
    12  	"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
    13  	"google.golang.org/grpc/grpclog"
    14  )
    15  
    16  type param struct {
    17  	*descriptor.File
    18  	Imports            []descriptor.GoPackage
    19  	UseRequestContext  bool
    20  	RegisterFuncSuffix string
    21  	AllowPatchFeature  bool
    22  	OmitPackageDoc     bool
    23  }
    24  
    25  type binding struct {
    26  	*descriptor.Binding
    27  	Registry          *descriptor.Registry
    28  	AllowPatchFeature bool
    29  }
    30  
    31  // GetBodyFieldPath returns the binding body's field path.
    32  func (b binding) GetBodyFieldPath() string {
    33  	if b.Body != nil && len(b.Body.FieldPath) != 0 {
    34  		return b.Body.FieldPath.String()
    35  	}
    36  	return "*"
    37  }
    38  
    39  // GetBodyFieldStructName returns the binding body's struct field name.
    40  func (b binding) GetBodyFieldStructName() (string, error) {
    41  	if b.Body != nil && len(b.Body.FieldPath) != 0 {
    42  		return casing.Camel(b.Body.FieldPath.String()), nil
    43  	}
    44  	return "", errors.New("no body field found")
    45  }
    46  
    47  // HasQueryParam determines if the binding needs parameters in query string.
    48  //
    49  // It sometimes returns true even though actually the binding does not need.
    50  // But it is not serious because it just results in a small amount of extra codes generated.
    51  func (b binding) HasQueryParam() bool {
    52  	if b.Body != nil && len(b.Body.FieldPath) == 0 {
    53  		return false
    54  	}
    55  	fields := make(map[string]bool)
    56  	for _, f := range b.Method.RequestType.Fields {
    57  		fields[f.GetName()] = true
    58  	}
    59  	if b.Body != nil {
    60  		delete(fields, b.Body.FieldPath.String())
    61  	}
    62  	for _, p := range b.PathParams {
    63  		delete(fields, p.FieldPath.String())
    64  	}
    65  	return len(fields) > 0
    66  }
    67  
    68  func (b binding) QueryParamFilter() queryParamFilter {
    69  	var seqs [][]string
    70  	if b.Body != nil {
    71  		seqs = append(seqs, strings.Split(b.Body.FieldPath.String(), "."))
    72  	}
    73  	for _, p := range b.PathParams {
    74  		seqs = append(seqs, strings.Split(p.FieldPath.String(), "."))
    75  	}
    76  	return queryParamFilter{utilities.NewDoubleArray(seqs)}
    77  }
    78  
    79  // HasEnumPathParam returns true if the path parameter slice contains a parameter
    80  // that maps to an enum proto field that is not repeated, if not false is returned.
    81  func (b binding) HasEnumPathParam() bool {
    82  	return b.hasEnumPathParam(false)
    83  }
    84  
    85  // HasRepeatedEnumPathParam returns true if the path parameter slice contains a parameter
    86  // that maps to a repeated enum proto field, if not false is returned.
    87  func (b binding) HasRepeatedEnumPathParam() bool {
    88  	return b.hasEnumPathParam(true)
    89  }
    90  
    91  // hasEnumPathParam returns true if the path parameter slice contains a parameter
    92  // that maps to a enum proto field and that the enum proto field is or isn't repeated
    93  // based on the provided 'repeated' parameter.
    94  func (b binding) hasEnumPathParam(repeated bool) bool {
    95  	for _, p := range b.PathParams {
    96  		if p.IsEnum() && p.IsRepeated() == repeated {
    97  			return true
    98  		}
    99  	}
   100  	return false
   101  }
   102  
   103  // LookupEnum looks up a enum type by path parameter.
   104  func (b binding) LookupEnum(p descriptor.Parameter) *descriptor.Enum {
   105  	e, err := b.Registry.LookupEnum("", p.Target.GetTypeName())
   106  	if err != nil {
   107  		return nil
   108  	}
   109  	return e
   110  }
   111  
   112  // FieldMaskField returns the golang-style name of the variable for a FieldMask, if there is exactly one of that type in
   113  // the message. Otherwise, it returns an empty string.
   114  func (b binding) FieldMaskField() string {
   115  	var fieldMaskField *descriptor.Field
   116  	for _, f := range b.Method.RequestType.Fields {
   117  		if f.GetTypeName() == ".google.protobuf.FieldMask" {
   118  			// if there is more than 1 FieldMask for this request, then return none
   119  			if fieldMaskField != nil {
   120  				return ""
   121  			}
   122  			fieldMaskField = f
   123  		}
   124  	}
   125  	if fieldMaskField != nil {
   126  		return casing.Camel(fieldMaskField.GetName())
   127  	}
   128  	return ""
   129  }
   130  
   131  // queryParamFilter is a wrapper of utilities.DoubleArray which provides String() to output DoubleArray.Encoding in a stable and predictable format.
   132  type queryParamFilter struct {
   133  	*utilities.DoubleArray
   134  }
   135  
   136  func (f queryParamFilter) String() string {
   137  	encodings := make([]string, len(f.Encoding))
   138  	for str, enc := range f.Encoding {
   139  		encodings[enc] = fmt.Sprintf("%q: %d", str, enc)
   140  	}
   141  	e := strings.Join(encodings, ", ")
   142  	return fmt.Sprintf("&utilities.DoubleArray{Encoding: map[string]int{%s}, Base: %#v, Check: %#v}", e, f.Base, f.Check)
   143  }
   144  
   145  type trailerParams struct {
   146  	Services           []*descriptor.Service
   147  	UseRequestContext  bool
   148  	RegisterFuncSuffix string
   149  }
   150  
   151  func applyTemplate(p param, reg *descriptor.Registry) (string, error) {
   152  	w := bytes.NewBuffer(nil)
   153  	if err := headerTemplate.Execute(w, p); err != nil {
   154  		return "", err
   155  	}
   156  	var targetServices []*descriptor.Service
   157  
   158  	for _, msg := range p.Messages {
   159  		msgName := casing.Camel(*msg.Name)
   160  		msg.Name = &msgName
   161  	}
   162  
   163  	for _, svc := range p.Services {
   164  		var methodWithBindingsSeen bool
   165  		svcName := casing.Camel(*svc.Name)
   166  		svc.Name = &svcName
   167  
   168  		for _, meth := range svc.Methods {
   169  			if grpclog.V(2) {
   170  				grpclog.Infof("Processing %s.%s", svc.GetName(), meth.GetName())
   171  			}
   172  			methName := casing.Camel(*meth.Name)
   173  			meth.Name = &methName
   174  			for _, b := range meth.Bindings {
   175  				if err := reg.CheckDuplicateAnnotation(b.HTTPMethod, b.PathTmpl.Template, svc); err != nil {
   176  					return "", err
   177  				}
   178  
   179  				methodWithBindingsSeen = true
   180  				if err := handlerTemplate.Execute(w, binding{
   181  					Binding:           b,
   182  					Registry:          reg,
   183  					AllowPatchFeature: p.AllowPatchFeature,
   184  				}); err != nil {
   185  					return "", err
   186  				}
   187  
   188  				// Local
   189  				if err := localHandlerTemplate.Execute(w, binding{
   190  					Binding:           b,
   191  					Registry:          reg,
   192  					AllowPatchFeature: p.AllowPatchFeature,
   193  				}); err != nil {
   194  					return "", err
   195  				}
   196  			}
   197  		}
   198  		if methodWithBindingsSeen {
   199  			targetServices = append(targetServices, svc)
   200  		}
   201  	}
   202  	if len(targetServices) == 0 {
   203  		return "", errNoTargetService
   204  	}
   205  
   206  	tp := trailerParams{
   207  		Services:           targetServices,
   208  		UseRequestContext:  p.UseRequestContext,
   209  		RegisterFuncSuffix: p.RegisterFuncSuffix,
   210  	}
   211  	// Local
   212  	if err := localTrailerTemplate.Execute(w, tp); err != nil {
   213  		return "", err
   214  	}
   215  
   216  	if err := trailerTemplate.Execute(w, tp); err != nil {
   217  		return "", err
   218  	}
   219  	return w.String(), nil
   220  }
   221  
   222  var (
   223  	headerTemplate = template.Must(template.New("header").Parse(`
   224  // Code generated by protoc-gen-grpc-gateway. DO NOT EDIT.
   225  // source: {{.GetName}}
   226  
   227  {{if not .OmitPackageDoc}}/*
   228  Package {{.GoPkg.Name}} is a reverse proxy.
   229  
   230  It translates gRPC into RESTful JSON APIs.
   231  */{{end}}
   232  package {{.GoPkg.Name}}
   233  import (
   234  	{{range $i := .Imports}}{{if $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
   235  
   236  	{{range $i := .Imports}}{{if not $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
   237  )
   238  
   239  // Suppress "imported and not used" errors
   240  var _ codes.Code
   241  var _ io.Reader
   242  var _ status.Status
   243  var _ = runtime.String
   244  var _ = utilities.NewDoubleArray
   245  var _ = metadata.Join
   246  `))
   247  
   248  	handlerTemplate = template.Must(template.New("handler").Parse(`
   249  {{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
   250  {{template "bidi-streaming-request-func" .}}
   251  {{else if .Method.GetClientStreaming}}
   252  {{template "client-streaming-request-func" .}}
   253  {{else}}
   254  {{template "client-rpc-request-func" .}}
   255  {{end}}
   256  `))
   257  
   258  	_ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(`
   259  {{if .Method.GetServerStreaming}}
   260  func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error)
   261  {{else}}
   262  func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
   263  {{end}}`, "\n", "")))
   264  
   265  	_ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(`
   266  {{template "request-func-signature" .}} {
   267  	var metadata runtime.ServerMetadata
   268  	stream, err := client.{{.Method.GetName}}(ctx)
   269  	if err != nil {
   270  		grpclog.Infof("Failed to start streaming: %v", err)
   271  		return nil, metadata, err
   272  	}
   273  	dec := marshaler.NewDecoder(req.Body)
   274  	for {
   275  		var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
   276  		err = dec.Decode(&protoReq)
   277  		if err == io.EOF {
   278  			break
   279  		}
   280  		if err != nil {
   281  			grpclog.Infof("Failed to decode request: %v", err)
   282  			return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   283  		}
   284  		if err = stream.Send(&protoReq); err != nil {
   285  			if err == io.EOF {
   286  				break
   287  			}
   288  			grpclog.Infof("Failed to send request: %v", err)
   289  			return nil, metadata, err
   290  		}
   291  	}
   292  
   293  	if err := stream.CloseSend(); err != nil {
   294  		grpclog.Infof("Failed to terminate client stream: %v", err)
   295  		return nil, metadata, err
   296  	}
   297  	header, err := stream.Header()
   298  	if err != nil {
   299  		grpclog.Infof("Failed to get header from client: %v", err)
   300  		return nil, metadata, err
   301  	}
   302  	metadata.HeaderMD = header
   303  {{if .Method.GetServerStreaming}}
   304  	return stream, metadata, nil
   305  {{else}}
   306  	msg, err := stream.CloseAndRecv()
   307  	metadata.TrailerMD = stream.Trailer()
   308  	return msg, metadata, err
   309  {{end}}
   310  }
   311  `))
   312  
   313  	funcMap template.FuncMap = map[string]interface{}{
   314  		"camelIdentifier": casing.CamelIdentifier,
   315  	}
   316  
   317  	_ = template.Must(handlerTemplate.New("client-rpc-request-func").Funcs(funcMap).Parse(`
   318  {{$AllowPatchFeature := .AllowPatchFeature}}
   319  {{if .HasQueryParam}}
   320  var (
   321  	filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter}}
   322  )
   323  {{end}}
   324  {{template "request-func-signature" .}} {
   325  	var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
   326  	var metadata runtime.ServerMetadata
   327  {{if .Body}}
   328  	{{- $isFieldMask := and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
   329  	{{- if $isFieldMask }}
   330  	newReader, berr := utilities.IOReaderFactory(req.Body)
   331  	if berr != nil {
   332  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
   333  	}
   334  	{{- end}}
   335  	{{- $protoReq := .Body.AssignableExprPrep "protoReq" .Method.Service.File.GoPkg.Path -}}
   336  	{{- if ne "" $protoReq }}
   337  	{{printf "%s" $protoReq }}
   338  	{{- end}}
   339  	{{- if not $isFieldMask }}
   340  	if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF  {
   341  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   342  	}
   343  	{{end}}
   344  	{{- if $isFieldMask }}
   345  	if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF  {
   346  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   347  	}
   348  	if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
   349  			if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), protoReq.{{.GetBodyFieldStructName}}); err != nil {
   350  				return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   351  			} else {
   352  				protoReq.{{.FieldMaskField}} = fieldMask
   353  			}
   354  	}
   355  	{{end}}
   356  {{end}}
   357  {{if .PathParams}}
   358  	var (
   359  		val string
   360  {{- if .HasEnumPathParam}}
   361  		e int32
   362  {{- end}}
   363  {{- if .HasRepeatedEnumPathParam}}
   364  		es []int32
   365  {{- end}}
   366  		ok bool
   367  		err error
   368  		_ = err
   369  	)
   370  	{{$binding := .}}
   371  	{{range $param := .PathParams}}
   372  	{{$enum := $binding.LookupEnum $param}}
   373  	val, ok = pathParams[{{$param | printf "%q"}}]
   374  	if !ok {
   375  		return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}})
   376  	}
   377  {{if $param.IsNestedProto3}}
   378  	err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val)
   379  	if err != nil {
   380  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   381  	}
   382  	{{if $enum}}
   383  		e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path | camelIdentifier}}_value)
   384  		if err != nil {
   385  			return nil, metadata, status.Errorf(codes.InvalidArgument, "could not parse path as enum value, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   386  		}
   387  	{{end}}
   388  {{else if $enum}}
   389  	e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path | camelIdentifier}}_value)
   390  	if err != nil {
   391  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   392  	}
   393  {{else}}
   394  	{{- $protoReq := $param.AssignableExprPrep "protoReq" $binding.Method.Service.File.GoPkg.Path -}}
   395  	{{- if ne "" $protoReq }}
   396  	{{printf "%s" $protoReq }}
   397  	{{- end}}
   398  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}})
   399  	if err != nil {
   400  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   401  	}
   402  {{end}}
   403  {{if and $enum $param.IsRepeated}}
   404  	s := make([]{{$enum.GoType $param.Method.Service.File.GoPkg.Path}}, len(es))
   405  	for i, v := range es {
   406  		s[i] = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(v)
   407  	}
   408  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}} = s
   409  {{else if $enum}}
   410  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}} = {{$enum.GoType $param.Method.Service.File.GoPkg.Path | camelIdentifier}}(e)
   411  {{end}}
   412  	{{end}}
   413  {{end}}
   414  {{if .HasQueryParam}}
   415  	if err := req.ParseForm(); err != nil {
   416  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   417  	}
   418  	if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil {
   419  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   420  	}
   421  {{end}}
   422  {{if .Method.GetServerStreaming}}
   423  	stream, err := client.{{.Method.GetName}}(ctx, &protoReq)
   424  	if err != nil {
   425  		return nil, metadata, err
   426  	}
   427  	header, err := stream.Header()
   428  	if err != nil {
   429  		return nil, metadata, err
   430  	}
   431  	metadata.HeaderMD = header
   432  	return stream, metadata, nil
   433  {{else}}
   434  	msg, err := client.{{.Method.GetName}}(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
   435  	return msg, metadata, err
   436  {{end}}
   437  }`))
   438  
   439  	_ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(`
   440  {{template "request-func-signature" .}} {
   441  	var metadata runtime.ServerMetadata
   442  	stream, err := client.{{.Method.GetName}}(ctx)
   443  	if err != nil {
   444  		grpclog.Infof("Failed to start streaming: %v", err)
   445  		return nil, metadata, err
   446  	}
   447  	dec := marshaler.NewDecoder(req.Body)
   448  	handleSend := func() error {
   449  		var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
   450  		err := dec.Decode(&protoReq)
   451  		if err == io.EOF {
   452  			return err
   453  		}
   454  		if err != nil {
   455  			grpclog.Infof("Failed to decode request: %v", err)
   456  			return err
   457  		}
   458  		if err := stream.Send(&protoReq); err != nil {
   459  			grpclog.Infof("Failed to send request: %v", err)
   460  			return err
   461  		}
   462  		return nil
   463  	}
   464  	go func() {
   465  		for {
   466  			if err := handleSend(); err != nil {
   467  				break
   468  			}
   469  		}
   470  		if err := stream.CloseSend(); err != nil {
   471  			grpclog.Infof("Failed to terminate client stream: %v", err)
   472  		}
   473  	}()
   474  	header, err := stream.Header()
   475  	if err != nil {
   476  		grpclog.Infof("Failed to get header from client: %v", err)
   477  		return nil, metadata, err
   478  	}
   479  	metadata.HeaderMD = header
   480  	return stream, metadata, nil
   481  }
   482  `))
   483  
   484  	localHandlerTemplate = template.Must(template.New("local-handler").Parse(`
   485  {{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
   486  {{else if .Method.GetClientStreaming}}
   487  {{else if .Method.GetServerStreaming}}
   488  {{else}}
   489  {{template "local-client-rpc-request-func" .}}
   490  {{end}}
   491  `))
   492  
   493  	_ = template.Must(localHandlerTemplate.New("local-request-func-signature").Parse(strings.ReplaceAll(`
   494  {{if .Method.GetServerStreaming}}
   495  {{else}}
   496  func local_request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, server {{.Method.Service.InstanceName}}Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
   497  {{end}}`, "\n", "")))
   498  
   499  	_ = template.Must(localHandlerTemplate.New("local-client-rpc-request-func").Funcs(funcMap).Parse(`
   500  {{$AllowPatchFeature := .AllowPatchFeature}}
   501  {{template "local-request-func-signature" .}} {
   502  	var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
   503  	var metadata runtime.ServerMetadata
   504  {{if .Body}}
   505  	{{- $isFieldMask := and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
   506  	{{- if $isFieldMask }}
   507  	newReader, berr := utilities.IOReaderFactory(req.Body)
   508  	if berr != nil {
   509  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
   510  	}
   511  	{{- end}}
   512  	{{- $protoReq := .Body.AssignableExprPrep "protoReq" .Method.Service.File.GoPkg.Path -}}
   513  	{{- if ne "" $protoReq }}
   514  	{{printf "%s" $protoReq }}
   515  	{{- end}}
   516  	{{- if not $isFieldMask }}
   517  	if err := marshaler.NewDecoder(req.Body).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF  {
   518  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   519  	}
   520  	{{end}}
   521  	{{- if $isFieldMask }}
   522  	if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq" .Method.Service.File.GoPkg.Path}}); err != nil && err != io.EOF  {
   523  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   524  	}
   525  	if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
   526  			if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), protoReq.{{.GetBodyFieldStructName}}); err != nil {
   527  				return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   528  			} else {
   529  				protoReq.{{.FieldMaskField}} = fieldMask
   530  			}
   531  	}
   532  	{{end}}
   533  {{end}}
   534  {{if .PathParams}}
   535  	var (
   536  		val string
   537  {{- if .HasEnumPathParam}}
   538  		e int32
   539  {{- end}}
   540  {{- if .HasRepeatedEnumPathParam}}
   541  		es []int32
   542  {{- end}}
   543  		ok bool
   544  		err error
   545  		_ = err
   546  	)
   547  	{{$binding := .}}
   548  	{{range $param := .PathParams}}
   549  	{{$enum := $binding.LookupEnum $param}}
   550  	val, ok = pathParams[{{$param | printf "%q"}}]
   551  	if !ok {
   552  		return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}})
   553  	}
   554  {{if $param.IsNestedProto3}}
   555  	err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val)
   556  	if err != nil {
   557  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   558  	}
   559  	{{if $enum}}
   560  		e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path | camelIdentifier}}_value)
   561  		if err != nil {
   562  			return nil, metadata, status.Errorf(codes.InvalidArgument, "could not parse path as enum value, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   563  		}
   564  	{{end}}
   565  {{else if $enum}}
   566  	e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType  $param.Method.Service.File.GoPkg.Path | camelIdentifier}}_value)
   567  	if err != nil {
   568  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   569  	}
   570  {{else}}
   571  	{{- $protoReq := $param.AssignableExprPrep "protoReq" $binding.Method.Service.File.GoPkg.Path -}}
   572  	{{- if ne "" $protoReq }}
   573  	{{printf "%s" $protoReq }}
   574  	{{- end}}
   575  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}})
   576  	if err != nil {
   577  		return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
   578  	}
   579  {{end}}
   580  
   581  {{if and $enum $param.IsRepeated}}
   582  	s := make([]{{$enum.GoType $param.Method.Service.File.GoPkg.Path}}, len(es))
   583  	for i, v := range es {
   584  		s[i] = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(v)
   585  	}
   586  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}} = s
   587  {{else if $enum}}
   588  	{{$param.AssignableExpr "protoReq" $binding.Method.Service.File.GoPkg.Path}} = {{$enum.GoType $param.Method.Service.File.GoPkg.Path | camelIdentifier}}(e)
   589  {{end}}
   590  	{{end}}
   591  {{end}}
   592  {{if .HasQueryParam}}
   593  	if err := req.ParseForm(); err != nil {
   594  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   595  	}
   596  	if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil {
   597  		return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
   598  	}
   599  {{end}}
   600  {{if .Method.GetServerStreaming}}
   601  	// TODO
   602  {{else}}
   603  	msg, err := server.{{.Method.GetName}}(ctx, &protoReq)
   604  	return msg, metadata, err
   605  {{end}}
   606  }`))
   607  
   608  	localTrailerTemplate = template.Must(template.New("local-trailer").Parse(`
   609  {{$UseRequestContext := .UseRequestContext}}
   610  {{range $svc := .Services}}
   611  // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server registers the http handlers for service {{$svc.GetName}} to "mux".
   612  // UnaryRPC     :call {{$svc.GetName}}Server directly.
   613  // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906.
   614  // Note that using this registration option will cause many gRPC library features to stop working. Consider using Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint instead.
   615  func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context, mux *runtime.ServeMux, server {{$svc.InstanceName}}Server) error {
   616  	{{range $m := $svc.Methods}}
   617  	{{range $b := $m.Bindings}}
   618  	{{if or $m.GetClientStreaming $m.GetServerStreaming}}
   619  	mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
   620  		err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")
   621  		_, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
   622  		runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
   623  		return
   624  	})
   625  	{{else}}
   626  	mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
   627  	{{- if $UseRequestContext }}
   628  		ctx, cancel := context.WithCancel(req.Context())
   629  	{{- else -}}
   630  		ctx, cancel := context.WithCancel(ctx)
   631  	{{- end }}
   632  		defer cancel()
   633  		var stream runtime.ServerTransportStream
   634  		ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
   635  		inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
   636  		var err error
   637  		var annotatedContext context.Context
   638  		{{- if $b.PathTmpl }}
   639  		annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/{{$svc.File.GetPackage}}.{{$svc.GetName}}/{{$m.GetName}}", runtime.WithHTTPPathPattern("{{$b.PathTmpl.Template}}"))
   640  		{{- else -}}
   641  		annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/{{$svc.File.GetPackage}}.{{$svc.GetName}}/{{$m.GetName}}")
   642  		{{- end }}
   643  		if err != nil {
   644  			runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
   645  			return
   646  		}
   647  		resp, md, err := local_request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, server, req, pathParams)
   648  		md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
   649  		annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
   650  		if err != nil {
   651  			runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
   652  			return
   653  		}
   654  
   655  		{{ if $b.ResponseBody }}
   656  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
   657  		{{ else }}
   658  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
   659  		{{end}}
   660  	})
   661  	{{end}}
   662  	{{end}}
   663  	{{end}}
   664  	return nil
   665  }
   666  {{end}}`))
   667  
   668  	trailerTemplate = template.Must(template.New("trailer").Parse(`
   669  {{$UseRequestContext := .UseRequestContext}}
   670  {{range $svc := .Services}}
   671  // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint is same as Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} but
   672  // automatically dials to "endpoint" and closes the connection when "ctx" gets done.
   673  func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
   674  	conn, err := grpc.DialContext(ctx, endpoint, opts...)
   675  	if err != nil {
   676  		return err
   677  	}
   678  	defer func() {
   679  		if err != nil {
   680  			if cerr := conn.Close(); cerr != nil {
   681  				grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
   682  			}
   683  			return
   684  		}
   685  		go func() {
   686  			<-ctx.Done()
   687  			if cerr := conn.Close(); cerr != nil {
   688  				grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
   689  			}
   690  		}()
   691  	}()
   692  
   693  	return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx, mux, conn)
   694  }
   695  
   696  // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} registers the http handlers for service {{$svc.GetName}} to "mux".
   697  // The handlers forward requests to the grpc endpoint over "conn".
   698  func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
   699  	return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx, mux, {{$svc.ClientConstructorName}}(conn))
   700  }
   701  
   702  // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client registers the http handlers for service {{$svc.GetName}}
   703  // to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "{{$svc.InstanceName}}Client".
   704  // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "{{$svc.InstanceName}}Client"
   705  // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in
   706  // "{{$svc.InstanceName}}Client" to call the correct interceptors.
   707  func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, mux *runtime.ServeMux, client {{$svc.InstanceName}}Client) error {
   708  	{{range $m := $svc.Methods}}
   709  	{{range $b := $m.Bindings}}
   710  	mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
   711  	{{- if $UseRequestContext }}
   712  		ctx, cancel := context.WithCancel(req.Context())
   713  	{{- else -}}
   714  		ctx, cancel := context.WithCancel(ctx)
   715  	{{- end }}
   716  		defer cancel()
   717  		inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
   718  		var err error
   719  		var annotatedContext context.Context
   720  		{{- if $b.PathTmpl }}
   721  		annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/{{$svc.File.GetPackage}}.{{$svc.GetName}}/{{$m.GetName}}", runtime.WithHTTPPathPattern("{{$b.PathTmpl.Template}}"))
   722  		{{- else -}}
   723  		annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/{{$svc.File.GetPackage}}.{{$svc.GetName}}/{{$m.GetName}}")
   724  		{{- end }}
   725  		if err != nil {
   726  			runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
   727  			return
   728  		}
   729  		resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams)
   730  		annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
   731  		if err != nil {
   732  			runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
   733  			return
   734  		}
   735  		{{if $m.GetServerStreaming}}
   736  		{{ if $b.ResponseBody }}
   737  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) {
   738  			res, err := resp.Recv()
   739  			return response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{res}, err
   740  		}, mux.GetForwardResponseOptions()...)
   741  		{{ else }}
   742  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...)
   743  		{{end}}
   744  		{{else}}
   745  		{{ if $b.ResponseBody }}
   746  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
   747  		{{ else }}
   748  		forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
   749  		{{end}}
   750  		{{end}}
   751  	})
   752  	{{end}}
   753  	{{end}}
   754  	return nil
   755  }
   756  
   757  {{range $m := $svc.Methods}}
   758  {{range $b := $m.Bindings}}
   759  {{if $b.ResponseBody}}
   760  type response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} struct {
   761  	proto.Message
   762  }
   763  
   764  func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody() interface{} {
   765  	response := m.Message.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})
   766  	return {{$b.ResponseBody.AssignableExpr "response" $m.Service.File.GoPkg.Path}}
   767  }
   768  {{end}}
   769  {{end}}
   770  {{end}}
   771  
   772  var (
   773  	{{range $m := $svc.Methods}}
   774  	{{range $b := $m.Bindings}}
   775  	pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}}))
   776  	{{end}}
   777  	{{end}}
   778  )
   779  
   780  var (
   781  	{{range $m := $svc.Methods}}
   782  	{{range $b := $m.Bindings}}
   783  	forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}}
   784  	{{end}}
   785  	{{end}}
   786  )
   787  {{end}}`))
   788  )
   789  

View as plain text