1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package trillianclient
17
18 import (
19 "context"
20 "encoding/hex"
21 "fmt"
22 "time"
23
24 "github.com/sigstore/rekor/pkg/log"
25 "github.com/transparency-dev/merkle/proof"
26 "github.com/transparency-dev/merkle/rfc6962"
27
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/status"
30 "google.golang.org/protobuf/types/known/durationpb"
31
32 "github.com/google/trillian"
33 "github.com/google/trillian/client"
34 "github.com/google/trillian/types"
35 )
36
37
38 type TrillianClient struct {
39 client trillian.TrillianLogClient
40 logID int64
41 context context.Context
42 }
43
44
45 func NewTrillianClient(ctx context.Context, logClient trillian.TrillianLogClient, logID int64) TrillianClient {
46 return TrillianClient{
47 client: logClient,
48 logID: logID,
49 context: ctx,
50 }
51 }
52
53
54 type Response struct {
55
56 Status codes.Code
57
58 Err error
59
60 GetAddResult *trillian.QueueLeafResponse
61
62 GetLeafAndProofResult *trillian.GetEntryAndProofResponse
63
64 GetLatestResult *trillian.GetLatestSignedLogRootResponse
65
66 GetConsistencyProofResult *trillian.GetConsistencyProofResponse
67
68 getProofResult *trillian.GetInclusionProofByHashResponse
69 }
70
71 func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) {
72 var root types.LogRootV1
73 if err := root.UnmarshalBinary(logRoot); err != nil {
74 return types.LogRootV1{}, err
75 }
76 return root, nil
77 }
78
79 func (t *TrillianClient) root() (types.LogRootV1, error) {
80 rqst := &trillian.GetLatestSignedLogRootRequest{
81 LogId: t.logID,
82 }
83 resp, err := t.client.GetLatestSignedLogRoot(t.context, rqst)
84 if err != nil {
85 return types.LogRootV1{}, err
86 }
87 return unmarshalLogRoot(resp.SignedLogRoot.LogRoot)
88 }
89
90 func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
91 leaf := &trillian.LogLeaf{
92 LeafValue: byteValue,
93 }
94 rqst := &trillian.QueueLeafRequest{
95 LogId: t.logID,
96 Leaf: leaf,
97 }
98 resp, err := t.client.QueueLeaf(t.context, rqst)
99
100
101 if err != nil || (resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK)) {
102 return &Response{
103 Status: status.Code(err),
104 Err: err,
105 GetAddResult: resp,
106 }
107 }
108
109 root, err := t.root()
110 if err != nil {
111 return &Response{
112 Status: status.Code(err),
113 Err: err,
114 GetAddResult: resp,
115 }
116 }
117 v := client.NewLogVerifier(rfc6962.DefaultHasher)
118 logClient := client.New(t.logID, t.client, v, root)
119
120 waitForInclusion := func(ctx context.Context, _ []byte) *Response {
121 if logClient.MinMergeDelay > 0 {
122 select {
123 case <-ctx.Done():
124 return &Response{
125 Status: codes.DeadlineExceeded,
126 Err: ctx.Err(),
127 }
128 case <-time.After(logClient.MinMergeDelay):
129 }
130 }
131 for {
132 root = *logClient.GetRoot()
133 if root.TreeSize >= 1 {
134 proofResp := t.getProofByHash(resp.QueuedLeaf.Leaf.MerkleLeafHash)
135
136 if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) {
137 return proofResp
138 }
139
140 }
141
142 if _, err := logClient.WaitForRootUpdate(ctx); err != nil {
143 return &Response{
144 Status: codes.Unknown,
145 Err: err,
146 }
147 }
148 }
149 }
150
151 proofResp := waitForInclusion(t.context, resp.QueuedLeaf.Leaf.MerkleLeafHash)
152 if proofResp.Err != nil {
153 return &Response{
154 Status: status.Code(proofResp.Err),
155 Err: proofResp.Err,
156 GetAddResult: resp,
157 }
158 }
159
160 proofs := proofResp.getProofResult.Proof
161 if len(proofs) != 1 {
162 err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(resp.QueuedLeaf.Leaf.MerkleLeafHash), len(proofs))
163 return &Response{
164 Status: status.Code(err),
165 Err: err,
166 GetAddResult: resp,
167 }
168 }
169
170 leafIndex := proofs[0].LeafIndex
171 leafResp := t.GetLeafAndProofByIndex(leafIndex)
172 if leafResp.Err != nil {
173 return &Response{
174 Status: status.Code(leafResp.Err),
175 Err: leafResp.Err,
176 GetAddResult: resp,
177 }
178 }
179
180
181 resp.QueuedLeaf.Leaf = leafResp.GetLeafAndProofResult.Leaf
182
183 return &Response{
184 Status: status.Code(err),
185 Err: err,
186 GetAddResult: resp,
187
188 GetLeafAndProofResult: leafResp.GetLeafAndProofResult,
189 }
190 }
191
192 func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
193
194 proofResp := t.getProofByHash(hash)
195 if proofResp.Err != nil {
196 return &Response{
197 Status: status.Code(proofResp.Err),
198 Err: proofResp.Err,
199 }
200 }
201
202 proofs := proofResp.getProofResult.Proof
203 if len(proofs) != 1 {
204 err := fmt.Errorf("expected 1 proof from getProofByHash for %v, found %v", hex.EncodeToString(hash), len(proofs))
205 return &Response{
206 Status: status.Code(err),
207 Err: err,
208 }
209 }
210
211 return t.GetLeafAndProofByIndex(proofs[0].LeafIndex)
212 }
213
214 func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
215 ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
216 defer cancel()
217
218 rootResp := t.GetLatest(0)
219 if rootResp.Err != nil {
220 return &Response{
221 Status: status.Code(rootResp.Err),
222 Err: rootResp.Err,
223 }
224 }
225
226 root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot)
227 if err != nil {
228 return &Response{
229 Status: status.Code(rootResp.Err),
230 Err: rootResp.Err,
231 }
232 }
233
234 resp, err := t.client.GetEntryAndProof(ctx,
235 &trillian.GetEntryAndProofRequest{
236 LogId: t.logID,
237 LeafIndex: index,
238 TreeSize: int64(root.TreeSize),
239 })
240
241 if resp != nil && resp.Proof != nil {
242 if err := proof.VerifyInclusion(rfc6962.DefaultHasher, uint64(index), root.TreeSize, resp.GetLeaf().MerkleLeafHash, resp.Proof.Hashes, root.RootHash); err != nil {
243 return &Response{
244 Status: status.Code(err),
245 Err: err,
246 }
247 }
248 return &Response{
249 Status: status.Code(err),
250 Err: err,
251 GetLeafAndProofResult: &trillian.GetEntryAndProofResponse{
252 Proof: resp.Proof,
253 Leaf: resp.Leaf,
254 SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot,
255 },
256 }
257 }
258
259 return &Response{
260 Status: status.Code(err),
261 Err: err,
262 }
263 }
264
265 func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {
266
267 ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
268 defer cancel()
269
270 resp, err := t.client.GetLatestSignedLogRoot(ctx,
271 &trillian.GetLatestSignedLogRootRequest{
272 LogId: t.logID,
273 FirstTreeSize: leafSizeInt,
274 })
275
276 return &Response{
277 Status: status.Code(err),
278 Err: err,
279 GetLatestResult: resp,
280 }
281 }
282
283 func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Response {
284
285 ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
286 defer cancel()
287
288 resp, err := t.client.GetConsistencyProof(ctx,
289 &trillian.GetConsistencyProofRequest{
290 LogId: t.logID,
291 FirstTreeSize: firstSize,
292 SecondTreeSize: lastSize,
293 })
294
295 return &Response{
296 Status: status.Code(err),
297 Err: err,
298 GetConsistencyProofResult: resp,
299 }
300 }
301
302 func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
303 ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
304 defer cancel()
305
306 rootResp := t.GetLatest(0)
307 if rootResp.Err != nil {
308 return &Response{
309 Status: status.Code(rootResp.Err),
310 Err: rootResp.Err,
311 }
312 }
313 root, err := unmarshalLogRoot(rootResp.GetLatestResult.SignedLogRoot.LogRoot)
314 if err != nil {
315 return &Response{
316 Status: status.Code(rootResp.Err),
317 Err: rootResp.Err,
318 }
319 }
320
321
322 if root.TreeSize == 0 {
323 return &Response{
324 Status: codes.NotFound,
325 Err: status.Error(codes.NotFound, "tree is empty"),
326 }
327 }
328
329 resp, err := t.client.GetInclusionProofByHash(ctx,
330 &trillian.GetInclusionProofByHashRequest{
331 LogId: t.logID,
332 LeafHash: hashValue,
333 TreeSize: int64(root.TreeSize),
334 })
335
336 if resp != nil {
337 v := client.NewLogVerifier(rfc6962.DefaultHasher)
338 for _, proof := range resp.Proof {
339 if err := v.VerifyInclusionByHash(&root, hashValue, proof); err != nil {
340 return &Response{
341 Status: status.Code(err),
342 Err: err,
343 }
344 }
345 }
346
347 return &Response{
348 Status: status.Code(err),
349 Err: err,
350 getProofResult: &trillian.GetInclusionProofByHashResponse{
351 Proof: resp.Proof,
352 SignedLogRoot: rootResp.GetLatestResult.SignedLogRoot,
353 },
354 }
355 }
356
357 return &Response{
358 Status: status.Code(err),
359 Err: err,
360 }
361 }
362
363 func CreateAndInitTree(ctx context.Context, adminClient trillian.TrillianAdminClient, logClient trillian.TrillianLogClient) (*trillian.Tree, error) {
364 t, err := adminClient.CreateTree(ctx, &trillian.CreateTreeRequest{
365 Tree: &trillian.Tree{
366 TreeType: trillian.TreeType_LOG,
367 TreeState: trillian.TreeState_ACTIVE,
368 MaxRootDuration: durationpb.New(time.Hour),
369 },
370 })
371 if err != nil {
372 return nil, fmt.Errorf("create tree: %w", err)
373 }
374
375 if err := client.InitLog(ctx, t, logClient); err != nil {
376 return nil, fmt.Errorf("init log: %w", err)
377 }
378 log.Logger.Infof("Created new tree with ID: %v", t.TreeId)
379 return t, nil
380 }
381
View as plain text