...

Source file src/google.golang.org/grpc/internal/xds/bootstrap/tlscreds/bundle_ext_test.go

Documentation: google.golang.org/grpc/internal/xds/bootstrap/tlscreds

     1  /*
     2   *
     3   * Copyright 2023 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package tlscreds_test
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"encoding/json"
    25  	"fmt"
    26  	"os"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/internal/grpctest"
    34  	"google.golang.org/grpc/internal/stubserver"
    35  	"google.golang.org/grpc/internal/testutils/xds/e2e"
    36  	"google.golang.org/grpc/internal/xds/bootstrap/tlscreds"
    37  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    38  	testpb "google.golang.org/grpc/interop/grpc_testing"
    39  	"google.golang.org/grpc/status"
    40  	"google.golang.org/grpc/testdata"
    41  )
    42  
    43  const defaultTestTimeout = 5 * time.Second
    44  
    45  type s struct {
    46  	grpctest.Tester
    47  }
    48  
    49  func Test(t *testing.T) {
    50  	grpctest.RunSubTests(t, s{})
    51  }
    52  
    53  type Closable interface {
    54  	Close()
    55  }
    56  
    57  func (s) TestValidTlsBuilder(t *testing.T) {
    58  	caCert := testdata.Path("x509/server_ca_cert.pem")
    59  	clientCert := testdata.Path("x509/client1_cert.pem")
    60  	clientKey := testdata.Path("x509/client1_key.pem")
    61  	tests := []struct {
    62  		name string
    63  		jd   string
    64  	}{
    65  		{
    66  			name: "Absent configuration",
    67  			jd:   `null`,
    68  		},
    69  		{
    70  			name: "Empty configuration",
    71  			jd:   `{}`,
    72  		},
    73  		{
    74  			name: "Only CA certificate chain",
    75  			jd:   fmt.Sprintf(`{"ca_certificate_file": "%s"}`, caCert),
    76  		},
    77  		{
    78  			name: "Only private key and certificate chain",
    79  			jd:   fmt.Sprintf(`{"certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
    80  		},
    81  		{
    82  			name: "CA chain, private key and certificate chain",
    83  			jd:   fmt.Sprintf(`{"ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
    84  		},
    85  		{
    86  			name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`,
    87  		},
    88  		{
    89  			name: "Refresh interval and CA certificate chain",
    90  			jd:   fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file": "%s"}`, caCert),
    91  		},
    92  		{
    93  			name: "Refresh interval, private key and certificate chain",
    94  			jd:   fmt.Sprintf(`{"refresh_interval": "1s","certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey),
    95  		},
    96  		{
    97  			name: "Refresh interval, CA chain, private key and certificate chain",
    98  			jd:   fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey),
    99  		},
   100  		{
   101  			name: "Unknown field",
   102  			jd:   `{"unknown_field": "foo"}`,
   103  		},
   104  	}
   105  
   106  	for _, test := range tests {
   107  		t.Run(test.name, func(t *testing.T) {
   108  			msg := json.RawMessage(test.jd)
   109  			_, stop, err := tlscreds.NewBundle(msg)
   110  			if err != nil {
   111  				t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err)
   112  			}
   113  			stop()
   114  		})
   115  	}
   116  }
   117  
   118  func (s) TestInvalidTlsBuilder(t *testing.T) {
   119  	tests := []struct {
   120  		name, jd, wantErrPrefix string
   121  	}{
   122  		{
   123  			name:          "Wrong type in json",
   124  			jd:            `{"ca_certificate_file": 1}`,
   125  			wantErrPrefix: "failed to unmarshal config:"},
   126  		{
   127  			name:          "Missing private key",
   128  			jd:            fmt.Sprintf(`{"certificate_file":"%s"}`, testdata.Path("x509/server_cert.pem")),
   129  			wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified",
   130  		},
   131  	}
   132  
   133  	for _, test := range tests {
   134  		t.Run(test.name, func(t *testing.T) {
   135  			msg := json.RawMessage(test.jd)
   136  			_, stop, err := tlscreds.NewBundle(msg)
   137  			if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) {
   138  				if stop != nil {
   139  					stop()
   140  				}
   141  				t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix)
   142  			}
   143  		})
   144  	}
   145  }
   146  
   147  func (s) TestCaReloading(t *testing.T) {
   148  	serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
   149  	if err != nil {
   150  		t.Fatalf("Failed to read test CA cert: %s", err)
   151  	}
   152  
   153  	// Write CA certs to a temporary file so that we can modify it later.
   154  	caPath := t.TempDir() + "/ca.pem"
   155  	if err = os.WriteFile(caPath, serverCa, 0644); err != nil {
   156  		t.Fatalf("Failed to write test CA cert: %v", err)
   157  	}
   158  	cfg := fmt.Sprintf(`{
   159  		"ca_certificate_file": "%s",
   160  		"refresh_interval": ".01s"
   161  	}`, caPath)
   162  	tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
   163  	if err != nil {
   164  		t.Fatalf("Failed to create TLS bundle: %v", err)
   165  	}
   166  	defer stop()
   167  
   168  	serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert))
   169  	server := stubserver.StartTestService(t, nil, serverCredentials)
   170  
   171  	conn, err := grpc.NewClient(
   172  		server.Address,
   173  		grpc.WithCredentialsBundle(tlsBundle),
   174  		grpc.WithAuthority("x.test.example.com"),
   175  	)
   176  	if err != nil {
   177  		t.Fatalf("Error dialing: %v", err)
   178  	}
   179  	defer conn.Close()
   180  
   181  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   182  	defer cancel()
   183  
   184  	client := testgrpc.NewTestServiceClient(conn)
   185  	if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   186  		t.Errorf("Error calling EmptyCall: %v", err)
   187  	}
   188  	// close the server and create a new one to force client to do a new
   189  	// handshake.
   190  	server.Stop()
   191  
   192  	invalidCa, err := os.ReadFile(testdata.Path("ca.pem"))
   193  	if err != nil {
   194  		t.Fatalf("Failed to read test CA cert: %v", err)
   195  	}
   196  	// unload root cert
   197  	err = os.WriteFile(caPath, invalidCa, 0644)
   198  	if err != nil {
   199  		t.Fatalf("Failed to write test CA cert: %v", err)
   200  	}
   201  
   202  	for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) {
   203  		ss := stubserver.StubServer{
   204  			Address:    server.Address,
   205  			EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil },
   206  		}
   207  		server = stubserver.StartTestService(t, &ss, serverCredentials)
   208  
   209  		// Client handshake should eventually fail because the client CA was
   210  		// reloaded, and thus the server cert is signed by an unknown CA.
   211  		t.Log(server)
   212  		_, err = client.EmptyCall(ctx, &testpb.Empty{})
   213  		const wantErr = "certificate signed by unknown authority"
   214  		if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) {
   215  			// Certs have reloaded.
   216  			server.Stop()
   217  			break
   218  		}
   219  		t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr)
   220  		server.Stop()
   221  	}
   222  	if ctx.Err() != nil {
   223  		t.Errorf("Timed out waiting for CA certs reloading")
   224  	}
   225  }
   226  
   227  func (s) TestMTLS(t *testing.T) {
   228  	s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert)))
   229  	defer s.Stop()
   230  
   231  	cfg := fmt.Sprintf(`{
   232  		"ca_certificate_file": "%s",
   233  		"certificate_file": "%s",
   234  		"private_key_file": "%s"
   235  	}`,
   236  		testdata.Path("x509/server_ca_cert.pem"),
   237  		testdata.Path("x509/client1_cert.pem"),
   238  		testdata.Path("x509/client1_key.pem"))
   239  	tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg))
   240  	if err != nil {
   241  		t.Fatalf("Failed to create TLS bundle: %v", err)
   242  	}
   243  	defer stop()
   244  	conn, err := grpc.NewClient(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com"))
   245  	if err != nil {
   246  		t.Fatalf("Error dialing: %v", err)
   247  	}
   248  	defer conn.Close()
   249  	client := testgrpc.NewTestServiceClient(conn)
   250  	if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
   251  		t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
   252  	}
   253  }
   254  

View as plain text