...

Source file src/github.com/sigstore/rekor/pkg/trillianclient/trillian_client.go

Documentation: github.com/sigstore/rekor/pkg/trillianclient

     1  //
     2  // Copyright 2021 The Sigstore Authors.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package trillianclient
    17  
    18  import (
    19  	"context"
    20  	"encoding/hex"
    21  	"fmt"
    22  	"time"
    23  
    24  	"github.com/sigstore/rekor/pkg/log"
    25  	"github.com/transparency-dev/merkle/proof"
    26  	"github.com/transparency-dev/merkle/rfc6962"
    27  
    28  	"google.golang.org/grpc/codes"
    29  	"google.golang.org/grpc/status"
    30  	"google.golang.org/protobuf/types/known/durationpb"
    31  
    32  	"github.com/google/trillian"
    33  	"github.com/google/trillian/client"
    34  	"github.com/google/trillian/types"
    35  )
    36  
    37  // TrillianClient provides a wrapper around the Trillian client
    38  type TrillianClient struct {
    39  	client  trillian.TrillianLogClient
    40  	logID   int64
    41  	context context.Context
    42  }
    43  
    44  // NewTrillianClient creates a TrillianClient with the given Trillian client and log/tree ID.
    45  func NewTrillianClient(ctx context.Context, logClient trillian.TrillianLogClient, logID int64) TrillianClient {
    46  	return TrillianClient{
    47  		client:  logClient,
    48  		logID:   logID,
    49  		context: ctx,
    50  	}
    51  }
    52  
    53  // Response includes a status code, an optional error message, and one of the results based on the API call
    54  type Response struct {
    55  	// Status is the status code of the response
    56  	Status codes.Code
    57  	// Error contains an error on request or client failure
    58  	Err error
    59  	// GetAddResult contains the response from queueing a leaf in Trillian
    60  	GetAddResult *trillian.QueueLeafResponse
    61  	// GetLeafAndProofResult contains the response for fetching an inclusion proof and leaf
    62  	GetLeafAndProofResult *trillian.GetEntryAndProofResponse
    63  	// GetLatestResult contains the response for the latest checkpoint
    64  	GetLatestResult *trillian.GetLatestSignedLogRootResponse
    65  	// GetConsistencyProofResult contains the response for a consistency proof between two log sizes
    66  	GetConsistencyProofResult *trillian.GetConsistencyProofResponse
    67  	// getProofResult contains the response for an inclusion proof fetched by leaf hash
    68  	getProofResult *trillian.GetInclusionProofByHashResponse
    69  }
    70  
    71  func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) {
    72  	var root types.LogRootV1
    73  	if err := root.UnmarshalBinary(logRoot); err != nil {
    74  		return types.LogRootV1{}, err
    75  	}
    76  	return root, nil
    77  }
    78  
    79  func (t *TrillianClient) root() (types.LogRootV1, error) {
    80  	rqst := &trillian.GetLatestSignedLogRootRequest{
    81  		LogId: t.logID,
    82  	}
    83  	resp, err := t.client.GetLatestSignedLogRoot(t.context, rqst)
    84  	if err != nil {
    85  		return types.LogRootV1{}, err
    86  	}
    87  	return unmarshalLogRoot(resp.SignedLogRoot.LogRoot)
    88  }
    89  
    90  func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
    91  	leaf := &trillian.LogLeaf{
    92  		LeafValue: byteValue,
    93  	}
    94  	rqst := &trillian.QueueLeafRequest{
    95  		LogId: t.logID,
    96  		Leaf:  leaf,
    97  	}
    98  	resp, err := t.client.QueueLeaf(t.context, rqst)
    99  
   100  	// check for error
   101  	if err != nil || (resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK)) {
   102  		return &Response{
   103  			Status:       status.Code(err),
   104  			Err:          err,
   105  			GetAddResult: resp,
   106  		}
   107  	}
   108  
   109  	root, err := t.root()
   110  	if err != nil {
   111  		return &Response{
   112  			Status:       status.Code(err),
   113  			Err:          err,
   114  			GetAddResult: resp,
   115  		}
   116  	}
   117  	v := client.NewLogVerifier(rfc6962.DefaultHasher)
   118  	logClient := client.New(t.logID, t.client, v, root)
   119  
   120  	waitForInclusion := func(ctx context.Context, _ []byte) *Response {
   121  		if logClient.MinMergeDelay > 0 {
   122  			select {
   123  			case <-ctx.Done():
   124  				return &Response{
   125  					Status: codes.DeadlineExceeded,
   126  					Err:    ctx.Err(),
   127  				}
   128  			case <-time.After(logClient.MinMergeDelay):
   129  			}
   130  		}
   131  		for {
   132  			root = *logClient.GetRoot()
   133  			if root.TreeSize >= 1 {
   134  				proofResp := t.getProofByHash(resp.QueuedLeaf.Leaf.MerkleLeafHash)
   135  				// if this call succeeds or returns an error other than "not found", return
   136  				if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) {
   137  					return proofResp
   138  				}
   139  				// otherwise wait for a root update before trying again
   140  			}
   141  
   142  			if _, err := logClient.WaitForRootUpdate(ctx); err != nil {
   143  				return &Response{
   144  					Status: codes.Unknown,
   145  					Err:    err,
   146  				}
   147  			}
   148  		}
   149  	}
   150  
   151  	proofResp := waitForInclusion(t.context, resp.QueuedLeaf.Leaf.MerkleLeafHash)
   152  	if proofResp.Err != nil {
   153  		return &Response{
   154  			Status:       status.Code(proofResp.Err),
   155  			Err:          proofResp.Err,
   156  			GetAddResult: resp,
   157  		}
   158  	}
   159  
   160  	proofs := proofResp.getProofResult.Proof
   161  	if len(proofs) != 1 {
   162  		err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(resp.QueuedLeaf.Leaf.MerkleLeafHash), len(proofs))
   163  		return &Response{
   164  			Status:       status.Code(err),
   165  			Err:          err,
   166  			GetAddResult: resp,
   167  		}
   168  	}
   169  
   170  	leafIndex := proofs[0].LeafIndex
   171  	leafResp := t.GetLeafAndProofByIndex(leafIndex)
   172  	if leafResp.Err != nil {
   173  		return &Response{
   174  			Status:       status.Code(leafResp.Err),
   175  			Err:          leafResp.Err,
   176  			GetAddResult: resp,
   177  		}
   178  	}
   179  
   180  	// overwrite queued leaf that doesn't have index set
   181  	resp.QueuedLeaf.Leaf = leafResp.GetLeafAndProofResult.Leaf
   182  
   183  	return &Response{
   184  		Status:       status.Code(err),
   185  		Err:          err,
   186  		GetAddResult: resp,
   187  		// include getLeafAndProofResult for inclusion proof
   188  		GetLeafAndProofResult: leafResp.GetLeafAndProofResult,
   189  	}
   190  }
   191  
   192  func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
   193  	// get inclusion proof for hash, extract index, then fetch leaf using index
   194  	proofResp := t.getProofByHash(hash)
   195  	if proofResp.Err != nil {
   196  		return &Response{
   197  			Status: status.Code(proofResp.Err),
   198  			Err:    proofResp.Err,
   199  		}
   200  	}
   201  
   202  	proofs := proofResp.getProofResult.Proof
   203  	if len(proofs) != 1 {
   204  		err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(hash), len(proofs))
   205  		return &Response{
   206  			Status: status.Code(err),
   207  			Err:    err,
   208  		}
   209  	}
   210  
   211  	return t.GetLeafAndProofByIndex(proofs[0].LeafIndex)
   212  }
   213  
   214  func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
   215  	ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
   216  	defer cancel()
   217  
   218  	rootResp := t.GetLatest(0)
   219  	if rootResp.Err != nil {
   220  		return &Response{
   221  			Status: status.Code(rootResp.Err),
   222  			Err:    rootResp.Err,
   223  		}
   224  	}
   225  
   226  	root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot)
   227  	if err != nil {
   228  		return &Response{
   229  			Status: status.Code(rootResp.Err),
   230  			Err:    rootResp.Err,
   231  		}
   232  	}
   233  
   234  	resp, err := t.client.GetEntryAndProof(ctx,
   235  		&trillian.GetEntryAndProofRequest{
   236  			LogId:     t.logID,
   237  			LeafIndex: index,
   238  			TreeSize:  int64(root.TreeSize),
   239  		})
   240  
   241  	if resp != nil && resp.Proof != nil {
   242  		if err := proof.VerifyInclusion(rfc6962.DefaultHasher, uint64(index), root.TreeSize, resp.GetLeaf().MerkleLeafHash, resp.Proof.Hashes, root.RootHash); err != nil {
   243  			return &Response{
   244  				Status: status.Code(err),
   245  				Err:    err,
   246  			}
   247  		}
   248  		return &Response{
   249  			Status: status.Code(err),
   250  			Err:    err,
   251  			GetLeafAndProofResult: &trillian.GetEntryAndProofResponse{
   252  				Proof:         resp.Proof,
   253  				Leaf:          resp.Leaf,
   254  				SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot,
   255  			},
   256  		}
   257  	}
   258  
   259  	return &Response{
   260  		Status: status.Code(err),
   261  		Err:    err,
   262  	}
   263  }
   264  
   265  func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {
   266  
   267  	ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
   268  	defer cancel()
   269  
   270  	resp, err := t.client.GetLatestSignedLogRoot(ctx,
   271  		&trillian.GetLatestSignedLogRootRequest{
   272  			LogId:         t.logID,
   273  			FirstTreeSize: leafSizeInt,
   274  		})
   275  
   276  	return &Response{
   277  		Status:          status.Code(err),
   278  		Err:             err,
   279  		GetLatestResult: resp,
   280  	}
   281  }
   282  
   283  func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Response {
   284  
   285  	ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
   286  	defer cancel()
   287  
   288  	resp, err := t.client.GetConsistencyProof(ctx,
   289  		&trillian.GetConsistencyProofRequest{
   290  			LogId:          t.logID,
   291  			FirstTreeSize:  firstSize,
   292  			SecondTreeSize: lastSize,
   293  		})
   294  
   295  	return &Response{
   296  		Status:                    status.Code(err),
   297  		Err:                       err,
   298  		GetConsistencyProofResult: resp,
   299  	}
   300  }
   301  
   302  func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
   303  	ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
   304  	defer cancel()
   305  
   306  	rootResp := t.GetLatest(0)
   307  	if rootResp.Err != nil {
   308  		return &Response{
   309  			Status: status.Code(rootResp.Err),
   310  			Err:    rootResp.Err,
   311  		}
   312  	}
   313  	root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot)
   314  	if err != nil {
   315  		return &Response{
   316  			Status: status.Code(rootResp.Err),
   317  			Err:    rootResp.Err,
   318  		}
   319  	}
   320  
   321  	// issue 1308: if the tree is empty, there's no way we can return a proof
   322  	if root.TreeSize == 0 {
   323  		return &Response{
   324  			Status: codes.NotFound,
   325  			Err:    status.Error(codes.NotFound, "tree is empty"),
   326  		}
   327  	}
   328  
   329  	resp, err := t.client.GetInclusionProofByHash(ctx,
   330  		&trillian.GetInclusionProofByHashRequest{
   331  			LogId:    t.logID,
   332  			LeafHash: hashValue,
   333  			TreeSize: int64(root.TreeSize),
   334  		})
   335  
   336  	if resp != nil {
   337  		v := client.NewLogVerifier(rfc6962.DefaultHasher)
   338  		for _, proof := range resp.Proof {
   339  			if err := v.VerifyInclusionByHash(&root, hashValue, proof); err != nil {
   340  				return &Response{
   341  					Status: status.Code(err),
   342  					Err:    err,
   343  				}
   344  			}
   345  		}
   346  		// Return an inclusion proof response with the requested
   347  		return &Response{
   348  			Status: status.Code(err),
   349  			Err:    err,
   350  			getProofResult: &trillian.GetInclusionProofByHashResponse{
   351  				Proof:         resp.Proof,
   352  				SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot,
   353  			},
   354  		}
   355  	}
   356  
   357  	return &Response{
   358  		Status: status.Code(err),
   359  		Err:    err,
   360  	}
   361  }
   362  
   363  func CreateAndInitTree(ctx context.Context, adminClient trillian.TrillianAdminClient, logClient trillian.TrillianLogClient) (*trillian.Tree, error) {
   364  	t, err := adminClient.CreateTree(ctx, &trillian.CreateTreeRequest{
   365  		Tree: &trillian.Tree{
   366  			TreeType:        trillian.TreeType_LOG,
   367  			TreeState:       trillian.TreeState_ACTIVE,
   368  			MaxRootDuration: durationpb.New(time.Hour),
   369  		},
   370  	})
   371  	if err != nil {
   372  		return nil, fmt.Errorf("create tree: %w", err)
   373  	}
   374  
   375  	if err := client.InitLog(ctx, t, logClient); err != nil {
   376  		return nil, fmt.Errorf("init log: %w", err)
   377  	}
   378  	log.Logger.Infof("Created new tree with ID: %v", t.TreeId)
   379  	return t, nil
   380  }
   381  

View as plain text