...

Source file src/github.com/transparency-dev/merkle/testonly/tree_test.go

Documentation: github.com/transparency-dev/merkle/testonly

     1  // Copyright 2022 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testonly
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"math/rand"
    21  	"strconv"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/google/go-cmp/cmp/cmpopts"
    26  	"github.com/transparency-dev/merkle/rfc6962"
    27  )
    28  
    29  func validateTree(t *testing.T, mt *Tree, size uint64) {
    30  	t.Helper()
    31  	if got, want := mt.Size(), size; got != want {
    32  		t.Errorf("Size: %d, want %d", got, want)
    33  	}
    34  	roots := RootHashes()
    35  	if got, want := mt.Hash(), roots[size]; !bytes.Equal(got, want) {
    36  		t.Errorf("Hash(%d): %x, want %x", size, got, want)
    37  	}
    38  	for s := uint64(0); s <= size; s++ {
    39  		if got, want := mt.HashAt(s), roots[s]; !bytes.Equal(got, want) {
    40  			t.Errorf("HashAt(%d/%d): %x, want %x", s, size, got, want)
    41  		}
    42  	}
    43  }
    44  
    45  func TestBuildTreeBuildOneAtATime(t *testing.T) {
    46  	mt := newTree(nil)
    47  	validateTree(t, mt, 0)
    48  	for i, entry := range LeafInputs() {
    49  		mt.AppendData(entry)
    50  		validateTree(t, mt, uint64(i+1))
    51  	}
    52  }
    53  
    54  func TestBuildTreeBuildTwoChunks(t *testing.T) {
    55  	entries := LeafInputs()
    56  	mt := newTree(nil)
    57  	mt.AppendData(entries[:3]...)
    58  	validateTree(t, mt, 3)
    59  	mt.AppendData(entries[3:8]...)
    60  	validateTree(t, mt, 8)
    61  }
    62  
    63  func TestBuildTreeBuildAllAtOnce(t *testing.T) {
    64  	mt := newTree(nil)
    65  	mt.AppendData(LeafInputs()...)
    66  	validateTree(t, mt, 8)
    67  }
    68  
    69  func TestTreeHashAt(t *testing.T) {
    70  	test := func(desc string, entries [][]byte) {
    71  		t.Run(desc, func(t *testing.T) {
    72  			mt := newTree(entries)
    73  			for size := 0; size <= len(entries); size++ {
    74  				got := mt.HashAt(uint64(size))
    75  				want := refRootHash(entries[:size], mt.hasher)
    76  				if !bytes.Equal(got, want) {
    77  					t.Errorf("HashAt(%d): %x, want %x", size, got, want)
    78  				}
    79  			}
    80  		})
    81  	}
    82  
    83  	entries := LeafInputs()
    84  	for size := 0; size <= len(entries); size++ {
    85  		test(fmt.Sprintf("size:%d", size), entries[:size])
    86  	}
    87  	test("generated", genEntries(256))
    88  }
    89  
    90  func TestTreeInclusionProof(t *testing.T) {
    91  	test := func(desc string, entries [][]byte) {
    92  		t.Run(desc, func(t *testing.T) {
    93  			mt := newTree(entries)
    94  			for index, size := uint64(0), uint64(len(entries)); index < size; index++ {
    95  				got, err := mt.InclusionProof(index, size)
    96  				if err != nil {
    97  					t.Fatalf("InclusionProof(%d, %d): %v", index, size, err)
    98  				}
    99  				want := refInclusionProof(entries[:size], index, mt.hasher)
   100  				if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
   101  					t.Fatalf("InclusionProof(%d, %d): diff (-got +want)\n%s", index, size, diff)
   102  				}
   103  			}
   104  		})
   105  	}
   106  
   107  	test("generated", genEntries(256))
   108  	entries := LeafInputs()
   109  	for size := 0; size < len(entries); size++ {
   110  		test(fmt.Sprintf("golden:%d", size), entries[:size])
   111  	}
   112  }
   113  
   114  func TestTreeConsistencyProof(t *testing.T) {
   115  	entries := LeafInputs()
   116  	mt := newTree(entries)
   117  	validateTree(t, mt, 8)
   118  
   119  	if _, err := mt.ConsistencyProof(6, 3); err == nil {
   120  		t.Error("ConsistencyProof(6, 3) succeeded unexpectedly")
   121  	}
   122  
   123  	for size1 := uint64(0); size1 <= 8; size1++ {
   124  		for size2 := size1; size2 <= 8; size2++ {
   125  			t.Run(fmt.Sprintf("%d:%d", size1, size2), func(t *testing.T) {
   126  				got, err := mt.ConsistencyProof(size1, size2)
   127  				if err != nil {
   128  					t.Fatalf("ConsistencyProof: %v", err)
   129  				}
   130  				want := refConsistencyProof(entries[:size2], size2, size1, mt.hasher, true)
   131  				if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
   132  					t.Errorf("ConsistencyProof: diff (-got +want)\n%s", diff)
   133  				}
   134  			})
   135  		}
   136  	}
   137  }
   138  
   139  // Make random proof queries and check against the reference implementation.
   140  func TestTreeConsistencyProofFuzz(t *testing.T) {
   141  	entries := genEntries(256)
   142  
   143  	for treeSize := int64(1); treeSize <= 256; treeSize++ {
   144  		mt := newTree(entries[:treeSize])
   145  		for i := 0; i < 8; i++ {
   146  			size2 := uint64(rand.Int63n(treeSize + 1))
   147  			size1 := uint64(rand.Int63n(int64(size2) + 1))
   148  
   149  			got, err := mt.ConsistencyProof(size1, size2)
   150  			if err != nil {
   151  				t.Fatalf("ConsistencyProof: %v", err)
   152  			}
   153  			want := refConsistencyProof(entries[:size2], size2, size1, mt.hasher, true)
   154  			if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
   155  				t.Errorf("ConsistencyProof: diff (-got +want)\n%s", diff)
   156  			}
   157  		}
   158  	}
   159  }
   160  
   161  func TestTreeAppend(t *testing.T) {
   162  	entries := genEntries(256)
   163  	mt1 := newTree(entries)
   164  
   165  	mt2 := newTree(nil)
   166  	for _, entry := range entries {
   167  		mt2.Append(rfc6962.DefaultHasher.HashLeaf(entry))
   168  	}
   169  
   170  	if diff := cmp.Diff(mt1, mt2, cmp.AllowUnexported(Tree{})); diff != "" {
   171  		t.Errorf("Trees built with AppendData and Append mismatch: diff (-mt1 +mt2)\n%s", diff)
   172  	}
   173  }
   174  
   175  func TestTreeAppendAssociativity(t *testing.T) {
   176  	entries := genEntries(256)
   177  	mt1 := newTree(nil)
   178  	mt1.AppendData(entries...)
   179  
   180  	mt2 := newTree(nil)
   181  	for _, entry := range entries {
   182  		mt2.AppendData(entry)
   183  	}
   184  
   185  	if diff := cmp.Diff(mt1, mt2, cmp.AllowUnexported(Tree{})); diff != "" {
   186  		t.Errorf("AppendData is not associative: diff (-mt1 +mt2)\n%s", diff)
   187  	}
   188  }
   189  
   190  func newTree(entries [][]byte) *Tree {
   191  	tree := New(rfc6962.DefaultHasher)
   192  	tree.AppendData(entries...)
   193  	return tree
   194  }
   195  
   196  // genEntries a slice of entries of the given size.
   197  func genEntries(size uint64) [][]byte {
   198  	entries := make([][]byte, size)
   199  	for i := range entries {
   200  		entries[i] = []byte(strconv.Itoa(i))
   201  	}
   202  	return entries
   203  }
   204  

View as plain text