1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package jsonclient
16
17 import (
18 "context"
19 "encoding/json"
20 "encoding/pem"
21 "fmt"
22 "net/http"
23 "net/http/httptest"
24 "reflect"
25 "strconv"
26 "strings"
27 "sync"
28 "testing"
29 "time"
30
31 "github.com/google/certificate-transparency-go/testdata"
32 )
33
34 func publicKeyPEMToDER(key string) []byte {
35 block, _ := pem.Decode([]byte(key))
36 if block == nil {
37 panic("failed to decode public key PEM")
38 }
39 if block.Type != "PUBLIC KEY" {
40 panic("PEM does not have type 'PUBLIC KEY'")
41 }
42 return block.Bytes
43 }
44
45 func TestNewJSONClient(t *testing.T) {
46 tests := []struct {
47 name string
48 opts Options
49 wantErr string
50 }{
51 {
52 name: "invalid PublicKey",
53 opts: Options{PublicKey: "bogus"},
54 wantErr: "no PEM block",
55 },
56 {
57 name: "invalid PublicKeyDER",
58 opts: Options{PublicKeyDER: []byte("bogus")},
59 wantErr: "asn1: structure error",
60 },
61 {
62 name: "RSA PublicKey",
63 opts: Options{PublicKey: testdata.RsaPublicKeyPEM},
64 },
65 {
66 name: "RSA PublicKeyDER",
67 opts: Options{PublicKeyDER: publicKeyPEMToDER(testdata.RsaPublicKeyPEM)},
68 },
69 {
70 name: "ECDSA PublicKey",
71 opts: Options{PublicKey: testdata.EcdsaPublicKeyPEM},
72 },
73 {
74 name: "ECDSA PublicKeyDER",
75 opts: Options{PublicKeyDER: publicKeyPEMToDER(testdata.EcdsaPublicKeyPEM)},
76 },
77 {
78 name: "DSA PublicKey",
79 opts: Options{PublicKey: testdata.DsaPublicKeyPEM},
80 wantErr: "unsupported public key type",
81 },
82 {
83 name: "DSA PublicKeyDER",
84 opts: Options{PublicKeyDER: publicKeyPEMToDER(testdata.DsaPublicKeyPEM)},
85 wantErr: "unsupported public key type",
86 },
87 {
88 name: "PublicKey contains trailing garbage",
89 opts: Options{PublicKey: testdata.RsaPublicKeyPEM + "bogus"},
90 wantErr: "extra data found",
91 },
92 {
93 name: "PublicKeyDER contains trailing garbage",
94 opts: Options{PublicKeyDER: append(publicKeyPEMToDER(testdata.RsaPublicKeyPEM), []byte("deadbeef")...)},
95 wantErr: "trailing data",
96 },
97 }
98 for _, test := range tests {
99 t.Run(test.name, func(t *testing.T) {
100 got, err := New("http://127.0.0.1", nil, test.opts)
101 if err != nil {
102 if len(test.wantErr) == 0 {
103 t.Errorf("New()=nil,%v; want _,nil", err)
104 } else if !strings.Contains(err.Error(), test.wantErr) {
105 t.Errorf("New()=nil,%v; want _, error containing %q", err, test.wantErr)
106 }
107 return
108 }
109 if len(test.wantErr) > 0 {
110 t.Errorf("New()=_,nil; want nil, error containing %q", test.wantErr)
111 }
112 if got == nil {
113 t.Errorf("New()=nil,nil; want non-nil,nil")
114 }
115 })
116 }
117 }
118
119 type TestStruct struct {
120 TreeSize int `json:"tree_size"`
121 Timestamp int `json:"timestamp"`
122 Data string `json:"data"`
123 }
124
125 type TestParams struct {
126 RespCode int `json:"rc"`
127 }
128
129 func MockServer(t *testing.T, failCount int, retryAfter int) *httptest.Server {
130 t.Helper()
131 mu := sync.Mutex{}
132 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133 mu.Lock()
134 defer mu.Unlock()
135 switch r.URL.Path {
136 case "/struct/path":
137 fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
138 case "/struct/params":
139 var s TestStruct
140 if r.Method == http.MethodGet {
141 s.TreeSize, _ = strconv.Atoi(r.FormValue("tree_size"))
142 s.Timestamp, _ = strconv.Atoi(r.FormValue("timestamp"))
143 s.Data = r.FormValue("data")
144 } else {
145 decoder := json.NewDecoder(r.Body)
146 err := decoder.Decode(&s)
147 if err != nil {
148 panic("Failed to decode: " + err.Error())
149 }
150 defer r.Body.Close()
151 }
152 fmt.Fprintf(w, `{"tree_size": %d, "timestamp": %d, "data": "%s"}`, s.TreeSize, s.Timestamp, s.Data)
153 case "/error":
154 var params TestParams
155 if r.Method == http.MethodGet {
156 params.RespCode, _ = strconv.Atoi(r.FormValue("rc"))
157 } else {
158 decoder := json.NewDecoder(r.Body)
159 err := decoder.Decode(¶ms)
160 if err != nil {
161 panic("Failed to decode: " + err.Error())
162 }
163 defer r.Body.Close()
164 }
165 http.Error(w, "error page", params.RespCode)
166 case "/malformed":
167 fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99`)
168 case "/retry":
169 if failCount > 0 {
170 failCount--
171 if retryAfter != 0 {
172 if retryAfter > 0 {
173 w.Header().Add("Retry-After", strconv.Itoa(retryAfter))
174 }
175 w.WriteHeader(http.StatusServiceUnavailable)
176 } else {
177 w.WriteHeader(http.StatusRequestTimeout)
178 }
179 } else {
180 fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
181 }
182 case "/retry-rfc1123":
183 if failCount > 0 {
184 failCount--
185 w.Header().Add("Retry-After", time.Now().Add(time.Duration(retryAfter)*time.Second).Format(time.RFC1123))
186 w.WriteHeader(http.StatusServiceUnavailable)
187 } else {
188 fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
189 }
190 case "/useragent/banana":
191 if got, want := r.Header.Get("User-Agent"), "banana"; got != want {
192 w.WriteHeader(400)
193 }
194 fmt.Fprintf(w, `{}`)
195 case "/useragent/none":
196 if got, want := r.Header.Get("User-Agent"), ""; got != want {
197 w.WriteHeader(400)
198 }
199 fmt.Fprintf(w, `{}`)
200 default:
201 t.Fatalf("Unhandled URL path: %s", r.URL.Path)
202 }
203 }))
204 }
205
206 func TestGetAndParse(t *testing.T) {
207 tests := []struct {
208 uri string
209 params map[string]string
210 wantStatus int
211 want TestStruct
212 wantErr string
213 ua string
214 }{
215 {uri: "/short%", wantErr: "invalid URL escape"},
216 {uri: "/malformed", wantStatus: http.StatusOK, wantErr: "unexpected EOF"},
217 {uri: "/error", params: map[string]string{"rc": "404"}, wantErr: "404 Not Found"},
218 {uri: "/error", params: map[string]string{"rc": "403"}, wantErr: "403 Forbidden"},
219 {uri: "/struct/path", wantStatus: http.StatusOK, want: TestStruct{11, 99, ""}},
220 {uri: "/useragent/banana", wantStatus: http.StatusOK, ua: "banana"},
221 {uri: "/useragent/banana", wantErr: "400 Bad Request", ua: "not-a-banana"},
222 {
223 uri: "/struct/params",
224 params: map[string]string{"tree_size": "42", "timestamp": "88", "data": "abcd"},
225 wantStatus: http.StatusOK,
226 want: TestStruct{42, 88, "abcd"},
227 },
228 }
229
230 ts := MockServer(t, -1, 0)
231 defer ts.Close()
232
233 ctx := context.Background()
234
235 for _, test := range tests {
236 logClient, err := New(ts.URL, &http.Client{}, Options{UserAgent: test.ua})
237 if err != nil {
238 t.Fatal(err)
239 }
240 var got TestStruct
241 httpRsp, body, err := logClient.GetAndParse(ctx, test.uri, test.params, &got)
242 var gotStatus int
243 if httpRsp != nil {
244 gotStatus = httpRsp.StatusCode
245 } else if rspErr, ok := err.(RspError); ok {
246 gotStatus = rspErr.StatusCode
247 }
248
249 if err != nil {
250 if len(test.wantErr) == 0 {
251 t.Errorf("GetAndParse(%q)=_,_,%q; want _, _, nil", test.uri, err.Error())
252 } else if !strings.Contains(err.Error(), test.wantErr) {
253 t.Errorf("GetAndParse(%q)=_,_,%q; want _, _, error containing %q", test.uri, err.Error(), test.wantErr)
254 }
255 continue
256 }
257
258 if len(test.wantErr) > 0 {
259 t.Errorf("GetAndParse(%q)=%+v,_,nil; want error matching %q", test.uri, got, test.wantErr)
260 }
261 if gotStatus != test.wantStatus {
262 t.Errorf("GetAndParse('%s') got status %d; want %d", test.uri, gotStatus, test.wantStatus)
263 }
264
265 if body == nil {
266 t.Errorf("GetAndParse(%q)=_,nil,_; want _,non-nil,_", test.uri)
267 }
268 if test.wantStatus == http.StatusOK {
269 if !reflect.DeepEqual(got, test.want) {
270 t.Errorf("GetAndParse(%q)=%+v,_,nil; want %+v", test.uri, got, test.want)
271 }
272 }
273 }
274 }
275
276 func TestPostAndParse(t *testing.T) {
277 tests := []struct {
278 uri string
279 request interface{}
280 wantStatus int
281 want TestStruct
282 wantErr string
283 ua string
284 }{
285 {uri: "/short%", wantErr: "invalid URL escape"},
286 {uri: "/struct/params", request: json.Number(`invalid`), wantErr: "invalid number literal"},
287 {uri: "/malformed", wantStatus: http.StatusOK, wantErr: "unexpected end of JSON"},
288 {uri: "/error", request: TestParams{RespCode: 404}, wantStatus: http.StatusNotFound},
289 {uri: "/error", request: TestParams{RespCode: 403}, wantStatus: http.StatusForbidden},
290 {uri: "/struct/path", wantStatus: http.StatusOK, want: TestStruct{11, 99, ""}},
291 {uri: "/useragent/banana", wantStatus: http.StatusOK, ua: "banana"},
292 {uri: "/useragent/banana", wantStatus: 400, ua: "not-a-banana"},
293 {
294 uri: "/struct/params",
295 wantStatus: http.StatusOK,
296 request: TestStruct{42, 88, "abcd"},
297 want: TestStruct{42, 88, "abcd"},
298 },
299 }
300
301 ts := MockServer(t, -1, 0)
302 defer ts.Close()
303
304 ctx := context.Background()
305
306 for _, test := range tests {
307 logClient, err := New(ts.URL, &http.Client{}, Options{UserAgent: test.ua})
308 if err != nil {
309 t.Fatal(err)
310 }
311 var got TestStruct
312 httpRsp, body, err := logClient.PostAndParse(ctx, test.uri, test.request, &got)
313 var gotStatus int
314 if httpRsp != nil {
315 gotStatus = httpRsp.StatusCode
316 } else if rspErr, ok := err.(RspError); ok {
317 gotStatus = rspErr.StatusCode
318 }
319
320 if err != nil {
321 if len(test.wantErr) == 0 {
322 t.Errorf("PostAndParse(%q)=_,_,%q; want _, _, nil", test.uri, err.Error())
323 } else if !strings.Contains(err.Error(), test.wantErr) {
324 t.Errorf("PostAndParse(%q)=nil,%q; want error matching %q", test.uri, err.Error(), test.wantErr)
325 }
326 continue
327 }
328
329 if len(test.wantErr) > 0 {
330 t.Errorf("PostAndParse(%q)=%+v,nil; want error matching %q", test.uri, got, test.wantErr)
331 }
332 if gotStatus != test.wantStatus {
333 t.Errorf("PostAndParse('%s') got status %d; want %d", test.uri, gotStatus, test.wantStatus)
334 }
335 if body == nil {
336 t.Errorf("PostAndParse(%q)=_,nil,_; want _,non-nil,_ ", test.uri)
337 }
338 if test.wantStatus == http.StatusOK {
339 if !reflect.DeepEqual(got, test.want) {
340 t.Errorf("PostAndParse(%q)=%+v,nil; want %+v", test.uri, got, test.want)
341 }
342 }
343 }
344 }
345
346
347 type mockBackoff struct {
348 override time.Duration
349 }
350
351 func (mb *mockBackoff) set(o *time.Duration) time.Duration {
352 if o != nil {
353 mb.override = *o
354 }
355 return 0
356 }
357 func (mb *mockBackoff) decreaseMultiplier() {}
358 func (mb *mockBackoff) until() time.Time { return time.Time{} }
359
360 func TestPostAndParseWithRetry(t *testing.T) {
361 tests := []struct {
362 uri string
363 request interface{}
364 deadlineSecs int
365 retryAfter int
366 failCount int
367 wantErr string
368 expectedBackoff time.Duration
369 }{
370 {
371 uri: "/error",
372 request: TestParams{RespCode: 418},
373 deadlineSecs: -1,
374 retryAfter: 0,
375 failCount: 0,
376 wantErr: "teapot",
377 expectedBackoff: 0,
378 },
379 {
380 uri: "/short%",
381 request: nil,
382 deadlineSecs: 0,
383 retryAfter: 0,
384 failCount: 0,
385 wantErr: "deadline exceeded",
386 expectedBackoff: 0,
387 },
388 {
389 uri: "/retry",
390 request: nil,
391 deadlineSecs: -1,
392 retryAfter: 0,
393 failCount: 1,
394 wantErr: "",
395 expectedBackoff: 0,
396 },
397 {
398 uri: "/retry",
399 request: nil,
400 deadlineSecs: -1,
401 retryAfter: 5,
402 failCount: 1,
403 wantErr: "",
404 expectedBackoff: 5 * time.Second,
405 },
406 {
407 uri: "/retry-rfc1123",
408 request: nil,
409 deadlineSecs: -1,
410 retryAfter: 5,
411 failCount: 1,
412 wantErr: "",
413 expectedBackoff: 5 * time.Second,
414 },
415 }
416 for _, test := range tests {
417 t.Run(test.uri, func(t *testing.T) {
418 ts := MockServer(t, test.failCount, test.retryAfter)
419 defer ts.Close()
420
421 logClient, err := New(ts.URL, &http.Client{}, Options{})
422 if err != nil {
423 t.Fatal(err)
424 }
425 mb := mockBackoff{}
426 logClient.backoff = &mb
427 ctx := context.Background()
428 if test.deadlineSecs >= 0 {
429 var cancel context.CancelFunc
430 ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(test.deadlineSecs)*time.Second))
431 defer cancel()
432 }
433
434 var got TestStruct
435 httpRsp, _, err := logClient.PostAndParseWithRetry(ctx, test.uri, test.request, &got)
436 if test.wantErr != "" {
437 if err == nil {
438 t.Errorf("PostAndParseWithRetry()=%+v,nil; want error %q", got, test.wantErr)
439 } else if !strings.Contains(err.Error(), test.wantErr) {
440 t.Errorf("PostAndParseWithRetry()=nil,%q; want error %q", err.Error(), test.wantErr)
441 } else if _, isRspError := err.(RspError); !isRspError && err != context.DeadlineExceeded {
442
443
444 t.Errorf("PostAndParseWithRetry()=%T; want jsonClient.RspError or context.DeadlineExceeded", err)
445 }
446 return
447 }
448 if err != nil {
449 t.Errorf("PostAndParseWithRetry()=nil,%q; want no error", err.Error())
450 } else if httpRsp.StatusCode != http.StatusOK {
451 t.Errorf("PostAndParseWithRetry() got status %d; want OK(404)", httpRsp.StatusCode)
452 }
453 if test.expectedBackoff > 0 && !fuzzyDurationEquals(test.expectedBackoff, mb.override, time.Second) {
454 t.Errorf("Unexpected backoff override set: got: %s, wanted: %s", mb.override, test.expectedBackoff)
455 }
456 })
457 }
458 }
459
460
461 func TestContextRequired(t *testing.T) {
462 ts := MockServer(t, -1, 0)
463 defer ts.Close()
464
465 logClient, err := New(ts.URL, &http.Client{}, Options{})
466 if err != nil {
467 t.Fatal(err)
468 }
469 var result TestStruct
470 _, _, err = logClient.GetAndParse(nil, "/struct/path", nil, &result)
471 if err == nil {
472 t.Errorf("GetAndParse() succeeded with empty Context")
473 }
474 _, _, err = logClient.PostAndParse(nil, "/struct/path", nil, &result)
475 if err == nil {
476 t.Errorf("PostAndParse() succeeded with empty Context")
477 }
478 _, _, err = logClient.PostAndParseWithRetry(nil, "/struct/path", nil, &result)
479 if err == nil {
480 t.Errorf("PostAndParseWithRetry() succeeded with empty Context")
481 }
482 }
483
484 func TestCancelledContext(t *testing.T) {
485 ts := MockServer(t, -1, 0)
486 defer ts.Close()
487 logClient, err := New(ts.URL, &http.Client{}, Options{})
488 if err != nil {
489 t.Fatal(err)
490 }
491 ctx, cancel := context.WithCancel(context.Background())
492 cancel()
493
494 var result TestStruct
495 _, _, err = logClient.GetAndParse(ctx, "/struct/path", nil, &result)
496 if err != context.Canceled {
497 t.Errorf("GetAndParse() = (_,_,%v), want %q", err, context.Canceled)
498 }
499 _, _, err = logClient.PostAndParse(ctx, "/struct/path", nil, &result)
500 if err != context.Canceled {
501 t.Errorf("PostAndParse() = (_,_,%v), want %q", err, context.Canceled)
502 }
503 _, _, err = logClient.PostAndParseWithRetry(ctx, "/struct/path", nil, &result)
504 if err != context.Canceled {
505 t.Errorf("PostAndParseWithRetry() = (_,_,%v), want %q", err, context.Canceled)
506 }
507 }
508
View as plain text