1 package http
2
3 import (
4 "context"
5 "encoding/json"
6 "net/http"
7
8 "github.com/go-kit/kit/endpoint"
9 "github.com/go-kit/kit/transport"
10 "github.com/go-kit/log"
11 )
12
13
14 type Server struct {
15 e endpoint.Endpoint
16 dec DecodeRequestFunc
17 enc EncodeResponseFunc
18 before []RequestFunc
19 after []ServerResponseFunc
20 errorEncoder ErrorEncoder
21 finalizer []ServerFinalizerFunc
22 errorHandler transport.ErrorHandler
23 }
24
25
26
27 func NewServer(
28 e endpoint.Endpoint,
29 dec DecodeRequestFunc,
30 enc EncodeResponseFunc,
31 options ...ServerOption,
32 ) *Server {
33 s := &Server{
34 e: e,
35 dec: dec,
36 enc: enc,
37 errorEncoder: DefaultErrorEncoder,
38 errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()),
39 }
40 for _, option := range options {
41 option(s)
42 }
43 return s
44 }
45
46
47 type ServerOption func(*Server)
48
49
50
51 func ServerBefore(before ...RequestFunc) ServerOption {
52 return func(s *Server) { s.before = append(s.before, before...) }
53 }
54
55
56
57 func ServerAfter(after ...ServerResponseFunc) ServerOption {
58 return func(s *Server) { s.after = append(s.after, after...) }
59 }
60
61
62
63
64
65 func ServerErrorEncoder(ee ErrorEncoder) ServerOption {
66 return func(s *Server) { s.errorEncoder = ee }
67 }
68
69
70
71
72
73
74
75 func ServerErrorLogger(logger log.Logger) ServerOption {
76 return func(s *Server) { s.errorHandler = transport.NewLogErrorHandler(logger) }
77 }
78
79
80
81
82
83
84 func ServerErrorHandler(errorHandler transport.ErrorHandler) ServerOption {
85 return func(s *Server) { s.errorHandler = errorHandler }
86 }
87
88
89
90 func ServerFinalizer(f ...ServerFinalizerFunc) ServerOption {
91 return func(s *Server) { s.finalizer = append(s.finalizer, f...) }
92 }
93
94
95 func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
96 ctx := r.Context()
97
98 if len(s.finalizer) > 0 {
99 iw := &interceptingWriter{w, http.StatusOK, 0}
100 defer func() {
101 ctx = context.WithValue(ctx, ContextKeyResponseHeaders, iw.Header())
102 ctx = context.WithValue(ctx, ContextKeyResponseSize, iw.written)
103 for _, f := range s.finalizer {
104 f(ctx, iw.code, r)
105 }
106 }()
107 w = iw.reimplementInterfaces()
108 }
109
110 for _, f := range s.before {
111 ctx = f(ctx, r)
112 }
113
114 request, err := s.dec(ctx, r)
115 if err != nil {
116 s.errorHandler.Handle(ctx, err)
117 s.errorEncoder(ctx, err, w)
118 return
119 }
120
121 response, err := s.e(ctx, request)
122 if err != nil {
123 s.errorHandler.Handle(ctx, err)
124 s.errorEncoder(ctx, err, w)
125 return
126 }
127
128 for _, f := range s.after {
129 ctx = f(ctx, w)
130 }
131
132 if err := s.enc(ctx, w, response); err != nil {
133 s.errorHandler.Handle(ctx, err)
134 s.errorEncoder(ctx, err, w)
135 return
136 }
137 }
138
139
140
141
142
143 type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
144
145
146
147
148
149
150 type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request)
151
152
153
154 func NopRequestDecoder(ctx context.Context, r *http.Request) (interface{}, error) {
155 return nil, nil
156 }
157
158
159
160
161
162
163 func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
164 w.Header().Set("Content-Type", "application/json; charset=utf-8")
165 if headerer, ok := response.(Headerer); ok {
166 for k, values := range headerer.Headers() {
167 for _, v := range values {
168 w.Header().Add(k, v)
169 }
170 }
171 }
172 code := http.StatusOK
173 if sc, ok := response.(StatusCoder); ok {
174 code = sc.StatusCode()
175 }
176 w.WriteHeader(code)
177 if code == http.StatusNoContent {
178 return nil
179 }
180 return json.NewEncoder(w).Encode(response)
181 }
182
183
184
185
186
187
188
189
190 func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
191 contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
192 if marshaler, ok := err.(json.Marshaler); ok {
193 if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil {
194 contentType, body = "application/json; charset=utf-8", jsonBody
195 }
196 }
197 w.Header().Set("Content-Type", contentType)
198 if headerer, ok := err.(Headerer); ok {
199 for k, values := range headerer.Headers() {
200 for _, v := range values {
201 w.Header().Add(k, v)
202 }
203 }
204 }
205 code := http.StatusInternalServerError
206 if sc, ok := err.(StatusCoder); ok {
207 code = sc.StatusCode()
208 }
209 w.WriteHeader(code)
210 w.Write(body)
211 }
212
213
214
215
216 type StatusCoder interface {
217 StatusCode() int
218 }
219
220
221
222
223 type Headerer interface {
224 Headers() http.Header
225 }
226
View as plain text