1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package httpproxy
16
17 import (
18 "bytes"
19 "errors"
20 "io/ioutil"
21 "net/http"
22 "net/http/httptest"
23 "net/url"
24 "reflect"
25 "testing"
26
27 "go.uber.org/zap"
28 )
29
30 type staticRoundTripper struct {
31 res *http.Response
32 err error
33 }
34
35 func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
36 return srt.res, srt.err
37 }
38
39 func TestReverseProxyServe(t *testing.T) {
40 u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
41 lg := zap.NewExample()
42
43 tests := []struct {
44 eps []*endpoint
45 rt http.RoundTripper
46 want int
47 }{
48
49 {
50 eps: []*endpoint{},
51 rt: &staticRoundTripper{
52 res: &http.Response{
53 StatusCode: http.StatusCreated,
54 Body: ioutil.NopCloser(&bytes.Reader{}),
55 },
56 },
57 want: http.StatusServiceUnavailable,
58 },
59
60
61 {
62 eps: []*endpoint{{URL: u, Available: true}},
63 rt: &staticRoundTripper{err: errors.New("what a bad trip")},
64 want: http.StatusBadGateway,
65 },
66
67
68 {
69 eps: []*endpoint{{URL: u, Available: true}},
70 rt: &staticRoundTripper{
71 res: &http.Response{
72 StatusCode: http.StatusCreated,
73 Body: ioutil.NopCloser(&bytes.Reader{}),
74 Header: map[string][]string{"Content-Type": {"application/json"}},
75 },
76 },
77 want: http.StatusCreated,
78 },
79 }
80
81 for i, tt := range tests {
82 rp := reverseProxy{
83 lg: lg,
84 director: &director{lg: lg, ep: tt.eps},
85 transport: tt.rt,
86 }
87
88 req, _ := http.NewRequest("GET", "http://192.0.2.2:2379", nil)
89 rr := httptest.NewRecorder()
90 rp.ServeHTTP(rr, req)
91
92 if rr.Code != tt.want {
93 t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
94 }
95 if gct := rr.Header().Get("Content-Type"); gct != "application/json" {
96 t.Errorf("#%d: Content-Type = %s, want %s", i, gct, "application/json")
97 }
98 }
99 }
100
101 func TestRedirectRequest(t *testing.T) {
102 loc := url.URL{
103 Scheme: "http",
104 Host: "bar.example.com",
105 }
106
107 req := &http.Request{
108 Method: "GET",
109 Host: "foo.example.com",
110 URL: &url.URL{
111 Host: "foo.example.com",
112 Path: "/v2/keys/baz",
113 },
114 }
115
116 redirectRequest(req, loc)
117
118 want := &http.Request{
119 Method: "GET",
120
121 Host: "foo.example.com",
122 URL: &url.URL{
123
124 Scheme: "http",
125
126 Host: "bar.example.com",
127 Path: "/v2/keys/baz",
128 },
129 }
130
131 if !reflect.DeepEqual(want, req) {
132 t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
133 }
134 }
135
136 func TestMaybeSetForwardedFor(t *testing.T) {
137 tests := []struct {
138 raddr string
139 fwdFor string
140 want string
141 }{
142 {"192.0.2.3:8002", "", "192.0.2.3"},
143 {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
144 {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
145 {"example.com:8002", "", "example.com"},
146
147
148
149 {":8002", "", ""},
150 {"192.0.2.3", "", ""},
151
152
153 {"12", "", ""},
154 {"12", "192.0.2.3", "192.0.2.3"},
155 }
156
157 for i, tt := range tests {
158 req := &http.Request{
159 RemoteAddr: tt.raddr,
160 Header: make(http.Header),
161 }
162
163 if tt.fwdFor != "" {
164 req.Header.Set("X-Forwarded-For", tt.fwdFor)
165 }
166
167 maybeSetForwardedFor(req)
168 got := req.Header.Get("X-Forwarded-For")
169 if tt.want != got {
170 t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
171 }
172 }
173 }
174
175 func TestRemoveSingleHopHeaders(t *testing.T) {
176 hdr := http.Header(map[string][]string{
177
178 "Connection": {"close"},
179 "Keep-Alive": {"foo"},
180 "Proxy-Authenticate": {"Basic realm=example.com"},
181 "Proxy-Authorization": {"foo"},
182 "Te": {"deflate,gzip"},
183 "Trailers": {"ETag"},
184 "Transfer-Encoding": {"chunked"},
185 "Upgrade": {"WebSocket"},
186
187
188 "Accept": {"application/json"},
189 "X-Foo": {"Bar"},
190 })
191
192 removeSingleHopHeaders(&hdr)
193
194 want := http.Header(map[string][]string{
195 "Accept": {"application/json"},
196 "X-Foo": {"Bar"},
197 })
198
199 if !reflect.DeepEqual(want, hdr) {
200 t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
201 }
202 }
203
204 func TestCopyHeader(t *testing.T) {
205 tests := []struct {
206 src http.Header
207 dst http.Header
208 want http.Header
209 }{
210 {
211 src: http.Header(map[string][]string{
212 "Foo": {"bar", "baz"},
213 }),
214 dst: http.Header(map[string][]string{}),
215 want: http.Header(map[string][]string{
216 "Foo": {"bar", "baz"},
217 }),
218 },
219 {
220 src: http.Header(map[string][]string{
221 "Foo": {"bar"},
222 "Ping": {"pong"},
223 }),
224 dst: http.Header(map[string][]string{}),
225 want: http.Header(map[string][]string{
226 "Foo": {"bar"},
227 "Ping": {"pong"},
228 }),
229 },
230 {
231 src: http.Header(map[string][]string{
232 "Foo": {"bar", "baz"},
233 }),
234 dst: http.Header(map[string][]string{
235 "Foo": {"qux"},
236 }),
237 want: http.Header(map[string][]string{
238 "Foo": {"qux", "bar", "baz"},
239 }),
240 },
241 }
242
243 for i, tt := range tests {
244 copyHeader(tt.dst, tt.src)
245 if !reflect.DeepEqual(tt.dst, tt.want) {
246 t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
247 }
248 }
249 }
250
View as plain text