1 package http
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "encoding/xml"
8 "io"
9 "io/ioutil"
10 "net/http"
11 "net/url"
12
13 "github.com/go-kit/kit/endpoint"
14 )
15
16
17 type HTTPClient interface {
18 Do(req *http.Request) (*http.Response, error)
19 }
20
21
22 type Client struct {
23 client HTTPClient
24 req CreateRequestFunc
25 dec DecodeResponseFunc
26 before []RequestFunc
27 after []ClientResponseFunc
28 finalizer []ClientFinalizerFunc
29 bufferedStream bool
30 }
31
32
33 func NewClient(method string, tgt *url.URL, enc EncodeRequestFunc, dec DecodeResponseFunc, options ...ClientOption) *Client {
34 return NewExplicitClient(makeCreateRequestFunc(method, tgt, enc), dec, options...)
35 }
36
37
38
39
40 func NewExplicitClient(req CreateRequestFunc, dec DecodeResponseFunc, options ...ClientOption) *Client {
41 c := &Client{
42 client: http.DefaultClient,
43 req: req,
44 dec: dec,
45 }
46 for _, option := range options {
47 option(c)
48 }
49 return c
50 }
51
52
53 type ClientOption func(*Client)
54
55
56
57 func SetClient(client HTTPClient) ClientOption {
58 return func(c *Client) { c.client = client }
59 }
60
61
62
63 func ClientBefore(before ...RequestFunc) ClientOption {
64 return func(c *Client) { c.before = append(c.before, before...) }
65 }
66
67
68
69
70
71 func ClientAfter(after ...ClientResponseFunc) ClientOption {
72 return func(c *Client) { c.after = append(c.after, after...) }
73 }
74
75
76
77
78 func ClientFinalizer(f ...ClientFinalizerFunc) ClientOption {
79 return func(s *Client) { s.finalizer = append(s.finalizer, f...) }
80 }
81
82
83
84
85 func BufferedStream(buffered bool) ClientOption {
86 return func(c *Client) { c.bufferedStream = buffered }
87 }
88
89
90 func (c Client) Endpoint() endpoint.Endpoint {
91 return func(ctx context.Context, request interface{}) (interface{}, error) {
92 ctx, cancel := context.WithCancel(ctx)
93
94 var (
95 resp *http.Response
96 err error
97 )
98 if c.finalizer != nil {
99 defer func() {
100 if resp != nil {
101 ctx = context.WithValue(ctx, ContextKeyResponseHeaders, resp.Header)
102 ctx = context.WithValue(ctx, ContextKeyResponseSize, resp.ContentLength)
103 }
104 for _, f := range c.finalizer {
105 f(ctx, err)
106 }
107 }()
108 }
109
110 req, err := c.req(ctx, request)
111 if err != nil {
112 cancel()
113 return nil, err
114 }
115
116 for _, f := range c.before {
117 ctx = f(ctx, req)
118 }
119
120 resp, err = c.client.Do(req.WithContext(ctx))
121 if err != nil {
122 cancel()
123 return nil, err
124 }
125
126
127
128
129 if c.bufferedStream {
130 resp.Body = bodyWithCancel{ReadCloser: resp.Body, cancel: cancel}
131 } else {
132 defer resp.Body.Close()
133 defer cancel()
134 }
135
136 for _, f := range c.after {
137 ctx = f(ctx, resp)
138 }
139
140 response, err := c.dec(ctx, resp)
141 if err != nil {
142 return nil, err
143 }
144
145 return response, nil
146 }
147 }
148
149
150
151 type bodyWithCancel struct {
152 io.ReadCloser
153
154 cancel context.CancelFunc
155 }
156
157 func (bwc bodyWithCancel) Close() error {
158 bwc.ReadCloser.Close()
159 bwc.cancel()
160 return nil
161 }
162
163
164
165
166
167
168
169 type ClientFinalizerFunc func(ctx context.Context, err error)
170
171
172
173
174
175 func EncodeJSONRequest(c context.Context, r *http.Request, request interface{}) error {
176 r.Header.Set("Content-Type", "application/json; charset=utf-8")
177 if headerer, ok := request.(Headerer); ok {
178 for k := range headerer.Headers() {
179 r.Header.Set(k, headerer.Headers().Get(k))
180 }
181 }
182 var b bytes.Buffer
183 r.Body = ioutil.NopCloser(&b)
184 return json.NewEncoder(&b).Encode(request)
185 }
186
187
188
189
190 func EncodeXMLRequest(c context.Context, r *http.Request, request interface{}) error {
191 r.Header.Set("Content-Type", "text/xml; charset=utf-8")
192 if headerer, ok := request.(Headerer); ok {
193 for k := range headerer.Headers() {
194 r.Header.Set(k, headerer.Headers().Get(k))
195 }
196 }
197 var b bytes.Buffer
198 r.Body = ioutil.NopCloser(&b)
199 return xml.NewEncoder(&b).Encode(request)
200 }
201
202
203
204
205
206 func makeCreateRequestFunc(method string, target *url.URL, enc EncodeRequestFunc) CreateRequestFunc {
207 return func(ctx context.Context, request interface{}) (*http.Request, error) {
208 req, err := http.NewRequest(method, target.String(), nil)
209 if err != nil {
210 return nil, err
211 }
212
213 if err = enc(ctx, req, request); err != nil {
214 return nil, err
215 }
216
217 return req, nil
218 }
219 }
220
View as plain text