...

Source file src/github.com/google/certificate-transparency-go/ctutil/ctutil_test.go

Documentation: github.com/google/certificate-transparency-go/ctutil

     1  // Copyright 2018 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ctutil
    16  
    17  import (
    18  	"encoding/base64"
    19  	"testing"
    20  
    21  	ct "github.com/google/certificate-transparency-go"
    22  	"github.com/google/certificate-transparency-go/testdata"
    23  	"github.com/google/certificate-transparency-go/tls"
    24  	"github.com/google/certificate-transparency-go/x509util"
    25  )
    26  
    27  func TestLeafHash(t *testing.T) {
    28  	tests := []struct {
    29  		desc     string
    30  		chainPEM string
    31  		sct      []byte
    32  		embedded bool
    33  		want     string
    34  	}{
    35  		{
    36  			desc:     "cert",
    37  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
    38  			sct:      testdata.TestCertProof,
    39  			want:     testdata.TestCertB64LeafHash,
    40  		},
    41  		{
    42  			desc:     "precert",
    43  			chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
    44  			sct:      testdata.TestPreCertProof,
    45  			want:     testdata.TestPreCertB64LeafHash,
    46  		},
    47  		{
    48  			desc:     "cert with embedded SCT",
    49  			chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
    50  			sct:      testdata.TestPreCertProof,
    51  			embedded: true,
    52  			want:     testdata.TestPreCertB64LeafHash,
    53  		},
    54  	}
    55  
    56  	for _, test := range tests {
    57  		t.Run(test.desc, func(t *testing.T) {
    58  			// Parse chain
    59  			chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
    60  			if err != nil {
    61  				t.Fatalf("error parsing certificate chain: %s", err)
    62  			}
    63  
    64  			// Parse SCT
    65  			var sct ct.SignedCertificateTimestamp
    66  			if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
    67  				t.Fatalf("error tls-unmarshalling sct: %s", err)
    68  			}
    69  
    70  			// Test LeafHash()
    71  			wantSl, err := base64.StdEncoding.DecodeString(test.want)
    72  			if err != nil {
    73  				t.Fatalf("error base64-decoding leaf hash %q: %s", test.want, err)
    74  			}
    75  			var want [32]byte
    76  			copy(want[:], wantSl)
    77  
    78  			got, err := LeafHash(chain, &sct, test.embedded)
    79  			if got != want || err != nil {
    80  				t.Errorf("LeafHash(_,_) = %v, %v, want %v, nil", got, err, want)
    81  			}
    82  
    83  			// Test LeafHashB64()
    84  			gotB64, err := LeafHashB64(chain, &sct, test.embedded)
    85  			if gotB64 != test.want || err != nil {
    86  				t.Errorf("LeafHashB64(_,_) = %v, %v, want %v, nil", gotB64, err, test.want)
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func TestLeafHashErrors(t *testing.T) {
    93  	tests := []struct {
    94  		desc     string
    95  		chainPEM string
    96  		sct      []byte
    97  		embedded bool
    98  	}{
    99  		{
   100  			desc:     "empty chain",
   101  			chainPEM: "",
   102  			sct:      testdata.TestCertProof,
   103  		},
   104  		{
   105  			desc:     "nil SCT",
   106  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
   107  			sct:      nil,
   108  		},
   109  		{
   110  			desc:     "no SCTs embedded in cert, embedded true",
   111  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
   112  			sct:      testdata.TestInvalidProof,
   113  			embedded: true,
   114  		},
   115  		{
   116  			desc:     "cert contains embedded SCTs, but not the SCT provided",
   117  			chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
   118  			sct:      testdata.TestInvalidProof,
   119  			embedded: true,
   120  		},
   121  	}
   122  
   123  	for _, test := range tests {
   124  		t.Run(test.desc, func(t *testing.T) {
   125  			// Parse chain
   126  			chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
   127  			if err != nil {
   128  				t.Fatalf("error parsing certificate chain: %s", err)
   129  			}
   130  
   131  			// Parse SCT
   132  			var sct *ct.SignedCertificateTimestamp
   133  			if test.sct != nil {
   134  				sct = &ct.SignedCertificateTimestamp{}
   135  				if _, err = tls.Unmarshal(test.sct, sct); err != nil {
   136  					t.Fatalf("error tls-unmarshalling sct: %s", err)
   137  				}
   138  			}
   139  
   140  			// Test LeafHash()
   141  			got, err := LeafHash(chain, sct, test.embedded)
   142  			if got != emptyHash || err == nil {
   143  				t.Errorf("LeafHash(_,_) = %s, %v, want %v, error", got, err, emptyHash)
   144  			}
   145  
   146  			// Test LeafHashB64()
   147  			gotB64, err := LeafHashB64(chain, sct, test.embedded)
   148  			if gotB64 != "" || err == nil {
   149  				t.Errorf("LeafHashB64(_,_) = %s, %v, want \"\", error", gotB64, err)
   150  			}
   151  		})
   152  	}
   153  }
   154  
   155  func TestVerifySCT(t *testing.T) {
   156  	tests := []struct {
   157  		desc     string
   158  		chainPEM string
   159  		sct      []byte
   160  		embedded bool
   161  		wantErr  bool
   162  	}{
   163  		{
   164  			desc:     "cert",
   165  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
   166  			sct:      testdata.TestCertProof,
   167  		},
   168  		{
   169  			desc:     "precert",
   170  			chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
   171  			sct:      testdata.TestPreCertProof,
   172  		},
   173  		{
   174  			desc:     "invalid SCT",
   175  			chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
   176  			sct:      testdata.TestCertProof,
   177  			wantErr:  true,
   178  		},
   179  		{
   180  			desc:     "cert with embedded SCT",
   181  			chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
   182  			sct:      testdata.TestPreCertProof,
   183  			embedded: true,
   184  		},
   185  		{
   186  			desc:     "cert with invalid embedded SCT",
   187  			chainPEM: testdata.TestInvalidEmbeddedCertPEM + testdata.CACertPEM,
   188  			sct:      testdata.TestInvalidProof,
   189  			embedded: true,
   190  			wantErr:  true,
   191  		},
   192  	}
   193  
   194  	for _, test := range tests {
   195  		t.Run(test.desc, func(t *testing.T) {
   196  			// Parse chain
   197  			chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
   198  			if err != nil {
   199  				t.Fatalf("error parsing certificate chain: %s", err)
   200  			}
   201  
   202  			// Parse SCT
   203  			var sct ct.SignedCertificateTimestamp
   204  			if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
   205  				t.Fatalf("error tls-unmarshalling sct: %s", err)
   206  			}
   207  
   208  			// Test VerifySCT()
   209  			pk, err := ct.PublicKeyFromB64(testdata.LogPublicKeyB64)
   210  			if err != nil {
   211  				t.Errorf("error parsing public key: %s", err)
   212  			}
   213  
   214  			err = VerifySCT(pk, chain, &sct, test.embedded)
   215  			if gotErr := err != nil; gotErr != test.wantErr {
   216  				t.Errorf("VerifySCT(_,_,_, %t) = %v, want error? %t", test.embedded, err, test.wantErr)
   217  			}
   218  		})
   219  	}
   220  }
   221  
   222  func TestVerifySCTWithVerifier(t *testing.T) {
   223  	// Parse public key
   224  	pk, err := ct.PublicKeyFromB64(testdata.LogPublicKeyB64)
   225  	if err != nil {
   226  		t.Errorf("error parsing public key: %s", err)
   227  	}
   228  
   229  	// Create signature verifier
   230  	sv, err := ct.NewSignatureVerifier(pk)
   231  	if err != nil {
   232  		t.Errorf("couldn't create signature verifier: %s", err)
   233  	}
   234  
   235  	tests := []struct {
   236  		desc     string
   237  		sv       *ct.SignatureVerifier
   238  		chainPEM string
   239  		sct      []byte
   240  		embedded bool
   241  		wantErr  bool
   242  	}{
   243  		{
   244  			desc:     "nil signature verifier",
   245  			sv:       nil,
   246  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
   247  			sct:      testdata.TestCertProof,
   248  			wantErr:  true,
   249  		},
   250  		{
   251  			desc:     "cert",
   252  			sv:       sv,
   253  			chainPEM: testdata.TestCertPEM + testdata.CACertPEM,
   254  			sct:      testdata.TestCertProof,
   255  		},
   256  		{
   257  			desc:     "precert",
   258  			sv:       sv,
   259  			chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
   260  			sct:      testdata.TestPreCertProof,
   261  		},
   262  		{
   263  			desc:     "invalid SCT",
   264  			sv:       sv,
   265  			chainPEM: testdata.TestPreCertPEM + testdata.CACertPEM,
   266  			sct:      testdata.TestCertProof,
   267  			wantErr:  true,
   268  		},
   269  		{
   270  			desc:     "cert with embedded SCT",
   271  			sv:       sv,
   272  			chainPEM: testdata.TestEmbeddedCertPEM + testdata.CACertPEM,
   273  			sct:      testdata.TestPreCertProof,
   274  			embedded: true,
   275  		},
   276  		{
   277  			desc:     "cert with invalid embedded SCT",
   278  			sv:       sv,
   279  			chainPEM: testdata.TestInvalidEmbeddedCertPEM + testdata.CACertPEM,
   280  			sct:      testdata.TestInvalidProof,
   281  			embedded: true,
   282  			wantErr:  true,
   283  		},
   284  	}
   285  
   286  	for _, test := range tests {
   287  		t.Run(test.desc, func(t *testing.T) {
   288  			// Parse chain
   289  			chain, err := x509util.CertificatesFromPEM([]byte(test.chainPEM))
   290  			if err != nil {
   291  				t.Fatalf("error parsing certificate chain: %s", err)
   292  			}
   293  
   294  			// Parse SCT
   295  			var sct ct.SignedCertificateTimestamp
   296  			if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
   297  				t.Fatalf("error tls-unmarshalling sct: %s", err)
   298  			}
   299  
   300  			// Test VerifySCTWithVerifier()
   301  			err = VerifySCTWithVerifier(test.sv, chain, &sct, test.embedded)
   302  			if gotErr := err != nil; gotErr != test.wantErr {
   303  				t.Errorf("VerifySCT(_,_,_, %t) = %v, want error? %t", test.embedded, err, test.wantErr)
   304  			}
   305  		})
   306  	}
   307  }
   308  
   309  func TestContainsSCT(t *testing.T) {
   310  	tests := []struct {
   311  		desc    string
   312  		certPEM string
   313  		sct     []byte
   314  		want    bool
   315  	}{
   316  		{
   317  			desc:    "cert doesn't contain any SCTs",
   318  			certPEM: testdata.TestCertPEM,
   319  			sct:     testdata.TestPreCertProof,
   320  			want:    false,
   321  		},
   322  		{
   323  			desc:    "cert contains SCT but not specified SCT",
   324  			certPEM: testdata.TestEmbeddedCertPEM,
   325  			sct:     testdata.TestInvalidProof,
   326  			want:    false,
   327  		},
   328  		{
   329  			desc:    "cert contains SCT",
   330  			certPEM: testdata.TestEmbeddedCertPEM,
   331  			sct:     testdata.TestPreCertProof,
   332  			want:    true,
   333  		},
   334  	}
   335  
   336  	for _, test := range tests {
   337  		t.Run(test.desc, func(t *testing.T) {
   338  			// Parse cert
   339  			cert, err := x509util.CertificateFromPEM([]byte(test.certPEM))
   340  			if err != nil {
   341  				t.Fatalf("error parsing certificate: %s", err)
   342  			}
   343  
   344  			// Parse SCT
   345  			var sct ct.SignedCertificateTimestamp
   346  			if _, err = tls.Unmarshal(test.sct, &sct); err != nil {
   347  				t.Fatalf("error tls-unmarshalling sct: %s", err)
   348  			}
   349  
   350  			// Test ContainsSCT()
   351  			got, err := ContainsSCT(cert, &sct)
   352  			if err != nil {
   353  				t.Fatalf("ContainsSCT(_,_) = false, %s, want no error", err)
   354  			}
   355  
   356  			if got != test.want {
   357  				t.Errorf("ContainsSCT(_,_) = %t, nil, want %t, nil", got, test.want)
   358  			}
   359  		})
   360  	}
   361  }
   362  

View as plain text