...

Source file src/github.com/google/certificate-transparency-go/jsonclient/client_test.go

Documentation: github.com/google/certificate-transparency-go/jsonclient

     1  // Copyright 2016 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package jsonclient
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"encoding/pem"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"reflect"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/google/certificate-transparency-go/testdata"
    32  )
    33  
    34  func publicKeyPEMToDER(key string) []byte {
    35  	block, _ := pem.Decode([]byte(key))
    36  	if block == nil {
    37  		panic("failed to decode public key PEM")
    38  	}
    39  	if block.Type != "PUBLIC KEY" {
    40  		panic("PEM does not have type 'PUBLIC KEY'")
    41  	}
    42  	return block.Bytes
    43  }
    44  
    45  func TestNewJSONClient(t *testing.T) {
    46  	tests := []struct {
    47  		name    string
    48  		opts    Options
    49  		wantErr string
    50  	}{
    51  		{
    52  			name:    "invalid PublicKey",
    53  			opts:    Options{PublicKey: "bogus"},
    54  			wantErr: "no PEM block",
    55  		},
    56  		{
    57  			name:    "invalid PublicKeyDER",
    58  			opts:    Options{PublicKeyDER: []byte("bogus")},
    59  			wantErr: "asn1: structure error",
    60  		},
    61  		{
    62  			name: "RSA PublicKey",
    63  			opts: Options{PublicKey: testdata.RsaPublicKeyPEM},
    64  		},
    65  		{
    66  			name: "RSA PublicKeyDER",
    67  			opts: Options{PublicKeyDER: publicKeyPEMToDER(testdata.RsaPublicKeyPEM)},
    68  		},
    69  		{
    70  			name: "ECDSA PublicKey",
    71  			opts: Options{PublicKey: testdata.EcdsaPublicKeyPEM},
    72  		},
    73  		{
    74  			name: "ECDSA PublicKeyDER",
    75  			opts: Options{PublicKeyDER: publicKeyPEMToDER(testdata.EcdsaPublicKeyPEM)},
    76  		},
    77  		{
    78  			name:    "DSA PublicKey",
    79  			opts:    Options{PublicKey: testdata.DsaPublicKeyPEM},
    80  			wantErr: "unsupported public key type",
    81  		},
    82  		{
    83  			name:    "DSA PublicKeyDER",
    84  			opts:    Options{PublicKeyDER: publicKeyPEMToDER(testdata.DsaPublicKeyPEM)},
    85  			wantErr: "unsupported public key type",
    86  		},
    87  		{
    88  			name:    "PublicKey contains trailing garbage",
    89  			opts:    Options{PublicKey: testdata.RsaPublicKeyPEM + "bogus"},
    90  			wantErr: "extra data found",
    91  		},
    92  		{
    93  			name:    "PublicKeyDER contains trailing garbage",
    94  			opts:    Options{PublicKeyDER: append(publicKeyPEMToDER(testdata.RsaPublicKeyPEM), []byte("deadbeef")...)},
    95  			wantErr: "trailing data",
    96  		},
    97  	}
    98  	for _, test := range tests {
    99  		t.Run(test.name, func(t *testing.T) {
   100  			got, err := New("http://127.0.0.1", nil, test.opts)
   101  			if err != nil {
   102  				if len(test.wantErr) == 0 {
   103  					t.Errorf("New()=nil,%v; want _,nil", err)
   104  				} else if !strings.Contains(err.Error(), test.wantErr) {
   105  					t.Errorf("New()=nil,%v; want _, error containing %q", err, test.wantErr)
   106  				}
   107  				return
   108  			}
   109  			if len(test.wantErr) > 0 {
   110  				t.Errorf("New()=_,nil; want nil, error containing %q", test.wantErr)
   111  			}
   112  			if got == nil {
   113  				t.Errorf("New()=nil,nil; want non-nil,nil")
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  type TestStruct struct {
   120  	TreeSize  int    `json:"tree_size"`
   121  	Timestamp int    `json:"timestamp"`
   122  	Data      string `json:"data"`
   123  }
   124  
   125  type TestParams struct {
   126  	RespCode int `json:"rc"`
   127  }
   128  
   129  func MockServer(t *testing.T, failCount int, retryAfter int) *httptest.Server {
   130  	t.Helper()
   131  	mu := sync.Mutex{}
   132  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   133  		mu.Lock()
   134  		defer mu.Unlock()
   135  		switch r.URL.Path {
   136  		case "/struct/path":
   137  			fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
   138  		case "/struct/params":
   139  			var s TestStruct
   140  			if r.Method == http.MethodGet {
   141  				s.TreeSize, _ = strconv.Atoi(r.FormValue("tree_size"))
   142  				s.Timestamp, _ = strconv.Atoi(r.FormValue("timestamp"))
   143  				s.Data = r.FormValue("data")
   144  			} else {
   145  				decoder := json.NewDecoder(r.Body)
   146  				err := decoder.Decode(&s)
   147  				if err != nil {
   148  					panic("Failed to decode: " + err.Error())
   149  				}
   150  				defer r.Body.Close()
   151  			}
   152  			fmt.Fprintf(w, `{"tree_size": %d, "timestamp": %d, "data": "%s"}`, s.TreeSize, s.Timestamp, s.Data)
   153  		case "/error":
   154  			var params TestParams
   155  			if r.Method == http.MethodGet {
   156  				params.RespCode, _ = strconv.Atoi(r.FormValue("rc"))
   157  			} else {
   158  				decoder := json.NewDecoder(r.Body)
   159  				err := decoder.Decode(&params)
   160  				if err != nil {
   161  					panic("Failed to decode: " + err.Error())
   162  				}
   163  				defer r.Body.Close()
   164  			}
   165  			http.Error(w, "error page", params.RespCode)
   166  		case "/malformed":
   167  			fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99`) // no closing }
   168  		case "/retry":
   169  			if failCount > 0 {
   170  				failCount--
   171  				if retryAfter != 0 {
   172  					if retryAfter > 0 {
   173  						w.Header().Add("Retry-After", strconv.Itoa(retryAfter))
   174  					}
   175  					w.WriteHeader(http.StatusServiceUnavailable)
   176  				} else {
   177  					w.WriteHeader(http.StatusRequestTimeout)
   178  				}
   179  			} else {
   180  				fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
   181  			}
   182  		case "/retry-rfc1123":
   183  			if failCount > 0 {
   184  				failCount--
   185  				w.Header().Add("Retry-After", time.Now().Add(time.Duration(retryAfter)*time.Second).Format(time.RFC1123))
   186  				w.WriteHeader(http.StatusServiceUnavailable)
   187  			} else {
   188  				fmt.Fprintf(w, `{"tree_size": 11, "timestamp": 99}`)
   189  			}
   190  		case "/useragent/banana":
   191  			if got, want := r.Header.Get("User-Agent"), "banana"; got != want {
   192  				w.WriteHeader(400)
   193  			}
   194  			fmt.Fprintf(w, `{}`)
   195  		case "/useragent/none":
   196  			if got, want := r.Header.Get("User-Agent"), ""; got != want {
   197  				w.WriteHeader(400)
   198  			}
   199  			fmt.Fprintf(w, `{}`)
   200  		default:
   201  			t.Fatalf("Unhandled URL path: %s", r.URL.Path)
   202  		}
   203  	}))
   204  }
   205  
   206  func TestGetAndParse(t *testing.T) {
   207  	tests := []struct {
   208  		uri        string
   209  		params     map[string]string
   210  		wantStatus int
   211  		want       TestStruct
   212  		wantErr    string
   213  		ua         string
   214  	}{
   215  		{uri: "/short%", wantErr: "invalid URL escape"},
   216  		{uri: "/malformed", wantStatus: http.StatusOK, wantErr: "unexpected EOF"},
   217  		{uri: "/error", params: map[string]string{"rc": "404"}, wantErr: "404 Not Found"},
   218  		{uri: "/error", params: map[string]string{"rc": "403"}, wantErr: "403 Forbidden"},
   219  		{uri: "/struct/path", wantStatus: http.StatusOK, want: TestStruct{11, 99, ""}},
   220  		{uri: "/useragent/banana", wantStatus: http.StatusOK, ua: "banana"},
   221  		{uri: "/useragent/banana", wantErr: "400 Bad Request", ua: "not-a-banana"},
   222  		{
   223  			uri:        "/struct/params",
   224  			params:     map[string]string{"tree_size": "42", "timestamp": "88", "data": "abcd"},
   225  			wantStatus: http.StatusOK,
   226  			want:       TestStruct{42, 88, "abcd"},
   227  		},
   228  	}
   229  
   230  	ts := MockServer(t, -1, 0)
   231  	defer ts.Close()
   232  
   233  	ctx := context.Background()
   234  
   235  	for _, test := range tests {
   236  		logClient, err := New(ts.URL, &http.Client{}, Options{UserAgent: test.ua})
   237  		if err != nil {
   238  			t.Fatal(err)
   239  		}
   240  		var got TestStruct
   241  		httpRsp, body, err := logClient.GetAndParse(ctx, test.uri, test.params, &got)
   242  		var gotStatus int
   243  		if httpRsp != nil {
   244  			gotStatus = httpRsp.StatusCode
   245  		} else if rspErr, ok := err.(RspError); ok {
   246  			gotStatus = rspErr.StatusCode
   247  		}
   248  
   249  		if err != nil {
   250  			if len(test.wantErr) == 0 {
   251  				t.Errorf("GetAndParse(%q)=_,_,%q; want _, _, nil", test.uri, err.Error())
   252  			} else if !strings.Contains(err.Error(), test.wantErr) {
   253  				t.Errorf("GetAndParse(%q)=_,_,%q; want _, _, error containing %q", test.uri, err.Error(), test.wantErr)
   254  			}
   255  			continue
   256  		}
   257  
   258  		if len(test.wantErr) > 0 {
   259  			t.Errorf("GetAndParse(%q)=%+v,_,nil; want error matching %q", test.uri, got, test.wantErr)
   260  		}
   261  		if gotStatus != test.wantStatus {
   262  			t.Errorf("GetAndParse('%s') got status %d; want %d", test.uri, gotStatus, test.wantStatus)
   263  		}
   264  
   265  		if body == nil {
   266  			t.Errorf("GetAndParse(%q)=_,nil,_; want _,non-nil,_", test.uri)
   267  		}
   268  		if test.wantStatus == http.StatusOK {
   269  			if !reflect.DeepEqual(got, test.want) {
   270  				t.Errorf("GetAndParse(%q)=%+v,_,nil; want %+v", test.uri, got, test.want)
   271  			}
   272  		}
   273  	}
   274  }
   275  
   276  func TestPostAndParse(t *testing.T) {
   277  	tests := []struct {
   278  		uri        string
   279  		request    interface{}
   280  		wantStatus int
   281  		want       TestStruct
   282  		wantErr    string
   283  		ua         string
   284  	}{
   285  		{uri: "/short%", wantErr: "invalid URL escape"},
   286  		{uri: "/struct/params", request: json.Number(`invalid`), wantErr: "invalid number literal"},
   287  		{uri: "/malformed", wantStatus: http.StatusOK, wantErr: "unexpected end of JSON"},
   288  		{uri: "/error", request: TestParams{RespCode: 404}, wantStatus: http.StatusNotFound},
   289  		{uri: "/error", request: TestParams{RespCode: 403}, wantStatus: http.StatusForbidden},
   290  		{uri: "/struct/path", wantStatus: http.StatusOK, want: TestStruct{11, 99, ""}},
   291  		{uri: "/useragent/banana", wantStatus: http.StatusOK, ua: "banana"},
   292  		{uri: "/useragent/banana", wantStatus: 400, ua: "not-a-banana"},
   293  		{
   294  			uri:        "/struct/params",
   295  			wantStatus: http.StatusOK,
   296  			request:    TestStruct{42, 88, "abcd"},
   297  			want:       TestStruct{42, 88, "abcd"},
   298  		},
   299  	}
   300  
   301  	ts := MockServer(t, -1, 0)
   302  	defer ts.Close()
   303  
   304  	ctx := context.Background()
   305  
   306  	for _, test := range tests {
   307  		logClient, err := New(ts.URL, &http.Client{}, Options{UserAgent: test.ua})
   308  		if err != nil {
   309  			t.Fatal(err)
   310  		}
   311  		var got TestStruct
   312  		httpRsp, body, err := logClient.PostAndParse(ctx, test.uri, test.request, &got)
   313  		var gotStatus int
   314  		if httpRsp != nil {
   315  			gotStatus = httpRsp.StatusCode
   316  		} else if rspErr, ok := err.(RspError); ok {
   317  			gotStatus = rspErr.StatusCode
   318  		}
   319  
   320  		if err != nil {
   321  			if len(test.wantErr) == 0 {
   322  				t.Errorf("PostAndParse(%q)=_,_,%q; want _, _, nil", test.uri, err.Error())
   323  			} else if !strings.Contains(err.Error(), test.wantErr) {
   324  				t.Errorf("PostAndParse(%q)=nil,%q; want error matching %q", test.uri, err.Error(), test.wantErr)
   325  			}
   326  			continue
   327  		}
   328  
   329  		if len(test.wantErr) > 0 {
   330  			t.Errorf("PostAndParse(%q)=%+v,nil; want error matching %q", test.uri, got, test.wantErr)
   331  		}
   332  		if gotStatus != test.wantStatus {
   333  			t.Errorf("PostAndParse('%s') got status %d; want %d", test.uri, gotStatus, test.wantStatus)
   334  		}
   335  		if body == nil {
   336  			t.Errorf("PostAndParse(%q)=_,nil,_; want _,non-nil,_ ", test.uri)
   337  		}
   338  		if test.wantStatus == http.StatusOK {
   339  			if !reflect.DeepEqual(got, test.want) {
   340  				t.Errorf("PostAndParse(%q)=%+v,nil; want %+v", test.uri, got, test.want)
   341  			}
   342  		}
   343  	}
   344  }
   345  
   346  // mockBackoff is not safe for concurrent usage
   347  type mockBackoff struct {
   348  	override time.Duration
   349  }
   350  
   351  func (mb *mockBackoff) set(o *time.Duration) time.Duration {
   352  	if o != nil {
   353  		mb.override = *o
   354  	}
   355  	return 0
   356  }
   357  func (mb *mockBackoff) decreaseMultiplier() {}
   358  func (mb *mockBackoff) until() time.Time    { return time.Time{} }
   359  
   360  func TestPostAndParseWithRetry(t *testing.T) {
   361  	tests := []struct {
   362  		uri             string
   363  		request         interface{}
   364  		deadlineSecs    int // -1 indicates no deadline
   365  		retryAfter      int // -1 indicates generate 503 with no Retry-After
   366  		failCount       int
   367  		wantErr         string
   368  		expectedBackoff time.Duration // 0 indicates no expected backoff override set
   369  	}{
   370  		{
   371  			uri:             "/error",
   372  			request:         TestParams{RespCode: 418},
   373  			deadlineSecs:    -1,
   374  			retryAfter:      0,
   375  			failCount:       0,
   376  			wantErr:         "teapot",
   377  			expectedBackoff: 0,
   378  		},
   379  		{
   380  			uri:             "/short%",
   381  			request:         nil,
   382  			deadlineSecs:    0,
   383  			retryAfter:      0,
   384  			failCount:       0,
   385  			wantErr:         "deadline exceeded",
   386  			expectedBackoff: 0,
   387  		},
   388  		{
   389  			uri:             "/retry",
   390  			request:         nil,
   391  			deadlineSecs:    -1,
   392  			retryAfter:      0,
   393  			failCount:       1,
   394  			wantErr:         "",
   395  			expectedBackoff: 0,
   396  		},
   397  		{
   398  			uri:             "/retry",
   399  			request:         nil,
   400  			deadlineSecs:    -1,
   401  			retryAfter:      5,
   402  			failCount:       1,
   403  			wantErr:         "",
   404  			expectedBackoff: 5 * time.Second,
   405  		},
   406  		{
   407  			uri:             "/retry-rfc1123",
   408  			request:         nil,
   409  			deadlineSecs:    -1,
   410  			retryAfter:      5,
   411  			failCount:       1,
   412  			wantErr:         "",
   413  			expectedBackoff: 5 * time.Second,
   414  		},
   415  	}
   416  	for _, test := range tests {
   417  		t.Run(test.uri, func(t *testing.T) {
   418  			ts := MockServer(t, test.failCount, test.retryAfter)
   419  			defer ts.Close()
   420  
   421  			logClient, err := New(ts.URL, &http.Client{}, Options{})
   422  			if err != nil {
   423  				t.Fatal(err)
   424  			}
   425  			mb := mockBackoff{}
   426  			logClient.backoff = &mb
   427  			ctx := context.Background()
   428  			if test.deadlineSecs >= 0 {
   429  				var cancel context.CancelFunc
   430  				ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(test.deadlineSecs)*time.Second))
   431  				defer cancel()
   432  			}
   433  
   434  			var got TestStruct
   435  			httpRsp, _, err := logClient.PostAndParseWithRetry(ctx, test.uri, test.request, &got)
   436  			if test.wantErr != "" {
   437  				if err == nil {
   438  					t.Errorf("PostAndParseWithRetry()=%+v,nil; want error %q", got, test.wantErr)
   439  				} else if !strings.Contains(err.Error(), test.wantErr) {
   440  					t.Errorf("PostAndParseWithRetry()=nil,%q; want error %q", err.Error(), test.wantErr)
   441  				} else if _, isRspError := err.(RspError); !isRspError && err != context.DeadlineExceeded {
   442  					// We expect all non-nil errors to be either a RspError instance or to
   443  					// be the context DeadlineExceeded error.
   444  					t.Errorf("PostAndParseWithRetry()=%T; want jsonClient.RspError or context.DeadlineExceeded", err)
   445  				}
   446  				return
   447  			}
   448  			if err != nil {
   449  				t.Errorf("PostAndParseWithRetry()=nil,%q; want no error", err.Error())
   450  			} else if httpRsp.StatusCode != http.StatusOK {
   451  				t.Errorf("PostAndParseWithRetry() got status %d; want OK(404)", httpRsp.StatusCode)
   452  			}
   453  			if test.expectedBackoff > 0 && !fuzzyDurationEquals(test.expectedBackoff, mb.override, time.Second) {
   454  				t.Errorf("Unexpected backoff override set: got: %s, wanted: %s", mb.override, test.expectedBackoff)
   455  			}
   456  		})
   457  	}
   458  }
   459  
   460  // nolint:staticcheck
   461  func TestContextRequired(t *testing.T) {
   462  	ts := MockServer(t, -1, 0)
   463  	defer ts.Close()
   464  
   465  	logClient, err := New(ts.URL, &http.Client{}, Options{})
   466  	if err != nil {
   467  		t.Fatal(err)
   468  	}
   469  	var result TestStruct
   470  	_, _, err = logClient.GetAndParse(nil, "/struct/path", nil, &result)
   471  	if err == nil {
   472  		t.Errorf("GetAndParse() succeeded with empty Context")
   473  	}
   474  	_, _, err = logClient.PostAndParse(nil, "/struct/path", nil, &result)
   475  	if err == nil {
   476  		t.Errorf("PostAndParse() succeeded with empty Context")
   477  	}
   478  	_, _, err = logClient.PostAndParseWithRetry(nil, "/struct/path", nil, &result)
   479  	if err == nil {
   480  		t.Errorf("PostAndParseWithRetry() succeeded with empty Context")
   481  	}
   482  }
   483  
   484  func TestCancelledContext(t *testing.T) {
   485  	ts := MockServer(t, -1, 0)
   486  	defer ts.Close()
   487  	logClient, err := New(ts.URL, &http.Client{}, Options{})
   488  	if err != nil {
   489  		t.Fatal(err)
   490  	}
   491  	ctx, cancel := context.WithCancel(context.Background())
   492  	cancel()
   493  
   494  	var result TestStruct
   495  	_, _, err = logClient.GetAndParse(ctx, "/struct/path", nil, &result)
   496  	if err != context.Canceled {
   497  		t.Errorf("GetAndParse() = (_,_,%v), want %q", err, context.Canceled)
   498  	}
   499  	_, _, err = logClient.PostAndParse(ctx, "/struct/path", nil, &result)
   500  	if err != context.Canceled {
   501  		t.Errorf("PostAndParse() = (_,_,%v), want %q", err, context.Canceled)
   502  	}
   503  	_, _, err = logClient.PostAndParseWithRetry(ctx, "/struct/path", nil, &result)
   504  	if err != context.Canceled {
   505  		t.Errorf("PostAndParseWithRetry() = (_,_,%v), want %q", err, context.Canceled)
   506  	}
   507  }
   508  

View as plain text