...

Source file src/google.golang.org/grpc/credentials/tls/certprovider/pemfile/watcher_test.go

Documentation: google.golang.org/grpc/credentials/tls/certprovider/pemfile

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package pemfile
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"os"
    25  	"path"
    26  	"testing"
    27  	"time"
    28  
    29  	"google.golang.org/grpc/credentials/tls/certprovider"
    30  	"google.golang.org/grpc/internal/grpctest"
    31  	"google.golang.org/grpc/internal/testutils"
    32  	"google.golang.org/grpc/testdata"
    33  )
    34  
    35  const (
    36  	// These are the names of files inside temporary directories, which the
    37  	// plugin is asked to watch.
    38  	certFile = "cert.pem"
    39  	keyFile  = "key.pem"
    40  	rootFile = "ca.pem"
    41  
    42  	defaultTestRefreshDuration = 100 * time.Millisecond
    43  	defaultTestTimeout         = 5 * time.Second
    44  )
    45  
    46  type s struct {
    47  	grpctest.Tester
    48  }
    49  
    50  func Test(t *testing.T) {
    51  	grpctest.RunSubTests(t, s{})
    52  }
    53  
    54  func compareKeyMaterial(got, want *certprovider.KeyMaterial) error {
    55  	if len(got.Certs) != len(want.Certs) {
    56  		return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
    57  	}
    58  	for i := 0; i < len(got.Certs); i++ {
    59  		if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) {
    60  			return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
    61  		}
    62  	}
    63  
    64  	if gotR, wantR := got.Roots, want.Roots; !gotR.Equal(wantR) {
    65  		return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR)
    66  	}
    67  
    68  	return nil
    69  }
    70  
    71  // TestNewProvider tests the NewProvider() function with different inputs.
    72  func (s) TestNewProvider(t *testing.T) {
    73  	tests := []struct {
    74  		desc      string
    75  		options   Options
    76  		wantError bool
    77  	}{
    78  		{
    79  			desc:      "No credential files specified",
    80  			options:   Options{},
    81  			wantError: true,
    82  		},
    83  		{
    84  			desc: "Only identity cert is specified",
    85  			options: Options{
    86  				CertFile: testdata.Path("x509/client1_cert.pem"),
    87  			},
    88  			wantError: true,
    89  		},
    90  		{
    91  			desc: "Only identity key is specified",
    92  			options: Options{
    93  				KeyFile: testdata.Path("x509/client1_key.pem"),
    94  			},
    95  			wantError: true,
    96  		},
    97  		{
    98  			desc: "Identity cert/key pair is specified",
    99  			options: Options{
   100  				KeyFile:  testdata.Path("x509/client1_key.pem"),
   101  				CertFile: testdata.Path("x509/client1_cert.pem"),
   102  			},
   103  		},
   104  		{
   105  			desc: "Only root certs are specified",
   106  			options: Options{
   107  				RootFile: testdata.Path("x509/client_ca_cert.pem"),
   108  			},
   109  		},
   110  		{
   111  			desc: "Everything is specified",
   112  			options: Options{
   113  				KeyFile:  testdata.Path("x509/client1_key.pem"),
   114  				CertFile: testdata.Path("x509/client1_cert.pem"),
   115  				RootFile: testdata.Path("x509/client_ca_cert.pem"),
   116  			},
   117  			wantError: false,
   118  		},
   119  	}
   120  	for _, test := range tests {
   121  		t.Run(test.desc, func(t *testing.T) {
   122  			provider, err := NewProvider(test.options)
   123  			if (err != nil) != test.wantError {
   124  				t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
   125  			}
   126  			if err != nil {
   127  				return
   128  			}
   129  			provider.Close()
   130  		})
   131  	}
   132  }
   133  
   134  // wrappedDistributor wraps a distributor and pushes on a channel whenever new
   135  // key material is pushed to the distributor.
   136  type wrappedDistributor struct {
   137  	*certprovider.Distributor
   138  	distCh *testutils.Channel
   139  }
   140  
   141  func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
   142  	return &wrappedDistributor{
   143  		distCh:      distCh,
   144  		Distributor: certprovider.NewDistributor(),
   145  	}
   146  }
   147  
   148  func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
   149  	wd.Distributor.Set(km, err)
   150  	wd.distCh.Send(nil)
   151  }
   152  
   153  func createTmpFile(t *testing.T, src, dst string) {
   154  	t.Helper()
   155  
   156  	data, err := os.ReadFile(src)
   157  	if err != nil {
   158  		t.Fatalf("os.ReadFile(%q) failed: %v", src, err)
   159  	}
   160  	if err := os.WriteFile(dst, data, os.ModePerm); err != nil {
   161  		t.Fatalf("os.WriteFile(%q) failed: %v", dst, err)
   162  	}
   163  	t.Logf("Wrote file at: %s", dst)
   164  	t.Logf("%s", string(data))
   165  }
   166  
   167  // createTempDirWithFiles creates a temporary directory under the system default
   168  // tempDir with the given dirSuffix. It also reads from certSrc, keySrc and
   169  // rootSrc files are creates appropriate files under the newly create tempDir.
   170  // Returns the name of the created tempDir.
   171  func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
   172  	t.Helper()
   173  
   174  	// Create a temp directory. Passing an empty string for the first argument
   175  	// uses the system temp directory.
   176  	dir, err := os.MkdirTemp("", dirSuffix)
   177  	if err != nil {
   178  		t.Fatalf("os.MkdirTemp() failed: %v", err)
   179  	}
   180  	t.Logf("Using tmpdir: %s", dir)
   181  
   182  	createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
   183  	createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
   184  	createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
   185  	return dir
   186  }
   187  
   188  // initializeProvider performs setup steps common to all tests (except the one
   189  // which uses symlinks).
   190  func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
   191  	t.Helper()
   192  
   193  	// Override the newDistributor to one which pushes on a channel that we
   194  	// can block on.
   195  	origDistributorFunc := newDistributor
   196  	distCh := testutils.NewChannel()
   197  	d := newWrappedDistributor(distCh)
   198  	newDistributor = func() distributor { return d }
   199  
   200  	// Create a new provider to watch the files in tmpdir.
   201  	dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
   202  	opts := Options{
   203  		CertFile:        path.Join(dir, certFile),
   204  		KeyFile:         path.Join(dir, keyFile),
   205  		RootFile:        path.Join(dir, rootFile),
   206  		RefreshDuration: defaultTestRefreshDuration,
   207  	}
   208  	prov, err := NewProvider(opts)
   209  	if err != nil {
   210  		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
   211  	}
   212  
   213  	// Make sure the provider picks up the files and pushes the key material on
   214  	// to the distributors.
   215  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   216  	defer cancel()
   217  	for i := 0; i < 2; i++ {
   218  		// Since we have root and identity certs, we need to make sure the
   219  		// update is pushed on both of them.
   220  		if _, err := distCh.Receive(ctx); err != nil {
   221  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   222  		}
   223  	}
   224  
   225  	return dir, prov, distCh, func() {
   226  		newDistributor = origDistributorFunc
   227  		prov.Close()
   228  	}
   229  }
   230  
   231  // TestProvider_NoUpdate tests the case where a file watcher plugin is created
   232  // successfully, and the underlying files do not change. Verifies that the
   233  // plugin does not push new updates to the distributor in this case.
   234  func (s) TestProvider_NoUpdate(t *testing.T) {
   235  	_, prov, distCh, cancel := initializeProvider(t, "no_update")
   236  	defer cancel()
   237  
   238  	// Make sure the provider is healthy and returns key material.
   239  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   240  	defer cc()
   241  	if _, err := prov.KeyMaterial(ctx); err != nil {
   242  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   243  	}
   244  
   245  	// Files haven't change. Make sure no updates are pushed by the provider.
   246  	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
   247  	defer sc()
   248  	if _, err := distCh.Receive(sCtx); err == nil {
   249  		t.Fatal("new key material pushed to distributor when underlying files did not change")
   250  	}
   251  }
   252  
   253  // TestProvider_UpdateSuccess tests the case where a file watcher plugin is
   254  // created successfully and the underlying files change. Verifies that the
   255  // changes are picked up by the provider.
   256  func (s) TestProvider_UpdateSuccess(t *testing.T) {
   257  	dir, prov, distCh, cancel := initializeProvider(t, "update_success")
   258  	defer cancel()
   259  
   260  	// Make sure the provider is healthy and returns key material.
   261  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   262  	defer cc()
   263  	km1, err := prov.KeyMaterial(ctx)
   264  	if err != nil {
   265  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   266  	}
   267  
   268  	// Change only the root file.
   269  	createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
   270  	if _, err := distCh.Receive(ctx); err != nil {
   271  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   272  	}
   273  
   274  	// Make sure update is picked up.
   275  	km2, err := prov.KeyMaterial(ctx)
   276  	if err != nil {
   277  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   278  	}
   279  	if err := compareKeyMaterial(km1, km2); err == nil {
   280  		t.Fatal("expected provider to return new key material after update to underlying file")
   281  	}
   282  
   283  	// Change only cert/key files.
   284  	createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
   285  	createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
   286  	if _, err := distCh.Receive(ctx); err != nil {
   287  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   288  	}
   289  
   290  	// Make sure update is picked up.
   291  	km3, err := prov.KeyMaterial(ctx)
   292  	if err != nil {
   293  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   294  	}
   295  	if err := compareKeyMaterial(km2, km3); err == nil {
   296  		t.Fatal("expected provider to return new key material after update to underlying file")
   297  	}
   298  }
   299  
   300  // TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher
   301  // plugin is created successfully to watch files through a symlink and the
   302  // symlink is updates to point to new files. Verifies that the changes are
   303  // picked up by the provider.
   304  func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
   305  	// Override the newDistributor to one which pushes on a channel that we
   306  	// can block on.
   307  	origDistributorFunc := newDistributor
   308  	distCh := testutils.NewChannel()
   309  	d := newWrappedDistributor(distCh)
   310  	newDistributor = func() distributor { return d }
   311  	defer func() { newDistributor = origDistributorFunc }()
   312  
   313  	// Create two tempDirs with different files.
   314  	dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
   315  	dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
   316  
   317  	// Create a symlink under a new tempdir, and make it point to dir1.
   318  	tmpdir, err := os.MkdirTemp("", "test_symlink_*")
   319  	if err != nil {
   320  		t.Fatalf("os.MkdirTemp() failed: %v", err)
   321  	}
   322  	symLinkName := path.Join(tmpdir, "test_symlink")
   323  	if err := os.Symlink(dir1, symLinkName); err != nil {
   324  		t.Fatalf("failed to create symlink to %q: %v", dir1, err)
   325  	}
   326  
   327  	// Create a provider which watches the files pointed to by the symlink.
   328  	opts := Options{
   329  		CertFile:        path.Join(symLinkName, certFile),
   330  		KeyFile:         path.Join(symLinkName, keyFile),
   331  		RootFile:        path.Join(symLinkName, rootFile),
   332  		RefreshDuration: defaultTestRefreshDuration,
   333  	}
   334  	prov, err := NewProvider(opts)
   335  	if err != nil {
   336  		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
   337  	}
   338  	defer prov.Close()
   339  
   340  	// Make sure the provider picks up the files and pushes the key material on
   341  	// to the distributors.
   342  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   343  	defer cancel()
   344  	for i := 0; i < 2; i++ {
   345  		// Since we have root and identity certs, we need to make sure the
   346  		// update is pushed on both of them.
   347  		if _, err := distCh.Receive(ctx); err != nil {
   348  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   349  		}
   350  	}
   351  	km1, err := prov.KeyMaterial(ctx)
   352  	if err != nil {
   353  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   354  	}
   355  
   356  	// Update the symlink to point to dir2.
   357  	symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
   358  	if err := os.Symlink(dir2, symLinkTmpName); err != nil {
   359  		t.Fatalf("failed to create symlink to %q: %v", dir2, err)
   360  	}
   361  	if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
   362  		t.Fatalf("failed to update symlink: %v", err)
   363  	}
   364  
   365  	// Make sure the provider picks up the new files and pushes the key material
   366  	// on to the distributors.
   367  	for i := 0; i < 2; i++ {
   368  		// Since we have root and identity certs, we need to make sure the
   369  		// update is pushed on both of them.
   370  		if _, err := distCh.Receive(ctx); err != nil {
   371  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   372  		}
   373  	}
   374  	km2, err := prov.KeyMaterial(ctx)
   375  	if err != nil {
   376  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   377  	}
   378  
   379  	if err := compareKeyMaterial(km1, km2); err == nil {
   380  		t.Fatal("expected provider to return new key material after symlink update")
   381  	}
   382  }
   383  
   384  // TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key
   385  // files fail. Verifies that the failed update does not push anything on the
   386  // distributor. Then the update succeeds, and the test verifies that the key
   387  // material is updated.
   388  func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
   389  	dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
   390  	defer cancel()
   391  
   392  	// Make sure the provider is healthy and returns key material.
   393  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   394  	defer cc()
   395  	km1, err := prov.KeyMaterial(ctx)
   396  	if err != nil {
   397  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   398  	}
   399  
   400  	// Update only the cert file. The key file is left unchanged. This should
   401  	// lead to these two files being not compatible with each other. This
   402  	// simulates the case where the watching goroutine might catch the files in
   403  	// the midst of an update.
   404  	createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
   405  
   406  	// Since the last update left the files in an incompatible state, the update
   407  	// should not be picked up by our provider.
   408  	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
   409  	defer sc()
   410  	if _, err := distCh.Receive(sCtx); err == nil {
   411  		t.Fatal("new key material pushed to distributor when underlying files did not change")
   412  	}
   413  
   414  	// The provider should return key material corresponding to the old state.
   415  	km2, err := prov.KeyMaterial(ctx)
   416  	if err != nil {
   417  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   418  	}
   419  	if err := compareKeyMaterial(km1, km2); err != nil {
   420  		t.Fatalf("expected provider to not update key material: %v", err)
   421  	}
   422  
   423  	// Update the key file to match the cert file.
   424  	createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
   425  
   426  	// Make sure update is picked up.
   427  	if _, err := distCh.Receive(ctx); err != nil {
   428  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   429  	}
   430  	km3, err := prov.KeyMaterial(ctx)
   431  	if err != nil {
   432  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   433  	}
   434  	if err := compareKeyMaterial(km2, km3); err == nil {
   435  		t.Fatal("expected provider to return new key material after update to underlying file")
   436  	}
   437  }
   438  

View as plain text