1 package gengateway
2
3 import (
4 "bytes"
5 "errors"
6 "fmt"
7 "strings"
8 "text/template"
9
10 "github.com/golang/glog"
11 "github.com/grpc-ecosystem/grpc-gateway/internal/casing"
12 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
13 "github.com/grpc-ecosystem/grpc-gateway/utilities"
14 )
15
16 type param struct {
17 *descriptor.File
18 Imports []descriptor.GoPackage
19 UseRequestContext bool
20 RegisterFuncSuffix string
21 AllowPatchFeature bool
22 OmitPackageDoc bool
23 }
24
25 type binding struct {
26 *descriptor.Binding
27 Registry *descriptor.Registry
28 AllowPatchFeature bool
29 }
30
31
32 func (b binding) GetBodyFieldPath() string {
33 if b.Body != nil && len(b.Body.FieldPath) != 0 {
34 return b.Body.FieldPath.String()
35 }
36 return "*"
37 }
38
39
40 func (b binding) GetBodyFieldStructName() (string, error) {
41 if b.Body != nil && len(b.Body.FieldPath) != 0 {
42 return casing.Camel(b.Body.FieldPath.String()), nil
43 }
44 return "", errors.New("No body field found")
45 }
46
47
48
49
50
51 func (b binding) HasQueryParam() bool {
52 if b.Body != nil && len(b.Body.FieldPath) == 0 {
53 return false
54 }
55 fields := make(map[string]bool)
56 for _, f := range b.Method.RequestType.Fields {
57 fields[f.GetName()] = true
58 }
59 if b.Body != nil {
60 delete(fields, b.Body.FieldPath.String())
61 }
62 for _, p := range b.PathParams {
63 delete(fields, p.FieldPath.String())
64 }
65 return len(fields) > 0
66 }
67
68 func (b binding) QueryParamFilter() queryParamFilter {
69 var seqs [][]string
70 if b.Body != nil {
71 seqs = append(seqs, strings.Split(b.Body.FieldPath.String(), "."))
72 }
73 for _, p := range b.PathParams {
74 seqs = append(seqs, strings.Split(p.FieldPath.String(), "."))
75 }
76 return queryParamFilter{utilities.NewDoubleArray(seqs)}
77 }
78
79
80
81 func (b binding) HasEnumPathParam() bool {
82 return b.hasEnumPathParam(false)
83 }
84
85
86
87 func (b binding) HasRepeatedEnumPathParam() bool {
88 return b.hasEnumPathParam(true)
89 }
90
91
92
93
94 func (b binding) hasEnumPathParam(repeated bool) bool {
95 for _, p := range b.PathParams {
96 if p.IsEnum() && p.IsRepeated() == repeated {
97 return true
98 }
99 }
100 return false
101 }
102
103
104 func (b binding) LookupEnum(p descriptor.Parameter) *descriptor.Enum {
105 e, err := b.Registry.LookupEnum("", p.Target.GetTypeName())
106 if err != nil {
107 return nil
108 }
109 return e
110 }
111
112
113
114 func (b binding) FieldMaskField() string {
115 var fieldMaskField *descriptor.Field
116 for _, f := range b.Method.RequestType.Fields {
117 if f.GetTypeName() == ".google.protobuf.FieldMask" {
118
119 if fieldMaskField != nil {
120 return ""
121 }
122 fieldMaskField = f
123 }
124 }
125 if fieldMaskField != nil {
126 return casing.Camel(fieldMaskField.GetName())
127 }
128 return ""
129 }
130
131
132 type queryParamFilter struct {
133 *utilities.DoubleArray
134 }
135
136 func (f queryParamFilter) String() string {
137 encodings := make([]string, len(f.Encoding))
138 for str, enc := range f.Encoding {
139 encodings[enc] = fmt.Sprintf("%q: %d", str, enc)
140 }
141 e := strings.Join(encodings, ", ")
142 return fmt.Sprintf("&utilities.DoubleArray{Encoding: map[string]int{%s}, Base: %#v, Check: %#v}", e, f.Base, f.Check)
143 }
144
145 type trailerParams struct {
146 Services []*descriptor.Service
147 UseRequestContext bool
148 RegisterFuncSuffix string
149 AssumeColonVerb bool
150 }
151
152 func applyTemplate(p param, reg *descriptor.Registry) (string, error) {
153 w := bytes.NewBuffer(nil)
154 if err := headerTemplate.Execute(w, p); err != nil {
155 return "", err
156 }
157 var targetServices []*descriptor.Service
158
159 for _, msg := range p.Messages {
160 msgName := casing.Camel(*msg.Name)
161 msg.Name = &msgName
162 }
163 for _, svc := range p.Services {
164 var methodWithBindingsSeen bool
165 svcName := casing.Camel(*svc.Name)
166 svc.Name = &svcName
167 for _, meth := range svc.Methods {
168 glog.V(2).Infof("Processing %s.%s", svc.GetName(), meth.GetName())
169 methName := casing.Camel(*meth.Name)
170 meth.Name = &methName
171 for _, b := range meth.Bindings {
172 methodWithBindingsSeen = true
173 if err := handlerTemplate.Execute(w, binding{
174 Binding: b,
175 Registry: reg,
176 AllowPatchFeature: p.AllowPatchFeature,
177 }); err != nil {
178 return "", err
179 }
180
181
182 if err := localHandlerTemplate.Execute(w, binding{
183 Binding: b,
184 Registry: reg,
185 AllowPatchFeature: p.AllowPatchFeature,
186 }); err != nil {
187 return "", err
188 }
189 }
190 }
191 if methodWithBindingsSeen {
192 targetServices = append(targetServices, svc)
193 }
194 }
195 if len(targetServices) == 0 {
196 return "", errNoTargetService
197 }
198
199 assumeColonVerb := true
200 if reg != nil {
201 assumeColonVerb = !reg.GetAllowColonFinalSegments()
202 }
203 tp := trailerParams{
204 Services: targetServices,
205 UseRequestContext: p.UseRequestContext,
206 RegisterFuncSuffix: p.RegisterFuncSuffix,
207 AssumeColonVerb: assumeColonVerb,
208 }
209
210 if err := localTrailerTemplate.Execute(w, tp); err != nil {
211 return "", err
212 }
213
214 if err := trailerTemplate.Execute(w, tp); err != nil {
215 return "", err
216 }
217 return w.String(), nil
218 }
219
220 var (
221 headerTemplate = template.Must(template.New("header").Parse(`
222 // Code generated by protoc-gen-grpc-gateway. DO NOT EDIT.
223 // source: {{.GetName}}
224
225 {{if not .OmitPackageDoc}}/*
226 Package {{.GoPkg.Name}} is a reverse proxy.
227
228 It translates gRPC into RESTful JSON APIs.
229 */{{end}}
230 package {{.GoPkg.Name}}
231 import (
232 {{range $i := .Imports}}{{if $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
233
234 {{range $i := .Imports}}{{if not $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}}
235 )
236
237 // Suppress "imported and not used" errors
238 var _ codes.Code
239 var _ io.Reader
240 var _ status.Status
241 var _ = runtime.String
242 var _ = utilities.NewDoubleArray
243 var _ = descriptor.ForMessage
244 var _ = metadata.Join
245 `))
246
247 handlerTemplate = template.Must(template.New("handler").Parse(`
248 {{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
249 {{template "bidi-streaming-request-func" .}}
250 {{else if .Method.GetClientStreaming}}
251 {{template "client-streaming-request-func" .}}
252 {{else}}
253 {{template "client-rpc-request-func" .}}
254 {{end}}
255 `))
256
257 _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.Replace(`
258 {{if .Method.GetServerStreaming}}
259 func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error)
260 {{else}}
261 func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
262 {{end}}`, "\n", "", -1)))
263
264 _ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(`
265 {{template "request-func-signature" .}} {
266 var metadata runtime.ServerMetadata
267 stream, err := client.{{.Method.GetName}}(ctx)
268 if err != nil {
269 grpclog.Infof("Failed to start streaming: %v", err)
270 return nil, metadata, err
271 }
272 dec := marshaler.NewDecoder(req.Body)
273 for {
274 var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
275 err = dec.Decode(&protoReq)
276 if err == io.EOF {
277 break
278 }
279 if err != nil {
280 grpclog.Infof("Failed to decode request: %v", err)
281 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
282 }
283 if err = stream.Send(&protoReq); err != nil {
284 if err == io.EOF {
285 break
286 }
287 grpclog.Infof("Failed to send request: %v", err)
288 return nil, metadata, err
289 }
290 }
291
292 if err := stream.CloseSend(); err != nil {
293 grpclog.Infof("Failed to terminate client stream: %v", err)
294 return nil, metadata, err
295 }
296 header, err := stream.Header()
297 if err != nil {
298 grpclog.Infof("Failed to get header from client: %v", err)
299 return nil, metadata, err
300 }
301 metadata.HeaderMD = header
302 {{if .Method.GetServerStreaming}}
303 return stream, metadata, nil
304 {{else}}
305 msg, err := stream.CloseAndRecv()
306 metadata.TrailerMD = stream.Trailer()
307 return msg, metadata, err
308 {{end}}
309 }
310 `))
311
312 _ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(`
313 {{$AllowPatchFeature := .AllowPatchFeature}}
314 {{if .HasQueryParam}}
315 var (
316 filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter}}
317 )
318 {{end}}
319 {{template "request-func-signature" .}} {
320 var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
321 var metadata runtime.ServerMetadata
322 {{if .Body}}
323 newReader, berr := utilities.IOReaderFactory(req.Body)
324 if berr != nil {
325 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
326 }
327 if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq"}}); err != nil && err != io.EOF {
328 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
329 }
330 {{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
331 if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
332 _, md := descriptor.ForMessage(protoReq.{{.GetBodyFieldStructName}})
333 if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil {
334 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
335 } else {
336 protoReq.{{.FieldMaskField}} = fieldMask
337 }
338 }
339 {{end}}
340 {{end}}
341 {{if .PathParams}}
342 var (
343 val string
344 {{- if .HasEnumPathParam}}
345 e int32
346 {{- end}}
347 {{- if .HasRepeatedEnumPathParam}}
348 es []int32
349 {{- end}}
350 ok bool
351 err error
352 _ = err
353 )
354 {{$binding := .}}
355 {{range $param := .PathParams}}
356 {{$enum := $binding.LookupEnum $param}}
357 val, ok = pathParams[{{$param | printf "%q"}}]
358 if !ok {
359 return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}})
360 }
361 {{if $param.IsNestedProto3}}
362 err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val)
363 {{if $enum}}
364 e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}_value)
365 {{end}}
366 {{else if $enum}}
367 e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}_value)
368 {{else}}
369 {{$param.AssignableExpr "protoReq"}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}})
370 {{end}}
371 if err != nil {
372 return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
373 }
374 {{if and $enum $param.IsRepeated}}
375 s := make([]{{$enum.GoType $param.Method.Service.File.GoPkg.Path}}, len(es))
376 for i, v := range es {
377 s[i] = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(v)
378 }
379 {{$param.AssignableExpr "protoReq"}} = s
380 {{else if $enum}}
381 {{$param.AssignableExpr "protoReq"}} = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(e)
382 {{end}}
383 {{end}}
384 {{end}}
385 {{if .HasQueryParam}}
386 if err := req.ParseForm(); err != nil {
387 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
388 }
389 if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil {
390 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
391 }
392 {{end}}
393 {{if .Method.GetServerStreaming}}
394 stream, err := client.{{.Method.GetName}}(ctx, &protoReq)
395 if err != nil {
396 return nil, metadata, err
397 }
398 header, err := stream.Header()
399 if err != nil {
400 return nil, metadata, err
401 }
402 metadata.HeaderMD = header
403 return stream, metadata, nil
404 {{else}}
405 msg, err := client.{{.Method.GetName}}(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
406 return msg, metadata, err
407 {{end}}
408 }`))
409
410 _ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(`
411 {{template "request-func-signature" .}} {
412 var metadata runtime.ServerMetadata
413 stream, err := client.{{.Method.GetName}}(ctx)
414 if err != nil {
415 grpclog.Infof("Failed to start streaming: %v", err)
416 return nil, metadata, err
417 }
418 dec := marshaler.NewDecoder(req.Body)
419 handleSend := func() error {
420 var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
421 err := dec.Decode(&protoReq)
422 if err == io.EOF {
423 return err
424 }
425 if err != nil {
426 grpclog.Infof("Failed to decode request: %v", err)
427 return err
428 }
429 if err := stream.Send(&protoReq); err != nil {
430 grpclog.Infof("Failed to send request: %v", err)
431 return err
432 }
433 return nil
434 }
435 if err := handleSend(); err != nil {
436 if cerr := stream.CloseSend(); cerr != nil {
437 grpclog.Infof("Failed to terminate client stream: %v", cerr)
438 }
439 if err == io.EOF {
440 return stream, metadata, nil
441 }
442 return nil, metadata, err
443 }
444 go func() {
445 for {
446 if err := handleSend(); err != nil {
447 break
448 }
449 }
450 if err := stream.CloseSend(); err != nil {
451 grpclog.Infof("Failed to terminate client stream: %v", err)
452 }
453 }()
454 header, err := stream.Header()
455 if err != nil {
456 grpclog.Infof("Failed to get header from client: %v", err)
457 return nil, metadata, err
458 }
459 metadata.HeaderMD = header
460 return stream, metadata, nil
461 }
462 `))
463
464 localHandlerTemplate = template.Must(template.New("local-handler").Parse(`
465 {{if and .Method.GetClientStreaming .Method.GetServerStreaming}}
466 {{else if .Method.GetClientStreaming}}
467 {{else if .Method.GetServerStreaming}}
468 {{else}}
469 {{template "local-client-rpc-request-func" .}}
470 {{end}}
471 `))
472
473 _ = template.Must(localHandlerTemplate.New("local-request-func-signature").Parse(strings.Replace(`
474 {{if .Method.GetServerStreaming}}
475 {{else}}
476 func local_request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, server {{.Method.Service.GetName}}Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
477 {{end}}`, "\n", "", -1)))
478
479 _ = template.Must(localHandlerTemplate.New("local-client-rpc-request-func").Parse(`
480 {{$AllowPatchFeature := .AllowPatchFeature}}
481 {{template "local-request-func-signature" .}} {
482 var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}}
483 var metadata runtime.ServerMetadata
484 {{if .Body}}
485 newReader, berr := utilities.IOReaderFactory(req.Body)
486 if berr != nil {
487 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
488 }
489 if err := marshaler.NewDecoder(newReader()).Decode(&{{.Body.AssignableExpr "protoReq"}}); err != nil && err != io.EOF {
490 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
491 }
492 {{- if and $AllowPatchFeature (eq (.HTTPMethod) "PATCH") (.FieldMaskField) (not (eq "*" .GetBodyFieldPath)) }}
493 if protoReq.{{.FieldMaskField}} == nil || len(protoReq.{{.FieldMaskField}}.GetPaths()) == 0 {
494 _, md := descriptor.ForMessage(protoReq.{{.GetBodyFieldStructName}})
495 if fieldMask, err := runtime.FieldMaskFromRequestBody(newReader(), md); err != nil {
496 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
497 } else {
498 protoReq.{{.FieldMaskField}} = fieldMask
499 }
500 }
501 {{end}}
502 {{end}}
503 {{if .PathParams}}
504 var (
505 val string
506 {{- if .HasEnumPathParam}}
507 e int32
508 {{- end}}
509 {{- if .HasRepeatedEnumPathParam}}
510 es []int32
511 {{- end}}
512 ok bool
513 err error
514 _ = err
515 )
516 {{$binding := .}}
517 {{range $param := .PathParams}}
518 {{$enum := $binding.LookupEnum $param}}
519 val, ok = pathParams[{{$param | printf "%q"}}]
520 if !ok {
521 return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}})
522 }
523 {{if $param.IsNestedProto3}}
524 err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val)
525 {{if $enum}}
526 e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}_value)
527 {{end}}
528 {{else if $enum}}
529 e{{if $param.IsRepeated}}s{{end}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}}, {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}_value)
530 {{else}}
531 {{$param.AssignableExpr "protoReq"}}, err = {{$param.ConvertFuncExpr}}(val{{if $param.IsRepeated}}, {{$binding.Registry.GetRepeatedPathParamSeparator | printf "%c" | printf "%q"}}{{end}})
532 {{end}}
533 if err != nil {
534 return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", {{$param | printf "%q"}}, err)
535 }
536 {{if and $enum $param.IsRepeated}}
537 s := make([]{{$enum.GoType $param.Method.Service.File.GoPkg.Path}}, len(es))
538 for i, v := range es {
539 s[i] = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(v)
540 }
541 {{$param.AssignableExpr "protoReq"}} = s
542 {{else if $enum}}
543 {{$param.AssignableExpr "protoReq"}} = {{$enum.GoType $param.Method.Service.File.GoPkg.Path}}(e)
544 {{end}}
545 {{end}}
546 {{end}}
547 {{if .HasQueryParam}}
548 if err := req.ParseForm(); err != nil {
549 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
550 }
551 if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil {
552 return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
553 }
554 {{end}}
555 {{if .Method.GetServerStreaming}}
556 // TODO
557 {{else}}
558 msg, err := server.{{.Method.GetName}}(ctx, &protoReq)
559 return msg, metadata, err
560 {{end}}
561 }`))
562
563 localTrailerTemplate = template.Must(template.New("local-trailer").Parse(`
564 {{$UseRequestContext := .UseRequestContext}}
565 {{range $svc := .Services}}
566 // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server registers the http handlers for service {{$svc.GetName}} to "mux".
567 // UnaryRPC :call {{$svc.GetName}}Server directly.
568 // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906.
569 // Note that using this registration option will cause many gRPC library features to stop working. Consider using Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint instead.
570 func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context, mux *runtime.ServeMux, server {{$svc.GetName}}Server) error {
571 {{range $m := $svc.Methods}}
572 {{range $b := $m.Bindings}}
573 {{if or $m.GetClientStreaming $m.GetServerStreaming}}
574 mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
575 err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")
576 _, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
577 runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
578 return
579 })
580 {{else}}
581 mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
582 {{- if $UseRequestContext }}
583 ctx, cancel := context.WithCancel(req.Context())
584 {{- else -}}
585 ctx, cancel := context.WithCancel(ctx)
586 {{- end }}
587 defer cancel()
588 var stream runtime.ServerTransportStream
589 ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
590 inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
591 rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req)
592 if err != nil {
593 runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
594 return
595 }
596 resp, md, err := local_request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, server, req, pathParams)
597 md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
598 ctx = runtime.NewServerMetadataContext(ctx, md)
599 if err != nil {
600 runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
601 return
602 }
603
604 {{ if $b.ResponseBody }}
605 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
606 {{ else }}
607 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
608 {{end}}
609 })
610 {{end}}
611 {{end}}
612 {{end}}
613 return nil
614 }
615 {{end}}`))
616
617 trailerTemplate = template.Must(template.New("trailer").Parse(`
618 {{$UseRequestContext := .UseRequestContext}}
619 {{range $svc := .Services}}
620 // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint is same as Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} but
621 // automatically dials to "endpoint" and closes the connection when "ctx" gets done.
622 func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
623 conn, err := grpc.Dial(endpoint, opts...)
624 if err != nil {
625 return err
626 }
627 defer func() {
628 if err != nil {
629 if cerr := conn.Close(); cerr != nil {
630 grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
631 }
632 return
633 }
634 go func() {
635 <-ctx.Done()
636 if cerr := conn.Close(); cerr != nil {
637 grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
638 }
639 }()
640 }()
641
642 return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx, mux, conn)
643 }
644
645 // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} registers the http handlers for service {{$svc.GetName}} to "mux".
646 // The handlers forward requests to the grpc endpoint over "conn".
647 func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
648 return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx, mux, New{{$svc.GetName}}Client(conn))
649 }
650
651 // Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client registers the http handlers for service {{$svc.GetName}}
652 // to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "{{$svc.GetName}}Client".
653 // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "{{$svc.GetName}}Client"
654 // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in
655 // "{{$svc.GetName}}Client" to call the correct interceptors.
656 func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, mux *runtime.ServeMux, client {{$svc.GetName}}Client) error {
657 {{range $m := $svc.Methods}}
658 {{range $b := $m.Bindings}}
659 mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
660 {{- if $UseRequestContext }}
661 ctx, cancel := context.WithCancel(req.Context())
662 {{- else -}}
663 ctx, cancel := context.WithCancel(ctx)
664 {{- end }}
665 defer cancel()
666 inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
667 rctx, err := runtime.AnnotateContext(ctx, mux, req)
668 if err != nil {
669 runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
670 return
671 }
672 resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(rctx, inboundMarshaler, client, req, pathParams)
673 ctx = runtime.NewServerMetadataContext(ctx, md)
674 if err != nil {
675 runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
676 return
677 }
678 {{if $m.GetServerStreaming}}
679 {{ if $b.ResponseBody }}
680 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) {
681 res, err := resp.Recv()
682 return response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{res}, err
683 }, mux.GetForwardResponseOptions()...)
684 {{ else }}
685 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, func() (proto.Message, error) { return resp.Recv() }, mux.GetForwardResponseOptions()...)
686 {{end}}
687 {{else}}
688 {{ if $b.ResponseBody }}
689 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
690 {{ else }}
691 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
692 {{end}}
693 {{end}}
694 })
695 {{end}}
696 {{end}}
697 return nil
698 }
699
700 {{range $m := $svc.Methods}}
701 {{range $b := $m.Bindings}}
702 {{if $b.ResponseBody}}
703 type response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} struct {
704 proto.Message
705 }
706
707 func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody() interface{} {
708 response := m.Message.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})
709 return {{$b.ResponseBody.AssignableExpr "response"}}
710 }
711 {{end}}
712 {{end}}
713 {{end}}
714
715 var (
716 {{range $m := $svc.Methods}}
717 {{range $b := $m.Bindings}}
718 pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}}, runtime.AssumeColonVerbOpt({{$.AssumeColonVerb}})))
719 {{end}}
720 {{end}}
721 )
722
723 var (
724 {{range $m := $svc.Methods}}
725 {{range $b := $m.Bindings}}
726 forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = {{if $m.GetServerStreaming}}runtime.ForwardResponseStream{{else}}runtime.ForwardResponseMessage{{end}}
727 {{end}}
728 {{end}}
729 )
730 {{end}}`))
731 )
732
View as plain text