1 package context
2
3 import (
4 "context"
5 "errors"
6 "net"
7 "net/http"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/docker/distribution/uuid"
13 "github.com/gorilla/mux"
14 log "github.com/sirupsen/logrus"
15 )
16
17
18 var (
19 ErrNoRequestContext = errors.New("no http request in context")
20 ErrNoResponseWriterContext = errors.New("no http response in context")
21 )
22
23 func parseIP(ipStr string) net.IP {
24 ip := net.ParseIP(ipStr)
25 if ip == nil {
26 log.Warnf("invalid remote IP address: %q", ipStr)
27 }
28 return ip
29 }
30
31
32
33 func RemoteAddr(r *http.Request) string {
34 if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
35 proxies := strings.Split(prior, ",")
36 if len(proxies) > 0 {
37 remoteAddr := strings.Trim(proxies[0], " ")
38 if parseIP(remoteAddr) != nil {
39 return remoteAddr
40 }
41 }
42 }
43
44
45 if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
46 if parseIP(realIP) != nil {
47 return realIP
48 }
49 }
50
51 return r.RemoteAddr
52 }
53
54
55
56 func RemoteIP(r *http.Request) string {
57 addr := RemoteAddr(r)
58
59
60 if ip, _, err := net.SplitHostPort(addr); err == nil {
61 return ip
62 }
63
64 return addr
65 }
66
67
68
69
70
71
72 func WithRequest(ctx context.Context, r *http.Request) context.Context {
73 if ctx.Value("http.request") != nil {
74
75
76
77 panic("only one request per context")
78 }
79
80 return &httpRequestContext{
81 Context: ctx,
82 startedAt: time.Now(),
83 id: uuid.Generate().String(),
84 r: r,
85 }
86 }
87
88
89
90
91 func GetRequest(ctx context.Context) (*http.Request, error) {
92 if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
93 return r, nil
94 }
95 return nil, ErrNoRequestContext
96 }
97
98
99
100 func GetRequestID(ctx context.Context) string {
101 return GetStringValue(ctx, "http.request.id")
102 }
103
104
105
106 func WithResponseWriter(ctx context.Context, w http.ResponseWriter) (context.Context, http.ResponseWriter) {
107 irw := instrumentedResponseWriter{
108 ResponseWriter: w,
109 Context: ctx,
110 }
111 return &irw, &irw
112 }
113
114
115
116
117 func GetResponseWriter(ctx context.Context) (http.ResponseWriter, error) {
118 v := ctx.Value("http.response")
119
120 rw, ok := v.(http.ResponseWriter)
121 if !ok || rw == nil {
122 return nil, ErrNoResponseWriterContext
123 }
124
125 return rw, nil
126 }
127
128
129
130 var getVarsFromRequest = mux.Vars
131
132
133
134
135
136
137 func WithVars(ctx context.Context, r *http.Request) context.Context {
138 return &muxVarsContext{
139 Context: ctx,
140 vars: getVarsFromRequest(r),
141 }
142 }
143
144
145
146
147 func GetRequestLogger(ctx context.Context) Logger {
148 return GetLogger(ctx,
149 "http.request.id",
150 "http.request.method",
151 "http.request.host",
152 "http.request.uri",
153 "http.request.referer",
154 "http.request.useragent",
155 "http.request.remoteaddr",
156 "http.request.contenttype")
157 }
158
159
160
161
162
163 func GetResponseLogger(ctx context.Context) Logger {
164 l := getLogrusLogger(ctx,
165 "http.response.written",
166 "http.response.status",
167 "http.response.contenttype")
168
169 duration := Since(ctx, "http.request.startedat")
170
171 if duration > 0 {
172 l = l.WithField("http.response.duration", duration.String())
173 }
174
175 return l
176 }
177
178
179 type httpRequestContext struct {
180 context.Context
181
182 startedAt time.Time
183 id string
184 r *http.Request
185 }
186
187
188
189
190 func (ctx *httpRequestContext) Value(key interface{}) interface{} {
191 if keyStr, ok := key.(string); ok {
192 if keyStr == "http.request" {
193 return ctx.r
194 }
195
196 if !strings.HasPrefix(keyStr, "http.request.") {
197 goto fallback
198 }
199
200 parts := strings.Split(keyStr, ".")
201
202 if len(parts) != 3 {
203 goto fallback
204 }
205
206 switch parts[2] {
207 case "uri":
208 return ctx.r.RequestURI
209 case "remoteaddr":
210 return RemoteAddr(ctx.r)
211 case "method":
212 return ctx.r.Method
213 case "host":
214 return ctx.r.Host
215 case "referer":
216 referer := ctx.r.Referer()
217 if referer != "" {
218 return referer
219 }
220 case "useragent":
221 return ctx.r.UserAgent()
222 case "id":
223 return ctx.id
224 case "startedat":
225 return ctx.startedAt
226 case "contenttype":
227 ct := ctx.r.Header.Get("Content-Type")
228 if ct != "" {
229 return ct
230 }
231 }
232 }
233
234 fallback:
235 return ctx.Context.Value(key)
236 }
237
238 type muxVarsContext struct {
239 context.Context
240 vars map[string]string
241 }
242
243 func (ctx *muxVarsContext) Value(key interface{}) interface{} {
244 if keyStr, ok := key.(string); ok {
245 if keyStr == "vars" {
246 return ctx.vars
247 }
248
249 if v, ok := ctx.vars[strings.TrimPrefix(keyStr, "vars.")]; ok {
250 return v
251 }
252 }
253
254 return ctx.Context.Value(key)
255 }
256
257
258
259
260 type instrumentedResponseWriter struct {
261 http.ResponseWriter
262 context.Context
263
264 mu sync.Mutex
265 status int
266 written int64
267 }
268
269 func (irw *instrumentedResponseWriter) Write(p []byte) (n int, err error) {
270 n, err = irw.ResponseWriter.Write(p)
271
272 irw.mu.Lock()
273 irw.written += int64(n)
274
275
276 if irw.status == 0 {
277 irw.status = http.StatusOK
278 }
279
280 irw.mu.Unlock()
281
282 return
283 }
284
285 func (irw *instrumentedResponseWriter) WriteHeader(status int) {
286 irw.ResponseWriter.WriteHeader(status)
287
288 irw.mu.Lock()
289 irw.status = status
290 irw.mu.Unlock()
291 }
292
293 func (irw *instrumentedResponseWriter) Flush() {
294 if flusher, ok := irw.ResponseWriter.(http.Flusher); ok {
295 flusher.Flush()
296 }
297 }
298
299 func (irw *instrumentedResponseWriter) Value(key interface{}) interface{} {
300 if keyStr, ok := key.(string); ok {
301 if keyStr == "http.response" {
302 return irw
303 }
304
305 if !strings.HasPrefix(keyStr, "http.response.") {
306 goto fallback
307 }
308
309 parts := strings.Split(keyStr, ".")
310
311 if len(parts) != 3 {
312 goto fallback
313 }
314
315 irw.mu.Lock()
316 defer irw.mu.Unlock()
317
318 switch parts[2] {
319 case "written":
320 return irw.written
321 case "status":
322 return irw.status
323 case "contenttype":
324 contentType := irw.Header().Get("Content-Type")
325 if contentType != "" {
326 return contentType
327 }
328 }
329 }
330
331 fallback:
332 return irw.Context.Value(key)
333 }
334
View as plain text