1 package runtime
2
3 import (
4 "context"
5 "errors"
6 "io"
7 "net/http"
8
9 "google.golang.org/grpc/codes"
10 "google.golang.org/grpc/grpclog"
11 "google.golang.org/grpc/status"
12 )
13
14
15 type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)
16
17
18 type StreamErrorHandlerFunc func(context.Context, error) *status.Status
19
20
21 type RoutingErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, int)
22
23
24
25 type HTTPStatusError struct {
26 HTTPStatus int
27 Err error
28 }
29
30 func (e *HTTPStatusError) Error() string {
31 return e.Err.Error()
32 }
33
34
35
36 func HTTPStatusFromCode(code codes.Code) int {
37 switch code {
38 case codes.OK:
39 return http.StatusOK
40 case codes.Canceled:
41 return 499
42 case codes.Unknown:
43 return http.StatusInternalServerError
44 case codes.InvalidArgument:
45 return http.StatusBadRequest
46 case codes.DeadlineExceeded:
47 return http.StatusGatewayTimeout
48 case codes.NotFound:
49 return http.StatusNotFound
50 case codes.AlreadyExists:
51 return http.StatusConflict
52 case codes.PermissionDenied:
53 return http.StatusForbidden
54 case codes.Unauthenticated:
55 return http.StatusUnauthorized
56 case codes.ResourceExhausted:
57 return http.StatusTooManyRequests
58 case codes.FailedPrecondition:
59
60 return http.StatusBadRequest
61 case codes.Aborted:
62 return http.StatusConflict
63 case codes.OutOfRange:
64 return http.StatusBadRequest
65 case codes.Unimplemented:
66 return http.StatusNotImplemented
67 case codes.Internal:
68 return http.StatusInternalServerError
69 case codes.Unavailable:
70 return http.StatusServiceUnavailable
71 case codes.DataLoss:
72 return http.StatusInternalServerError
73 default:
74 grpclog.Infof("Unknown gRPC error code: %v", code)
75 return http.StatusInternalServerError
76 }
77 }
78
79
80 func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
81 mux.errorHandler(ctx, mux, marshaler, w, r, err)
82 }
83
84
85
86
87
88
89
90
91
92
93 func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
94
95 const fallback = `{"code": 13, "message": "failed to marshal error message"}`
96
97 var customStatus *HTTPStatusError
98 if errors.As(err, &customStatus) {
99 err = customStatus.Err
100 }
101
102 s := status.Convert(err)
103 pb := s.Proto()
104
105 w.Header().Del("Trailer")
106 w.Header().Del("Transfer-Encoding")
107
108 contentType := marshaler.ContentType(pb)
109 w.Header().Set("Content-Type", contentType)
110
111 if s.Code() == codes.Unauthenticated {
112 w.Header().Set("WWW-Authenticate", s.Message())
113 }
114
115 buf, merr := marshaler.Marshal(pb)
116 if merr != nil {
117 grpclog.Infof("Failed to marshal error message %q: %v", s, merr)
118 w.WriteHeader(http.StatusInternalServerError)
119 if _, err := io.WriteString(w, fallback); err != nil {
120 grpclog.Infof("Failed to write response: %v", err)
121 }
122 return
123 }
124
125 md, ok := ServerMetadataFromContext(ctx)
126 if !ok {
127 grpclog.Infof("Failed to extract ServerMetadata from context")
128 }
129
130 handleForwardResponseServerMetadata(w, mux, md)
131
132
133
134
135
136
137 doForwardTrailers := requestAcceptsTrailers(r)
138
139 if doForwardTrailers {
140 handleForwardResponseTrailerHeader(w, mux, md)
141 w.Header().Set("Transfer-Encoding", "chunked")
142 }
143
144 st := HTTPStatusFromCode(s.Code())
145 if customStatus != nil {
146 st = customStatus.HTTPStatus
147 }
148
149 w.WriteHeader(st)
150 if _, err := w.Write(buf); err != nil {
151 grpclog.Infof("Failed to write response: %v", err)
152 }
153
154 if doForwardTrailers {
155 handleForwardResponseTrailer(w, mux, md)
156 }
157 }
158
159 func DefaultStreamErrorHandler(_ context.Context, err error) *status.Status {
160 return status.Convert(err)
161 }
162
163
164
165
166
167
168
169
170 func DefaultRoutingErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
171 sterr := status.Error(codes.Internal, "Unexpected routing error")
172 switch httpStatus {
173 case http.StatusBadRequest:
174 sterr = status.Error(codes.InvalidArgument, http.StatusText(httpStatus))
175 case http.StatusMethodNotAllowed:
176 sterr = status.Error(codes.Unimplemented, http.StatusText(httpStatus))
177 case http.StatusNotFound:
178 sterr = status.Error(codes.NotFound, http.StatusText(httpStatus))
179 }
180 mux.errorHandler(ctx, mux, marshaler, w, r, sterr)
181 }
182
View as plain text