...

Source file src/github.com/google/certificate-transparency-go/fixchain/roundtrip_test.go

Documentation: github.com/google/certificate-transparency-go/fixchain

     1  // Copyright 2016 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package fixchain
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/base64"
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"strings"
    26  	"testing"
    27  
    28  	ct "github.com/google/certificate-transparency-go"
    29  	"github.com/google/certificate-transparency-go/testdata"
    30  	"github.com/google/certificate-transparency-go/tls"
    31  	"github.com/google/certificate-transparency-go/x509"
    32  )
    33  
    34  type testRoundTripper struct {
    35  	t         *testing.T
    36  	test      *fixAndLogTest
    37  	testIndex int
    38  	seen      []bool
    39  }
    40  
    41  func (rt testRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
    42  	url := fmt.Sprintf("%s://%s%s", request.URL.Scheme, request.URL.Host, request.URL.Path)
    43  	switch url {
    44  	case "https://ct.googleapis.com/pilot/ct/v1/get-roots":
    45  		b := stringRootsToJSON([]string{verisignRoot, testRoot})
    46  		return &http.Response{
    47  			Status:        "200 OK",
    48  			StatusCode:    200,
    49  			Proto:         request.Proto,
    50  			ProtoMajor:    request.ProtoMajor,
    51  			ProtoMinor:    request.ProtoMinor,
    52  			Body:          &bytesReadCloser{bytes.NewReader(b)},
    53  			ContentLength: int64(len(b)),
    54  			Request:       request,
    55  		}, nil
    56  	case "https://ct.googleapis.com/pilot/ct/v1/add-chain":
    57  		body, err := io.ReadAll(request.Body)
    58  		request.Body.Close()
    59  		if err != nil {
    60  			errStr := fmt.Sprintf("#%d: Could not read request body: %s", rt.testIndex, err.Error())
    61  			rt.t.Error(errStr)
    62  			return nil, errors.New(errStr)
    63  		}
    64  
    65  		type Chain struct {
    66  			Chain [][]byte
    67  		}
    68  		var chainBytes Chain
    69  		err = json.Unmarshal(body, &chainBytes)
    70  		if err != nil {
    71  			errStr := fmt.Sprintf("#%d: Could not unmarshal json: %s", rt.testIndex, err.Error())
    72  			rt.t.Error(errStr)
    73  			return nil, errors.New(errStr)
    74  		}
    75  		var chain []*x509.Certificate
    76  		for _, certBytes := range chainBytes.Chain {
    77  			cert, err := x509.ParseCertificate(certBytes)
    78  			if x509.IsFatal(err) {
    79  				errStr := fmt.Sprintf("#%d: Could not parse certificate: %s", rt.testIndex, err.Error())
    80  				rt.t.Error(errStr)
    81  				return nil, errors.New(errStr)
    82  			}
    83  			chain = append(chain, cert)
    84  		}
    85  
    86  	TryNextExpected:
    87  		for i, expChain := range rt.test.expLoggedChains {
    88  			if rt.seen[i] || len(chain) != len(expChain) {
    89  				continue
    90  			}
    91  			for j, cert := range chain {
    92  				if !strings.Contains(nameToKey(&cert.Subject), expChain[j]) {
    93  					continue TryNextExpected
    94  				}
    95  			}
    96  			rt.seen[i] = true
    97  			goto Return
    98  		}
    99  		rt.t.Errorf("#%d: Logged chain was not expected: %s", rt.testIndex, chainToDebugString(chain))
   100  	Return:
   101  		return &http.Response{
   102  			Status:        "200 OK",
   103  			StatusCode:    200,
   104  			Proto:         request.Proto,
   105  			ProtoMajor:    request.ProtoMajor,
   106  			ProtoMinor:    request.ProtoMinor,
   107  			Body:          &bytesReadCloser{bytes.NewReader(validAddChainRsp())},
   108  			ContentLength: 0,
   109  			Request:       request,
   110  		}, nil
   111  	default:
   112  		var cert string
   113  		switch url {
   114  		case "http://www.thawte.com/repository/Thawte_SGC_CA.crt":
   115  			cert = thawteIntermediate
   116  		case "http://crt.comodoca.com/EssentialSSLCA_2.crt":
   117  			cert = comodoIntermediate
   118  		case "http://crt.comodoca.com/ComodoUTNSGCCA.crt":
   119  			cert = comodoRoot
   120  		case "http://www.example.com/intermediate2.crt":
   121  			cert = testIntermediate2
   122  		case "http://www.example.com/intermediate1.crt":
   123  			cert = testIntermediate1
   124  		case "http://www.example.com/ca.crt":
   125  			cert = testRoot
   126  		case "http://www.example.com/a.crt":
   127  			cert = testA
   128  		case "http://www.example.com/b.crt":
   129  			cert = testB
   130  		default:
   131  			return nil, fmt.Errorf("can't reach url %s", url)
   132  		}
   133  
   134  		return &http.Response{
   135  			Status:        "200 OK",
   136  			StatusCode:    200,
   137  			Proto:         request.Proto,
   138  			ProtoMajor:    request.ProtoMajor,
   139  			ProtoMinor:    request.ProtoMinor,
   140  			Body:          &bytesReadCloser{bytes.NewReader([]byte(cert))},
   141  			ContentLength: int64(len([]byte(cert))),
   142  			Request:       request,
   143  		}, nil
   144  	}
   145  }
   146  
   147  // The round tripper used during testing of PostChainToLog() is used to check
   148  // that the http requests sent by PostChainToLog() contain the right information
   149  // for a Certificate Transparency log to be able to log the given chain
   150  // (assuming the chain is valid).
   151  type postTestRoundTripper struct {
   152  	t         *testing.T
   153  	test      *postTest
   154  	testIndex int
   155  }
   156  
   157  func (rt postTestRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
   158  	if strings.Contains(request.URL.Path, "/ct/v1/get-roots") {
   159  		b := stringRootsToJSON([]string{verisignRoot})
   160  		return &http.Response{
   161  			Status:        "200 OK",
   162  			StatusCode:    200,
   163  			Proto:         request.Proto,
   164  			ProtoMajor:    request.ProtoMajor,
   165  			ProtoMinor:    request.ProtoMinor,
   166  			Body:          &bytesReadCloser{bytes.NewReader(b)},
   167  			ContentLength: int64(len(b)),
   168  			Request:       request,
   169  		}, nil
   170  	}
   171  	// For tests that are checking the correct FixError type is returned:
   172  	if rt.test.ferr.Type == LogPostFailed {
   173  		return &http.Response{
   174  			Status:        "501 Not Implemented",
   175  			StatusCode:    501,
   176  			Proto:         request.Proto,
   177  			ProtoMajor:    request.ProtoMajor,
   178  			ProtoMinor:    request.ProtoMinor,
   179  			Body:          &bytesReadCloser{bytes.NewReader([]byte(""))},
   180  			ContentLength: 0,
   181  			Request:       request,
   182  		}, nil
   183  	}
   184  
   185  	// For tests to check request sent to log looks right:
   186  	// Check method used
   187  	if request.Method != "POST" {
   188  		rt.t.Errorf("#%d: expected request method to be POST, received %s", rt.testIndex, request.Method)
   189  	}
   190  
   191  	// Check URL
   192  	if request.URL.Scheme != rt.test.urlScheme {
   193  		rt.t.Errorf("#%d: Scheme: received %s, expected %s", rt.testIndex, request.URL.Scheme, rt.test.urlScheme)
   194  	}
   195  	if request.URL.Host != rt.test.urlHost {
   196  		rt.t.Errorf("#%d: Host: received %s, expected %s", rt.testIndex, request.URL.Host, rt.test.urlHost)
   197  	}
   198  	if request.URL.Path != rt.test.urlPath {
   199  		rt.t.Errorf("#%d: Path: received %s, expected %s", rt.testIndex, request.URL.Path, rt.test.urlPath)
   200  	}
   201  
   202  	// Check Body
   203  	body, err := io.ReadAll(request.Body)
   204  	request.Body.Close()
   205  	if err != nil {
   206  		errStr := fmt.Sprintf("#%d: Could not read request body: %s", rt.testIndex, err.Error())
   207  		rt.t.Error(errStr)
   208  		return nil, errors.New(errStr)
   209  	}
   210  
   211  	// Create string in the format that the Certificate Transparency logs expect
   212  	// the body of an add-chain request to be in.
   213  	var encode = base64.StdEncoding.EncodeToString
   214  	expStr := "{\"chain\":"
   215  	if rt.test.chain == nil {
   216  		expStr += "null"
   217  	} else {
   218  		expStr += "["
   219  		for i, cert := range rt.test.chain {
   220  			expStr += "\"" + encode(GetTestCertificateFromPEM(rt.t, cert).Raw) + "\""
   221  			if i != len(rt.test.chain)-1 {
   222  				expStr += ","
   223  			}
   224  		}
   225  		expStr += "]"
   226  	}
   227  	expStr += "}"
   228  
   229  	if string(body) != expStr {
   230  		rt.t.Errorf("#%d: incorrect format of request body.  Received %s, expected %s", rt.testIndex, string(body), expStr)
   231  	}
   232  
   233  	rspData := []byte("")
   234  	if strings.Contains(request.URL.Path, "/ct/v1/add-chain") {
   235  		rspData = validAddChainRsp()
   236  	}
   237  
   238  	// Return a response
   239  	return &http.Response{
   240  		Status:        "200 OK",
   241  		StatusCode:    200,
   242  		Proto:         request.Proto,
   243  		ProtoMajor:    request.ProtoMajor,
   244  		ProtoMinor:    request.ProtoMinor,
   245  		Body:          &bytesReadCloser{bytes.NewReader(rspData)},
   246  		ContentLength: 0,
   247  		Request:       request,
   248  	}, nil
   249  }
   250  
   251  func validAddChainRsp() []byte {
   252  	var sct ct.SignedCertificateTimestamp
   253  	_, err := tls.Unmarshal(testdata.TestCertProof, &sct)
   254  	if err != nil {
   255  		panic(fmt.Sprintf("failed to tls-unmarshal test certificate proof: %v", err))
   256  	}
   257  	sig, err := tls.Marshal(sct.Signature)
   258  	if err != nil {
   259  		panic(fmt.Sprintf("failed to marshal signature: %v", err))
   260  	}
   261  	rsp := ct.AddChainResponse{
   262  		SCTVersion: sct.SCTVersion,
   263  		Timestamp:  sct.Timestamp,
   264  		ID:         sct.LogID.KeyID[:],
   265  		Extensions: base64.StdEncoding.EncodeToString(sct.Extensions),
   266  		Signature:  sig,
   267  	}
   268  	rspData, err := json.Marshal(rsp)
   269  	if err != nil {
   270  		panic(fmt.Sprintf("failed to json-marshal test certificate proof: %v", err))
   271  	}
   272  	return rspData
   273  }
   274  
   275  type newLoggerTestRoundTripper struct{}
   276  
   277  func (rt newLoggerTestRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
   278  	// Return a response
   279  	b := validAddChainRsp()
   280  	return &http.Response{
   281  		Status:        "200 OK",
   282  		StatusCode:    200,
   283  		Proto:         request.Proto,
   284  		ProtoMajor:    request.ProtoMajor,
   285  		ProtoMinor:    request.ProtoMinor,
   286  		Body:          &bytesReadCloser{bytes.NewReader(b)},
   287  		ContentLength: int64(len(b)),
   288  		Request:       request,
   289  	}, nil
   290  }
   291  
   292  type rootCertsTestRoundTripper struct{}
   293  
   294  func (rt rootCertsTestRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
   295  	url := fmt.Sprintf("%s://%s%s", request.URL.Scheme, request.URL.Host, request.URL.Path)
   296  	if url == "https://ct.googleapis.com/pilot/ct/v1/get-roots" {
   297  		b := stringRootsToJSON([]string{verisignRoot, comodoRoot})
   298  		return &http.Response{
   299  			Status:        "200 OK",
   300  			StatusCode:    200,
   301  			Proto:         request.Proto,
   302  			ProtoMajor:    request.ProtoMajor,
   303  			ProtoMinor:    request.ProtoMinor,
   304  			Body:          &bytesReadCloser{bytes.NewReader(b)},
   305  			ContentLength: int64(len(b)),
   306  			Request:       request,
   307  		}, nil
   308  	}
   309  	return nil, errors.New("")
   310  }
   311  

View as plain text