...

Source file src/github.com/grpc-ecosystem/grpc-gateway/runtime/handler.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/runtime

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/textproto"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/grpc-ecosystem/grpc-gateway/internal"
    13  	"google.golang.org/grpc/grpclog"
    14  )
    15  
    16  var errEmptyResponse = errors.New("empty response")
    17  
    18  // ForwardResponseStream forwards the stream from gRPC server to REST client.
    19  func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
    20  	f, ok := w.(http.Flusher)
    21  	if !ok {
    22  		grpclog.Infof("Flush not supported in %T", w)
    23  		http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
    24  		return
    25  	}
    26  
    27  	md, ok := ServerMetadataFromContext(ctx)
    28  	if !ok {
    29  		grpclog.Infof("Failed to extract ServerMetadata from context")
    30  		http.Error(w, "unexpected error", http.StatusInternalServerError)
    31  		return
    32  	}
    33  	handleForwardResponseServerMetadata(w, mux, md)
    34  
    35  	w.Header().Set("Transfer-Encoding", "chunked")
    36  	w.Header().Set("Content-Type", marshaler.ContentType())
    37  	if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
    38  		HTTPError(ctx, mux, marshaler, w, req, err)
    39  		return
    40  	}
    41  
    42  	var delimiter []byte
    43  	if d, ok := marshaler.(Delimited); ok {
    44  		delimiter = d.Delimiter()
    45  	} else {
    46  		delimiter = []byte("\n")
    47  	}
    48  
    49  	var wroteHeader bool
    50  	for {
    51  		resp, err := recv()
    52  		if err == io.EOF {
    53  			return
    54  		}
    55  		if err != nil {
    56  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
    57  			return
    58  		}
    59  		if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
    60  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
    61  			return
    62  		}
    63  
    64  		var buf []byte
    65  		switch {
    66  		case resp == nil:
    67  			buf, err = marshaler.Marshal(errorChunk(streamError(ctx, mux.streamErrorHandler, errEmptyResponse)))
    68  		default:
    69  			result := map[string]interface{}{"result": resp}
    70  			if rb, ok := resp.(responseBody); ok {
    71  				result["result"] = rb.XXX_ResponseBody()
    72  			}
    73  
    74  			buf, err = marshaler.Marshal(result)
    75  		}
    76  
    77  		if err != nil {
    78  			grpclog.Infof("Failed to marshal response chunk: %v", err)
    79  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
    80  			return
    81  		}
    82  		if _, err = w.Write(buf); err != nil {
    83  			grpclog.Infof("Failed to send response chunk: %v", err)
    84  			return
    85  		}
    86  		wroteHeader = true
    87  		if _, err = w.Write(delimiter); err != nil {
    88  			grpclog.Infof("Failed to send delimiter chunk: %v", err)
    89  			return
    90  		}
    91  		f.Flush()
    92  	}
    93  }
    94  
    95  func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
    96  	for k, vs := range md.HeaderMD {
    97  		if h, ok := mux.outgoingHeaderMatcher(k); ok {
    98  			for _, v := range vs {
    99  				w.Header().Add(h, v)
   100  			}
   101  		}
   102  	}
   103  }
   104  
   105  func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
   106  	for k := range md.TrailerMD {
   107  		tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
   108  		w.Header().Add("Trailer", tKey)
   109  	}
   110  }
   111  
   112  func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
   113  	for k, vs := range md.TrailerMD {
   114  		tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
   115  		for _, v := range vs {
   116  			w.Header().Add(tKey, v)
   117  		}
   118  	}
   119  }
   120  
   121  // responseBody interface contains method for getting field for marshaling to the response body
   122  // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
   123  type responseBody interface {
   124  	XXX_ResponseBody() interface{}
   125  }
   126  
   127  // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
   128  func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
   129  	md, ok := ServerMetadataFromContext(ctx)
   130  	if !ok {
   131  		grpclog.Infof("Failed to extract ServerMetadata from context")
   132  	}
   133  
   134  	handleForwardResponseServerMetadata(w, mux, md)
   135  	handleForwardResponseTrailerHeader(w, md)
   136  
   137  	contentType := marshaler.ContentType()
   138  	// Check marshaler on run time in order to keep backwards compatibility
   139  	// An interface param needs to be added to the ContentType() function on
   140  	// the Marshal interface to be able to remove this check
   141  	if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok {
   142  		contentType = typeMarshaler.ContentTypeFromMessage(resp)
   143  	}
   144  	w.Header().Set("Content-Type", contentType)
   145  
   146  	if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
   147  		HTTPError(ctx, mux, marshaler, w, req, err)
   148  		return
   149  	}
   150  	var buf []byte
   151  	var err error
   152  	if rb, ok := resp.(responseBody); ok {
   153  		buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
   154  	} else {
   155  		buf, err = marshaler.Marshal(resp)
   156  	}
   157  	if err != nil {
   158  		grpclog.Infof("Marshal error: %v", err)
   159  		HTTPError(ctx, mux, marshaler, w, req, err)
   160  		return
   161  	}
   162  
   163  	if _, err = w.Write(buf); err != nil {
   164  		grpclog.Infof("Failed to write response: %v", err)
   165  	}
   166  
   167  	handleForwardResponseTrailer(w, md)
   168  }
   169  
   170  func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
   171  	if len(opts) == 0 {
   172  		return nil
   173  	}
   174  	for _, opt := range opts {
   175  		if err := opt(ctx, w, resp); err != nil {
   176  			grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
   177  			return err
   178  		}
   179  	}
   180  	return nil
   181  }
   182  
   183  func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
   184  	serr := streamError(ctx, mux.streamErrorHandler, err)
   185  	if !wroteHeader {
   186  		w.WriteHeader(int(serr.HttpCode))
   187  	}
   188  	buf, merr := marshaler.Marshal(errorChunk(serr))
   189  	if merr != nil {
   190  		grpclog.Infof("Failed to marshal an error: %v", merr)
   191  		return
   192  	}
   193  	if _, werr := w.Write(buf); werr != nil {
   194  		grpclog.Infof("Failed to notify error to client: %v", werr)
   195  		return
   196  	}
   197  }
   198  
   199  // streamError returns the payload for the final message in a response stream
   200  // that represents the given err.
   201  func streamError(ctx context.Context, errHandler StreamErrorHandlerFunc, err error) *StreamError {
   202  	serr := errHandler(ctx, err)
   203  	if serr != nil {
   204  		return serr
   205  	}
   206  	// TODO: log about misbehaving stream error handler?
   207  	return DefaultHTTPStreamErrorHandler(ctx, err)
   208  }
   209  
   210  func errorChunk(err *StreamError) map[string]proto.Message {
   211  	return map[string]proto.Message{"error": (*internal.StreamError)(err)}
   212  }
   213  

View as plain text