...

Source file src/github.com/sigstore/rekor/pkg/sharding/ranges_test.go

Documentation: github.com/sigstore/rekor/pkg/sharding

     1  //
     2  // Copyright 2021 The Sigstore Authors.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package sharding
    17  
    18  import (
    19  	"context"
    20  	"encoding/json"
    21  	"errors"
    22  	"os"
    23  	"path/filepath"
    24  	"reflect"
    25  	"testing"
    26  
    27  	"github.com/golang/mock/gomock"
    28  	"github.com/google/trillian/testonly"
    29  
    30  	"github.com/google/trillian"
    31  	"google.golang.org/grpc"
    32  	"gopkg.in/yaml.v2"
    33  )
    34  
    35  func TestNewLogRanges(t *testing.T) {
    36  	contents := `
    37  - treeID: 0001
    38    treeLength: 3
    39    encodedPublicKey: c2hhcmRpbmcK
    40  - treeID: 0002
    41    treeLength: 4`
    42  	file := filepath.Join(t.TempDir(), "sharding-config")
    43  	if err := os.WriteFile(file, []byte(contents), 0o644); err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	treeID := uint(45)
    47  	expected := LogRanges{
    48  		inactive: []LogRange{
    49  			{
    50  				TreeID:           1,
    51  				TreeLength:       3,
    52  				EncodedPublicKey: "c2hhcmRpbmcK",
    53  				decodedPublicKey: "sharding\n",
    54  			}, {
    55  				TreeID:     2,
    56  				TreeLength: 4,
    57  			},
    58  		},
    59  		active: int64(45),
    60  	}
    61  	ctx := context.Background()
    62  	tc := trillian.NewTrillianLogClient(&grpc.ClientConn{})
    63  	got, err := NewLogRanges(ctx, tc, file, treeID)
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	if expected.ActiveTreeID() != got.ActiveTreeID() {
    68  		t.Fatalf("expected tree id %d got %d", expected.ActiveTreeID(), got.ActiveTreeID())
    69  	}
    70  	if !reflect.DeepEqual(expected.GetInactive(), got.GetInactive()) {
    71  		t.Fatalf("expected %v got %v", expected.GetInactive(), got.GetInactive())
    72  	}
    73  }
    74  
    75  func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
    76  	lrs := LogRanges{
    77  		inactive: []LogRange{
    78  			{TreeID: 1, TreeLength: 17},
    79  			{TreeID: 2, TreeLength: 1},
    80  			{TreeID: 3, TreeLength: 100},
    81  		},
    82  		active: 4,
    83  	}
    84  
    85  	for _, tt := range []struct {
    86  		Index      int
    87  		WantTreeID int64
    88  		WantIndex  int64
    89  	}{
    90  		{
    91  			Index:      3,
    92  			WantTreeID: 1, WantIndex: 3,
    93  		},
    94  		// This is the first (0th) entry in the next tree
    95  		{
    96  			Index:      17,
    97  			WantTreeID: 2, WantIndex: 0,
    98  		},
    99  		// Overflow
   100  		{
   101  			Index:      3000,
   102  			WantTreeID: 4, WantIndex: 2882,
   103  		},
   104  	} {
   105  		tree, index := lrs.ResolveVirtualIndex(tt.Index)
   106  		if tree != tt.WantTreeID {
   107  			t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
   108  		}
   109  		if index != tt.WantIndex {
   110  			t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
   111  		}
   112  	}
   113  }
   114  
   115  func TestPublicKey(t *testing.T) {
   116  	ranges := LogRanges{
   117  		active: 45,
   118  		inactive: []LogRange{
   119  			{
   120  				TreeID:           10,
   121  				TreeLength:       10,
   122  				decodedPublicKey: "sharding",
   123  			}, {
   124  				TreeID:     20,
   125  				TreeLength: 20,
   126  			},
   127  		},
   128  	}
   129  	activePubKey := "activekey"
   130  	tests := []struct {
   131  		description    string
   132  		treeID         string
   133  		expectedPubKey string
   134  		shouldErr      bool
   135  	}{
   136  		{
   137  			description:    "empty tree ID",
   138  			expectedPubKey: "activekey",
   139  		}, {
   140  			description:    "tree id with decoded public key",
   141  			treeID:         "10",
   142  			expectedPubKey: "sharding",
   143  		}, {
   144  			description:    "tree id without decoded public key",
   145  			treeID:         "20",
   146  			expectedPubKey: "activekey",
   147  		}, {
   148  			description: "invalid tree id",
   149  			treeID:      "34",
   150  			shouldErr:   true,
   151  		}, {
   152  			description:    "pass in active tree id",
   153  			treeID:         "45",
   154  			expectedPubKey: "activekey",
   155  		},
   156  	}
   157  
   158  	for _, test := range tests {
   159  		t.Run(test.description, func(t *testing.T) {
   160  			got, err := ranges.PublicKey(activePubKey, test.treeID)
   161  			if err != nil && !test.shouldErr {
   162  				t.Fatal(err)
   163  			}
   164  			if test.shouldErr {
   165  				return
   166  			}
   167  			if got != test.expectedPubKey {
   168  				t.Fatalf("got %s doesn't match expected %s", got, test.expectedPubKey)
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func TestLogRanges_String(t *testing.T) {
   175  	type fields struct {
   176  		inactive Ranges
   177  		active   int64
   178  	}
   179  	tests := []struct {
   180  		name   string
   181  		fields fields
   182  		want   string
   183  	}{
   184  		{
   185  			name: "empty",
   186  			fields: fields{
   187  				inactive: Ranges{},
   188  				active:   0,
   189  			},
   190  			want: "active=0",
   191  		},
   192  		{
   193  			name: "one",
   194  			fields: fields{
   195  				inactive: Ranges{
   196  					{
   197  						TreeID:     1,
   198  						TreeLength: 2,
   199  					},
   200  				},
   201  				active: 3,
   202  			},
   203  			want: "1=2,active=3",
   204  		},
   205  		{
   206  			name: "two",
   207  			fields: fields{
   208  				inactive: Ranges{
   209  					{
   210  						TreeID:     1,
   211  						TreeLength: 2,
   212  					},
   213  					{
   214  						TreeID:     2,
   215  						TreeLength: 3,
   216  					},
   217  				},
   218  				active: 4,
   219  			},
   220  			want: "1=2,2=3,active=4",
   221  		},
   222  	}
   223  	for _, tt := range tests {
   224  		t.Run(tt.name, func(t *testing.T) {
   225  			l := &LogRanges{
   226  				inactive: tt.fields.inactive,
   227  				active:   tt.fields.active,
   228  			}
   229  			if got := l.String(); got != tt.want {
   230  				t.Errorf("String() = %v, want %v", got, tt.want)
   231  			}
   232  		})
   233  	}
   234  }
   235  
   236  func TestLogRanges_TotalInactiveLength(t *testing.T) {
   237  	type fields struct {
   238  		inactive Ranges
   239  		active   int64
   240  	}
   241  	tests := []struct {
   242  		name   string
   243  		fields fields
   244  		want   int64
   245  	}{
   246  		{
   247  			name: "empty",
   248  			fields: fields{
   249  				inactive: Ranges{},
   250  				active:   0,
   251  			},
   252  			want: 0,
   253  		},
   254  		{
   255  			name: "one",
   256  			fields: fields{
   257  				inactive: Ranges{
   258  					{
   259  						TreeID:     1,
   260  						TreeLength: 2,
   261  					},
   262  				},
   263  				active: 3,
   264  			},
   265  			want: 2,
   266  		},
   267  	}
   268  	for _, tt := range tests {
   269  		t.Run(tt.name, func(t *testing.T) {
   270  			l := &LogRanges{
   271  				inactive: tt.fields.inactive,
   272  				active:   tt.fields.active,
   273  			}
   274  			if got := l.TotalInactiveLength(); got != tt.want {
   275  				t.Errorf("TotalInactiveLength() = %v, want %v", got, tt.want)
   276  			}
   277  		})
   278  	}
   279  }
   280  
   281  func TestLogRanges_AllShards(t *testing.T) {
   282  	type fields struct {
   283  		inactive Ranges
   284  		active   int64
   285  	}
   286  	tests := []struct {
   287  		name   string
   288  		fields fields
   289  		want   []int64
   290  	}{
   291  		{
   292  			name: "empty",
   293  			fields: fields{
   294  				inactive: Ranges{},
   295  				active:   0,
   296  			},
   297  			want: []int64{0},
   298  		},
   299  		{
   300  			name: "one",
   301  			fields: fields{
   302  				inactive: Ranges{
   303  					{
   304  						TreeID:     1,
   305  						TreeLength: 2,
   306  					},
   307  				},
   308  				active: 3,
   309  			},
   310  			want: []int64{3, 1},
   311  		},
   312  		{
   313  			name: "two",
   314  			fields: fields{
   315  				inactive: Ranges{
   316  					{
   317  						TreeID:     1,
   318  						TreeLength: 2,
   319  					},
   320  					{
   321  						TreeID:     2,
   322  						TreeLength: 3,
   323  					},
   324  				},
   325  				active: 4,
   326  			},
   327  			want: []int64{4, 1, 2},
   328  		},
   329  	}
   330  	for _, tt := range tests {
   331  		t.Run(tt.name, func(t *testing.T) {
   332  			l := &LogRanges{
   333  				inactive: tt.fields.inactive,
   334  				active:   tt.fields.active,
   335  			}
   336  			if got := l.AllShards(); !reflect.DeepEqual(got, tt.want) {
   337  				t.Errorf("AllShards() = %v, want %v", got, tt.want)
   338  			}
   339  		})
   340  	}
   341  }
   342  
   343  func TestLogRangesFromPath(t *testing.T) {
   344  	type args struct {
   345  		path string
   346  	}
   347  	tests := []struct {
   348  		name            string
   349  		args            args
   350  		want            Ranges
   351  		content         string
   352  		wantJSON        bool
   353  		wantYaml        bool
   354  		wantInvalidJSON bool
   355  		wantErr         bool
   356  	}{
   357  		{
   358  			name: "empty",
   359  			args: args{
   360  				path: "",
   361  			},
   362  			want:    Ranges{},
   363  			wantErr: true,
   364  		},
   365  		{
   366  			name: "empty file",
   367  			args: args{
   368  				path: "one",
   369  			},
   370  			want:    Ranges{},
   371  			wantErr: false,
   372  		},
   373  		{
   374  			name: "valid json",
   375  			args: args{
   376  				path: "one",
   377  			},
   378  			want: Ranges{
   379  				{
   380  					TreeID:     1,
   381  					TreeLength: 2,
   382  				},
   383  			},
   384  			wantJSON: true,
   385  			wantErr:  false,
   386  		},
   387  		{
   388  			name: "valid yaml",
   389  			args: args{
   390  				path: "one",
   391  			},
   392  			want: Ranges{
   393  				{
   394  					TreeID:     1,
   395  					TreeLength: 2,
   396  				},
   397  			},
   398  			wantYaml: true,
   399  			wantErr:  false,
   400  		},
   401  		{
   402  			name: "invalid json",
   403  			args: args{
   404  				path: "one",
   405  			},
   406  			want:            Ranges{},
   407  			wantInvalidJSON: true,
   408  			wantErr:         true,
   409  		},
   410  	}
   411  	for _, tt := range tests {
   412  		t.Run(tt.name, func(t *testing.T) {
   413  			if tt.args.path != "" {
   414  				f, err := os.CreateTemp("", tt.args.path)
   415  				if err != nil {
   416  					t.Fatalf("Failed to create temp file: %v", err)
   417  				}
   418  				switch {
   419  				case tt.wantJSON:
   420  					if err := json.NewEncoder(f).Encode(tt.want); err != nil {
   421  						t.Fatalf("Failed to encode json: %v", err)
   422  					}
   423  				case tt.wantYaml:
   424  					if err := yaml.NewEncoder(f).Encode(tt.want); err != nil {
   425  						t.Fatalf("Failed to encode yaml: %v", err)
   426  					}
   427  				case tt.wantInvalidJSON:
   428  					if _, err := f.WriteString("invalid json"); err != nil {
   429  						t.Fatalf("Failed to write invalid json: %v", err)
   430  					}
   431  				}
   432  				if _, err := f.Write([]byte(tt.content)); err != nil {
   433  					t.Fatalf("Failed to write to temp file: %v", err)
   434  				}
   435  				defer f.Close()
   436  				defer os.Remove(f.Name())
   437  				tt.args.path = f.Name()
   438  			}
   439  			got, err := logRangesFromPath(tt.args.path)
   440  			if (err != nil) != tt.wantErr {
   441  				t.Errorf("logRangesFromPath() error = %v, wantErr %v", err, tt.wantErr)
   442  				return
   443  			}
   444  			if !reflect.DeepEqual(got, tt.want) {
   445  				t.Errorf("logRangesFromPath() got = %v, want %v", got, tt.want)
   446  			}
   447  		})
   448  	}
   449  }
   450  
   451  func TestUpdateRange(t *testing.T) {
   452  	type args struct {
   453  		ctx context.Context
   454  		r   LogRange
   455  	}
   456  	tests := []struct {
   457  		name           string
   458  		args           args
   459  		want           LogRange
   460  		wantErr        bool
   461  		rootResponse   *trillian.GetLatestSignedLogRootResponse
   462  		signedLogError error
   463  	}{
   464  		{
   465  			name: "empty",
   466  			args: args{
   467  				ctx: context.Background(),
   468  				r:   LogRange{},
   469  			},
   470  			want:    LogRange{},
   471  			wantErr: true,
   472  			rootResponse: &trillian.GetLatestSignedLogRootResponse{
   473  				SignedLogRoot: &trillian.SignedLogRoot{},
   474  			},
   475  			signedLogError: nil,
   476  		},
   477  		{
   478  			name: "error in GetLatestSignedLogRoot",
   479  			args: args{
   480  				ctx: context.Background(),
   481  				r:   LogRange{},
   482  			},
   483  			want:    LogRange{},
   484  			wantErr: true,
   485  			rootResponse: &trillian.GetLatestSignedLogRootResponse{
   486  				SignedLogRoot: &trillian.SignedLogRoot{},
   487  			},
   488  			signedLogError: errors.New("error"),
   489  		},
   490  	}
   491  
   492  	mockCtl := gomock.NewController(t)
   493  	defer mockCtl.Finish()
   494  	for _, tt := range tests {
   495  		t.Run(tt.name, func(t *testing.T) {
   496  			s, fakeServer, err := testonly.NewMockServer(mockCtl)
   497  			if err != nil {
   498  				t.Fatalf("Failed to create mock server: %v", err)
   499  			}
   500  			defer fakeServer()
   501  
   502  			s.Log.EXPECT().GetLatestSignedLogRoot(
   503  				gomock.Any(), gomock.Any()).Return(tt.rootResponse, tt.signedLogError).AnyTimes()
   504  			got, err := updateRange(tt.args.ctx, s.LogClient, tt.args.r)
   505  
   506  			if (err != nil) != tt.wantErr {
   507  				t.Errorf("updateRange() error = %v, wantErr %v", err, tt.wantErr)
   508  				return
   509  			}
   510  			if !reflect.DeepEqual(got, tt.want) {
   511  				t.Errorf("updateRange() got = %v, want %v", got, tt.want)
   512  			}
   513  		})
   514  	}
   515  }
   516  
   517  func TestNewLogRangesWithMock(t *testing.T) {
   518  	type args struct {
   519  		ctx    context.Context
   520  		path   string
   521  		treeID uint
   522  	}
   523  	tests := []struct {
   524  		name    string
   525  		args    args
   526  		want    LogRanges
   527  		wantErr bool
   528  	}{
   529  		{
   530  			name: "empty path",
   531  			args: args{
   532  				ctx:    context.Background(),
   533  				path:   "",
   534  				treeID: 1,
   535  			},
   536  			want:    LogRanges{},
   537  			wantErr: false,
   538  		},
   539  		{
   540  			name: "treeID 0",
   541  			args: args{
   542  				ctx:    context.Background(),
   543  				path:   "x",
   544  				treeID: 0,
   545  			},
   546  			want:    LogRanges{},
   547  			wantErr: true,
   548  		},
   549  	}
   550  
   551  	mockCtl := gomock.NewController(t)
   552  	defer mockCtl.Finish()
   553  	for _, tt := range tests {
   554  		t.Run(tt.name, func(t *testing.T) {
   555  
   556  			s, fakeServer, err := testonly.NewMockServer(mockCtl)
   557  			if err != nil {
   558  				t.Fatalf("Failed to create mock server: %v", err)
   559  			}
   560  			defer fakeServer()
   561  			got, err := NewLogRanges(tt.args.ctx, s.LogClient, tt.args.path, tt.args.treeID)
   562  			if (err != nil) != tt.wantErr {
   563  				t.Errorf("NewLogRanges() error = %v, wantErr %v", err, tt.wantErr)
   564  				return
   565  			}
   566  			if !reflect.DeepEqual(got, tt.want) {
   567  				t.Errorf("NewLogRanges() got = %v, want %v", got, tt.want)
   568  			}
   569  		})
   570  	}
   571  }
   572  

View as plain text