1 package runtime
2
3 import (
4 "context"
5 "io"
6 "net/http"
7 "strings"
8
9 "github.com/grpc-ecosystem/grpc-gateway/internal"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/grpclog"
12 "google.golang.org/grpc/status"
13 )
14
15
16
17 func HTTPStatusFromCode(code codes.Code) int {
18 switch code {
19 case codes.OK:
20 return http.StatusOK
21 case codes.Canceled:
22 return http.StatusRequestTimeout
23 case codes.Unknown:
24 return http.StatusInternalServerError
25 case codes.InvalidArgument:
26 return http.StatusBadRequest
27 case codes.DeadlineExceeded:
28 return http.StatusGatewayTimeout
29 case codes.NotFound:
30 return http.StatusNotFound
31 case codes.AlreadyExists:
32 return http.StatusConflict
33 case codes.PermissionDenied:
34 return http.StatusForbidden
35 case codes.Unauthenticated:
36 return http.StatusUnauthorized
37 case codes.ResourceExhausted:
38 return http.StatusTooManyRequests
39 case codes.FailedPrecondition:
40
41 return http.StatusBadRequest
42 case codes.Aborted:
43 return http.StatusConflict
44 case codes.OutOfRange:
45 return http.StatusBadRequest
46 case codes.Unimplemented:
47 return http.StatusNotImplemented
48 case codes.Internal:
49 return http.StatusInternalServerError
50 case codes.Unavailable:
51 return http.StatusServiceUnavailable
52 case codes.DataLoss:
53 return http.StatusInternalServerError
54 }
55
56 grpclog.Infof("Unknown gRPC error code: %v", code)
57 return http.StatusInternalServerError
58 }
59
60 var (
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77 HTTPError = MuxOrGlobalHTTPError
78
79
80
81
82
83 GlobalHTTPErrorHandler = DefaultHTTPError
84
85
86
87
88
89
90
91
92
93
94
95 OtherErrorHandler = DefaultOtherErrorHandler
96 )
97
98
99 func MuxOrGlobalHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
100 if mux.protoErrorHandler != nil {
101 mux.protoErrorHandler(ctx, mux, marshaler, w, r, err)
102 } else {
103 GlobalHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
104 }
105 }
106
107
108
109
110
111
112
113 func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
114 const fallback = `{"error": "failed to marshal error message"}`
115
116 s, ok := status.FromError(err)
117 if !ok {
118 s = status.New(codes.Unknown, err.Error())
119 }
120
121 w.Header().Del("Trailer")
122 w.Header().Del("Transfer-Encoding")
123
124 contentType := marshaler.ContentType()
125
126
127
128 if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok {
129 pb := s.Proto()
130 contentType = typeMarshaler.ContentTypeFromMessage(pb)
131 }
132 w.Header().Set("Content-Type", contentType)
133
134 body := &internal.Error{
135 Error: s.Message(),
136 Message: s.Message(),
137 Code: int32(s.Code()),
138 Details: s.Proto().GetDetails(),
139 }
140
141 buf, merr := marshaler.Marshal(body)
142 if merr != nil {
143 grpclog.Infof("Failed to marshal error message %q: %v", body, merr)
144 w.WriteHeader(http.StatusInternalServerError)
145 if _, err := io.WriteString(w, fallback); err != nil {
146 grpclog.Infof("Failed to write response: %v", err)
147 }
148 return
149 }
150
151 md, ok := ServerMetadataFromContext(ctx)
152 if !ok {
153 grpclog.Infof("Failed to extract ServerMetadata from context")
154 }
155
156 handleForwardResponseServerMetadata(w, mux, md)
157
158
159
160
161
162
163 var wantsTrailers bool
164
165 if te := r.Header.Get("TE"); strings.Contains(strings.ToLower(te), "trailers") {
166 wantsTrailers = true
167 handleForwardResponseTrailerHeader(w, md)
168 w.Header().Set("Transfer-Encoding", "chunked")
169 }
170
171 st := HTTPStatusFromCode(s.Code())
172 w.WriteHeader(st)
173 if _, err := w.Write(buf); err != nil {
174 grpclog.Infof("Failed to write response: %v", err)
175 }
176
177 if wantsTrailers {
178 handleForwardResponseTrailer(w, md)
179 }
180 }
181
182
183
184 func DefaultOtherErrorHandler(w http.ResponseWriter, _ *http.Request, msg string, code int) {
185 http.Error(w, msg, code)
186 }
187
View as plain text