1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sharding
17
18 import (
19 "context"
20 "encoding/json"
21 "errors"
22 "os"
23 "path/filepath"
24 "reflect"
25 "testing"
26
27 "github.com/golang/mock/gomock"
28 "github.com/google/trillian/testonly"
29
30 "github.com/google/trillian"
31 "google.golang.org/grpc"
32 "gopkg.in/yaml.v2"
33 )
34
35 func TestNewLogRanges(t *testing.T) {
36 contents := `
37 - treeID: 0001
38 treeLength: 3
39 encodedPublicKey: c2hhcmRpbmcK
40 - treeID: 0002
41 treeLength: 4`
42 file := filepath.Join(t.TempDir(), "sharding-config")
43 if err := os.WriteFile(file, []byte(contents), 0o644); err != nil {
44 t.Fatal(err)
45 }
46 treeID := uint(45)
47 expected := LogRanges{
48 inactive: []LogRange{
49 {
50 TreeID: 1,
51 TreeLength: 3,
52 EncodedPublicKey: "c2hhcmRpbmcK",
53 decodedPublicKey: "sharding\n",
54 }, {
55 TreeID: 2,
56 TreeLength: 4,
57 },
58 },
59 active: int64(45),
60 }
61 ctx := context.Background()
62 tc := trillian.NewTrillianLogClient(&grpc.ClientConn{})
63 got, err := NewLogRanges(ctx, tc, file, treeID)
64 if err != nil {
65 t.Fatal(err)
66 }
67 if expected.ActiveTreeID() != got.ActiveTreeID() {
68 t.Fatalf("expected tree id %d got %d", expected.ActiveTreeID(), got.ActiveTreeID())
69 }
70 if !reflect.DeepEqual(expected.GetInactive(), got.GetInactive()) {
71 t.Fatalf("expected %v got %v", expected.GetInactive(), got.GetInactive())
72 }
73 }
74
75 func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
76 lrs := LogRanges{
77 inactive: []LogRange{
78 {TreeID: 1, TreeLength: 17},
79 {TreeID: 2, TreeLength: 1},
80 {TreeID: 3, TreeLength: 100},
81 },
82 active: 4,
83 }
84
85 for _, tt := range []struct {
86 Index int
87 WantTreeID int64
88 WantIndex int64
89 }{
90 {
91 Index: 3,
92 WantTreeID: 1, WantIndex: 3,
93 },
94
95 {
96 Index: 17,
97 WantTreeID: 2, WantIndex: 0,
98 },
99
100 {
101 Index: 3000,
102 WantTreeID: 4, WantIndex: 2882,
103 },
104 } {
105 tree, index := lrs.ResolveVirtualIndex(tt.Index)
106 if tree != tt.WantTreeID {
107 t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
108 }
109 if index != tt.WantIndex {
110 t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
111 }
112 }
113 }
114
115 func TestPublicKey(t *testing.T) {
116 ranges := LogRanges{
117 active: 45,
118 inactive: []LogRange{
119 {
120 TreeID: 10,
121 TreeLength: 10,
122 decodedPublicKey: "sharding",
123 }, {
124 TreeID: 20,
125 TreeLength: 20,
126 },
127 },
128 }
129 activePubKey := "activekey"
130 tests := []struct {
131 description string
132 treeID string
133 expectedPubKey string
134 shouldErr bool
135 }{
136 {
137 description: "empty tree ID",
138 expectedPubKey: "activekey",
139 }, {
140 description: "tree id with decoded public key",
141 treeID: "10",
142 expectedPubKey: "sharding",
143 }, {
144 description: "tree id without decoded public key",
145 treeID: "20",
146 expectedPubKey: "activekey",
147 }, {
148 description: "invalid tree id",
149 treeID: "34",
150 shouldErr: true,
151 }, {
152 description: "pass in active tree id",
153 treeID: "45",
154 expectedPubKey: "activekey",
155 },
156 }
157
158 for _, test := range tests {
159 t.Run(test.description, func(t *testing.T) {
160 got, err := ranges.PublicKey(activePubKey, test.treeID)
161 if err != nil && !test.shouldErr {
162 t.Fatal(err)
163 }
164 if test.shouldErr {
165 return
166 }
167 if got != test.expectedPubKey {
168 t.Fatalf("got %s doesn't match expected %s", got, test.expectedPubKey)
169 }
170 })
171 }
172 }
173
174 func TestLogRanges_String(t *testing.T) {
175 type fields struct {
176 inactive Ranges
177 active int64
178 }
179 tests := []struct {
180 name string
181 fields fields
182 want string
183 }{
184 {
185 name: "empty",
186 fields: fields{
187 inactive: Ranges{},
188 active: 0,
189 },
190 want: "active=0",
191 },
192 {
193 name: "one",
194 fields: fields{
195 inactive: Ranges{
196 {
197 TreeID: 1,
198 TreeLength: 2,
199 },
200 },
201 active: 3,
202 },
203 want: "1=2,active=3",
204 },
205 {
206 name: "two",
207 fields: fields{
208 inactive: Ranges{
209 {
210 TreeID: 1,
211 TreeLength: 2,
212 },
213 {
214 TreeID: 2,
215 TreeLength: 3,
216 },
217 },
218 active: 4,
219 },
220 want: "1=2,2=3,active=4",
221 },
222 }
223 for _, tt := range tests {
224 t.Run(tt.name, func(t *testing.T) {
225 l := &LogRanges{
226 inactive: tt.fields.inactive,
227 active: tt.fields.active,
228 }
229 if got := l.String(); got != tt.want {
230 t.Errorf("String() = %v, want %v", got, tt.want)
231 }
232 })
233 }
234 }
235
236 func TestLogRanges_TotalInactiveLength(t *testing.T) {
237 type fields struct {
238 inactive Ranges
239 active int64
240 }
241 tests := []struct {
242 name string
243 fields fields
244 want int64
245 }{
246 {
247 name: "empty",
248 fields: fields{
249 inactive: Ranges{},
250 active: 0,
251 },
252 want: 0,
253 },
254 {
255 name: "one",
256 fields: fields{
257 inactive: Ranges{
258 {
259 TreeID: 1,
260 TreeLength: 2,
261 },
262 },
263 active: 3,
264 },
265 want: 2,
266 },
267 }
268 for _, tt := range tests {
269 t.Run(tt.name, func(t *testing.T) {
270 l := &LogRanges{
271 inactive: tt.fields.inactive,
272 active: tt.fields.active,
273 }
274 if got := l.TotalInactiveLength(); got != tt.want {
275 t.Errorf("TotalInactiveLength() = %v, want %v", got, tt.want)
276 }
277 })
278 }
279 }
280
281 func TestLogRanges_AllShards(t *testing.T) {
282 type fields struct {
283 inactive Ranges
284 active int64
285 }
286 tests := []struct {
287 name string
288 fields fields
289 want []int64
290 }{
291 {
292 name: "empty",
293 fields: fields{
294 inactive: Ranges{},
295 active: 0,
296 },
297 want: []int64{0},
298 },
299 {
300 name: "one",
301 fields: fields{
302 inactive: Ranges{
303 {
304 TreeID: 1,
305 TreeLength: 2,
306 },
307 },
308 active: 3,
309 },
310 want: []int64{3, 1},
311 },
312 {
313 name: "two",
314 fields: fields{
315 inactive: Ranges{
316 {
317 TreeID: 1,
318 TreeLength: 2,
319 },
320 {
321 TreeID: 2,
322 TreeLength: 3,
323 },
324 },
325 active: 4,
326 },
327 want: []int64{4, 1, 2},
328 },
329 }
330 for _, tt := range tests {
331 t.Run(tt.name, func(t *testing.T) {
332 l := &LogRanges{
333 inactive: tt.fields.inactive,
334 active: tt.fields.active,
335 }
336 if got := l.AllShards(); !reflect.DeepEqual(got, tt.want) {
337 t.Errorf("AllShards() = %v, want %v", got, tt.want)
338 }
339 })
340 }
341 }
342
343 func TestLogRangesFromPath(t *testing.T) {
344 type args struct {
345 path string
346 }
347 tests := []struct {
348 name string
349 args args
350 want Ranges
351 content string
352 wantJSON bool
353 wantYaml bool
354 wantInvalidJSON bool
355 wantErr bool
356 }{
357 {
358 name: "empty",
359 args: args{
360 path: "",
361 },
362 want: Ranges{},
363 wantErr: true,
364 },
365 {
366 name: "empty file",
367 args: args{
368 path: "one",
369 },
370 want: Ranges{},
371 wantErr: false,
372 },
373 {
374 name: "valid json",
375 args: args{
376 path: "one",
377 },
378 want: Ranges{
379 {
380 TreeID: 1,
381 TreeLength: 2,
382 },
383 },
384 wantJSON: true,
385 wantErr: false,
386 },
387 {
388 name: "valid yaml",
389 args: args{
390 path: "one",
391 },
392 want: Ranges{
393 {
394 TreeID: 1,
395 TreeLength: 2,
396 },
397 },
398 wantYaml: true,
399 wantErr: false,
400 },
401 {
402 name: "invalid json",
403 args: args{
404 path: "one",
405 },
406 want: Ranges{},
407 wantInvalidJSON: true,
408 wantErr: true,
409 },
410 }
411 for _, tt := range tests {
412 t.Run(tt.name, func(t *testing.T) {
413 if tt.args.path != "" {
414 f, err := os.CreateTemp("", tt.args.path)
415 if err != nil {
416 t.Fatalf("Failed to create temp file: %v", err)
417 }
418 switch {
419 case tt.wantJSON:
420 if err := json.NewEncoder(f).Encode(tt.want); err != nil {
421 t.Fatalf("Failed to encode json: %v", err)
422 }
423 case tt.wantYaml:
424 if err := yaml.NewEncoder(f).Encode(tt.want); err != nil {
425 t.Fatalf("Failed to encode yaml: %v", err)
426 }
427 case tt.wantInvalidJSON:
428 if _, err := f.WriteString("invalid json"); err != nil {
429 t.Fatalf("Failed to write invalid json: %v", err)
430 }
431 }
432 if _, err := f.Write([]byte(tt.content)); err != nil {
433 t.Fatalf("Failed to write to temp file: %v", err)
434 }
435 defer f.Close()
436 defer os.Remove(f.Name())
437 tt.args.path = f.Name()
438 }
439 got, err := logRangesFromPath(tt.args.path)
440 if (err != nil) != tt.wantErr {
441 t.Errorf("logRangesFromPath() error = %v, wantErr %v", err, tt.wantErr)
442 return
443 }
444 if !reflect.DeepEqual(got, tt.want) {
445 t.Errorf("logRangesFromPath() got = %v, want %v", got, tt.want)
446 }
447 })
448 }
449 }
450
451 func TestUpdateRange(t *testing.T) {
452 type args struct {
453 ctx context.Context
454 r LogRange
455 }
456 tests := []struct {
457 name string
458 args args
459 want LogRange
460 wantErr bool
461 rootResponse *trillian.GetLatestSignedLogRootResponse
462 signedLogError error
463 }{
464 {
465 name: "empty",
466 args: args{
467 ctx: context.Background(),
468 r: LogRange{},
469 },
470 want: LogRange{},
471 wantErr: true,
472 rootResponse: &trillian.GetLatestSignedLogRootResponse{
473 SignedLogRoot: &trillian.SignedLogRoot{},
474 },
475 signedLogError: nil,
476 },
477 {
478 name: "error in GetLatestSignedLogRoot",
479 args: args{
480 ctx: context.Background(),
481 r: LogRange{},
482 },
483 want: LogRange{},
484 wantErr: true,
485 rootResponse: &trillian.GetLatestSignedLogRootResponse{
486 SignedLogRoot: &trillian.SignedLogRoot{},
487 },
488 signedLogError: errors.New("error"),
489 },
490 }
491
492 mockCtl := gomock.NewController(t)
493 defer mockCtl.Finish()
494 for _, tt := range tests {
495 t.Run(tt.name, func(t *testing.T) {
496 s, fakeServer, err := testonly.NewMockServer(mockCtl)
497 if err != nil {
498 t.Fatalf("Failed to create mock server: %v", err)
499 }
500 defer fakeServer()
501
502 s.Log.EXPECT().GetLatestSignedLogRoot(
503 gomock.Any(), gomock.Any()).Return(tt.rootResponse, tt.signedLogError).AnyTimes()
504 got, err := updateRange(tt.args.ctx, s.LogClient, tt.args.r)
505
506 if (err != nil) != tt.wantErr {
507 t.Errorf("updateRange() error = %v, wantErr %v", err, tt.wantErr)
508 return
509 }
510 if !reflect.DeepEqual(got, tt.want) {
511 t.Errorf("updateRange() got = %v, want %v", got, tt.want)
512 }
513 })
514 }
515 }
516
517 func TestNewLogRangesWithMock(t *testing.T) {
518 type args struct {
519 ctx context.Context
520 path string
521 treeID uint
522 }
523 tests := []struct {
524 name string
525 args args
526 want LogRanges
527 wantErr bool
528 }{
529 {
530 name: "empty path",
531 args: args{
532 ctx: context.Background(),
533 path: "",
534 treeID: 1,
535 },
536 want: LogRanges{},
537 wantErr: false,
538 },
539 {
540 name: "treeID 0",
541 args: args{
542 ctx: context.Background(),
543 path: "x",
544 treeID: 0,
545 },
546 want: LogRanges{},
547 wantErr: true,
548 },
549 }
550
551 mockCtl := gomock.NewController(t)
552 defer mockCtl.Finish()
553 for _, tt := range tests {
554 t.Run(tt.name, func(t *testing.T) {
555
556 s, fakeServer, err := testonly.NewMockServer(mockCtl)
557 if err != nil {
558 t.Fatalf("Failed to create mock server: %v", err)
559 }
560 defer fakeServer()
561 got, err := NewLogRanges(tt.args.ctx, s.LogClient, tt.args.path, tt.args.treeID)
562 if (err != nil) != tt.wantErr {
563 t.Errorf("NewLogRanges() error = %v, wantErr %v", err, tt.wantErr)
564 return
565 }
566 if !reflect.DeepEqual(got, tt.want) {
567 t.Errorf("NewLogRanges() got = %v, want %v", got, tt.want)
568 }
569 })
570 }
571 }
572
View as plain text