...
1 package client
2
3 import (
4 "io"
5 "net/http"
6 "sync/atomic"
7 )
8
9
10
11
12
13
14
15 func KeepAliveTransport(rt http.RoundTripper) http.RoundTripper {
16 return &keepAliveTransport{wrapped: rt}
17 }
18
19 type keepAliveTransport struct {
20 wrapped http.RoundTripper
21 }
22
23 func (k *keepAliveTransport) RoundTrip(r *http.Request) (*http.Response, error) {
24 resp, err := k.wrapped.RoundTrip(r)
25 if err != nil {
26 return resp, err
27 }
28 resp.Body = &drainingReadCloser{rdr: resp.Body}
29 return resp, nil
30 }
31
32 type drainingReadCloser struct {
33 rdr io.ReadCloser
34 seenEOF uint32
35 }
36
37 func (d *drainingReadCloser) Read(p []byte) (n int, err error) {
38 n, err = d.rdr.Read(p)
39 if err == io.EOF || n == 0 {
40 atomic.StoreUint32(&d.seenEOF, 1)
41 }
42 return
43 }
44
45 func (d *drainingReadCloser) Close() error {
46
47 if atomic.LoadUint32(&d.seenEOF) != 1 {
48
49
50
51 _, _ = io.Copy(io.Discard, d.rdr)
52 }
53 return d.rdr.Close()
54 }
55
View as plain text