...

Source file src/github.com/google/s2a-go/internal/v2/certverifier/certverifier_test.go

Documentation: github.com/google/s2a-go/internal/v2/certverifier

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package certverifier
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"log"
    26  	"net"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	_ "embed"
    32  
    33  	"github.com/google/s2a-go/internal/v2/fakes2av2"
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/credentials/insecure"
    36  
    37  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    38  )
    39  
    40  const (
    41  	defaultTimeout                = 10.0 * time.Second
    42  	fakeServerAuthorizationPolicy = "fake server authorization policy"
    43  )
    44  
    45  var (
    46  	//go:embed testdata/client_root_cert.der
    47  	clientRootDERCert []byte
    48  	//go:embed testdata/client_intermediate_cert.der
    49  	clientIntermediateDERCert []byte
    50  	//go:embed testdata/client_leaf_cert.der
    51  	clientLeafDERCert []byte
    52  	//go:embed testdata/server_root_cert.der
    53  	serverRootDERCert []byte
    54  	//go:embed testdata/server_intermediate_cert.der
    55  	serverIntermediateDERCert []byte
    56  	//go:embed testdata/server_leaf_cert.der
    57  	serverLeafDERCert []byte
    58  )
    59  
    60  func startFakeS2Av2Server(wg *sync.WaitGroup, enableServerAuthorizationPolicyCheck bool) (stop func(), address string, err error) {
    61  	listener, err := net.Listen("tcp", ":0")
    62  	if err != nil {
    63  		log.Fatalf("Failed to listen on address %s: %v", address, err)
    64  	}
    65  	address = listener.Addr().String()
    66  	s := grpc.NewServer()
    67  	log.Printf("Server: started gRPC fake S2Av2 Server on address: %s", address)
    68  	if enableServerAuthorizationPolicyCheck {
    69  		s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{
    70  			ServerAuthorizationPolicy: []byte(fakeServerAuthorizationPolicy),
    71  		})
    72  	} else {
    73  		s2av2pb.RegisterS2AServiceServer(s, &fakes2av2.Server{})
    74  	}
    75  	go func() {
    76  		wg.Done()
    77  		if err := s.Serve(listener); err != nil {
    78  			log.Printf("Failed to serve: %v", err)
    79  		}
    80  	}()
    81  	return func() { s.Stop() }, address, nil
    82  }
    83  
    84  // TestVerifyClientCertChain runs unit tests for VerifyClientCertificateChain.
    85  func TestVerifyClientCertChain(t *testing.T) {
    86  	// Start up fake S2Av2 server.
    87  	var wg sync.WaitGroup
    88  	wg.Add(1)
    89  	stop, address, err := startFakeS2Av2Server(&wg, false)
    90  	wg.Wait()
    91  	if err != nil {
    92  		t.Fatalf("Error starting fake S2Av2 Server: %v", err)
    93  	}
    94  
    95  	for _, tc := range []struct {
    96  		description string
    97  		rawCerts    [][]byte
    98  		expectedErr error
    99  	}{
   100  		{
   101  			description: "empty chain",
   102  			rawCerts:    nil,
   103  			expectedErr: errors.New("client cert verification failed: client peer verification failed: client cert chain is empty"),
   104  		},
   105  		{
   106  			description: "chain of length 1",
   107  			rawCerts:    [][]byte{clientRootDERCert},
   108  			expectedErr: nil,
   109  		},
   110  		{
   111  			description: "chain of length 2 correct",
   112  			rawCerts:    [][]byte{clientLeafDERCert, clientIntermediateDERCert},
   113  			expectedErr: nil,
   114  		},
   115  		{
   116  			description: "chain of length 2 error: missing intermediate",
   117  			rawCerts:    [][]byte{clientLeafDERCert, clientRootDERCert},
   118  			expectedErr: errors.New("failed to offload client cert verification to S2A: 3, client peer verification failed: x509: certificate signed by unknown authority (possibly because of \"crypto/rsa: verification error\" while trying to verify candidate authority certificate \"s2a_test_cert\")"),
   119  		},
   120  	} {
   121  		t.Run(tc.description, func(t *testing.T) {
   122  			// Create new stream to S2Av2.
   123  			opts := []grpc.DialOption{
   124  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   125  				grpc.WithReturnConnectionError(),
   126  				grpc.WithBlock(),
   127  			}
   128  			conn, err := grpc.Dial(address, opts...)
   129  			if err != nil {
   130  				t.Fatalf("Client: failed to connect: %v", err)
   131  			}
   132  			defer conn.Close()
   133  			c := s2av2pb.NewS2AServiceClient(conn)
   134  			log.Printf("Client: connected to: %s", address)
   135  			ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   136  			defer cancel()
   137  
   138  			// Setup bidrectional streaming session.
   139  			callOpts := []grpc.CallOption{}
   140  			cstream, err := c.SetUpSession(ctx, callOpts...)
   141  			if err != nil {
   142  				t.Fatalf("Client: failed to setup bidirectional streaming RPC session: %v", err)
   143  			}
   144  			log.Printf("Client: set up bidirectional streaming RPC session.")
   145  
   146  			// TODO(rmehta19): Add verificationMode to struct, and vary between tests.
   147  			VerifyPeerCertificateFunc := VerifyClientCertificateChain(s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, cstream)
   148  			got, want := VerifyPeerCertificateFunc(tc.rawCerts, nil), tc.expectedErr
   149  			if want == nil {
   150  				if got != nil {
   151  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   152  				}
   153  			} else {
   154  				if got == nil {
   155  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   156  				}
   157  				if got.Error() != want.Error() {
   158  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   159  				}
   160  			}
   161  		})
   162  	}
   163  	stop()
   164  }
   165  
   166  // TestVerifyServerCertChainWithServerAuthorizationPolicy runs unit tests for VerifyServerCertificateChain with server authorization policy check.
   167  func TestVerifyServerCertChainWithServerAuthorizationPolicy(t *testing.T) {
   168  	// Start up fake S2Av2 server.
   169  	var wg sync.WaitGroup
   170  	wg.Add(1)
   171  	stop, address, err := startFakeS2Av2Server(&wg, true)
   172  	wg.Wait()
   173  	if err != nil {
   174  		t.Fatalf("Error starting fake S2Av2 Server: %v", err)
   175  	}
   176  
   177  	for _, tc := range []struct {
   178  		description               string
   179  		hostname                  string
   180  		rawCerts                  [][]byte
   181  		expectedErr               error
   182  		serverAuthorizationPolicy []byte
   183  	}{
   184  		{
   185  			description:               "empty chain",
   186  			hostname:                  "host",
   187  			rawCerts:                  nil,
   188  			expectedErr:               errors.New("server cert verification failed: server peer verification failed: server cert chain is empty"),
   189  			serverAuthorizationPolicy: []byte(fakeServerAuthorizationPolicy),
   190  		},
   191  		{
   192  			description:               "invalid server authorization policy",
   193  			hostname:                  "host",
   194  			rawCerts:                  [][]byte{serverRootDERCert},
   195  			expectedErr:               fmt.Errorf("rpc error: code = Unknown desc = server peer verification failed: invalid server authorization policy, expected: %s, got: ", fakeServerAuthorizationPolicy),
   196  			serverAuthorizationPolicy: nil,
   197  		},
   198  		{
   199  			description:               "chain of length 1",
   200  			hostname:                  "host",
   201  			rawCerts:                  [][]byte{serverRootDERCert},
   202  			expectedErr:               nil,
   203  			serverAuthorizationPolicy: []byte(fakeServerAuthorizationPolicy),
   204  		},
   205  		{
   206  			description:               "chain of length 2 correct",
   207  			hostname:                  "host",
   208  			rawCerts:                  [][]byte{serverLeafDERCert, serverIntermediateDERCert},
   209  			expectedErr:               nil,
   210  			serverAuthorizationPolicy: []byte(fakeServerAuthorizationPolicy),
   211  		},
   212  	} {
   213  		t.Run(tc.description, func(t *testing.T) {
   214  			// Create new stream to S2Av2.
   215  			opts := []grpc.DialOption{
   216  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   217  				grpc.WithReturnConnectionError(),
   218  				grpc.WithBlock(),
   219  			}
   220  			conn, err := grpc.Dial(address, opts...)
   221  			if err != nil {
   222  				t.Fatalf("Client: failed to connect: %v", err)
   223  			}
   224  			defer conn.Close()
   225  			c := s2av2pb.NewS2AServiceClient(conn)
   226  			log.Printf("Client: connected to: %s", address)
   227  			ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   228  			defer cancel()
   229  
   230  			// Setup bidrectional streaming session.
   231  			callOpts := []grpc.CallOption{}
   232  			cstream, err := c.SetUpSession(ctx, callOpts...)
   233  			if err != nil {
   234  				t.Fatalf("Client: failed to setup bidirectional streaming RPC session: %v", err)
   235  			}
   236  			log.Printf("Client: set up bidirectional streaming RPC session.")
   237  
   238  			// TODO(rmehta19): Add verificationMode to struct, and vary between tests.
   239  			VerifyPeerCertificateFunc := VerifyServerCertificateChain(tc.hostname, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, cstream, tc.serverAuthorizationPolicy)
   240  			got, want := VerifyPeerCertificateFunc(tc.rawCerts, nil), tc.expectedErr
   241  			if want == nil {
   242  				if got != nil {
   243  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   244  				}
   245  			} else {
   246  				if got == nil {
   247  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   248  				}
   249  				if got.Error() != want.Error() {
   250  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   251  				}
   252  			}
   253  		})
   254  	}
   255  	stop()
   256  }
   257  
   258  // TestVerifyServerCertChainWithoutServerAuthorizationPolicy runs unit tests for VerifyServerCertificateChain without server authorization policy check.
   259  func TestVerifyServerCertChainWithoutServerAuthorizationPolicy(t *testing.T) {
   260  	// Start up fake S2Av2 server.
   261  	var wg sync.WaitGroup
   262  	wg.Add(1)
   263  	stop, address, err := startFakeS2Av2Server(&wg, false)
   264  	wg.Wait()
   265  	if err != nil {
   266  		t.Fatalf("Error starting fake S2Av2 Server: %v", err)
   267  	}
   268  
   269  	for _, tc := range []struct {
   270  		description               string
   271  		hostname                  string
   272  		rawCerts                  [][]byte
   273  		expectedErr               error
   274  		serverAuthorizationPolicy []byte
   275  	}{
   276  		{
   277  			description:               "empty chain",
   278  			hostname:                  "host",
   279  			rawCerts:                  nil,
   280  			expectedErr:               errors.New("server cert verification failed: server peer verification failed: server cert chain is empty"),
   281  			serverAuthorizationPolicy: []byte(fakeServerAuthorizationPolicy),
   282  		},
   283  		{
   284  			description:               "chain of length 1",
   285  			hostname:                  "host",
   286  			rawCerts:                  [][]byte{serverRootDERCert},
   287  			expectedErr:               nil,
   288  			serverAuthorizationPolicy: nil,
   289  		},
   290  		{
   291  			description:               "chain of length 2 correct",
   292  			hostname:                  "host",
   293  			rawCerts:                  [][]byte{serverLeafDERCert, serverIntermediateDERCert},
   294  			expectedErr:               nil,
   295  			serverAuthorizationPolicy: nil,
   296  		},
   297  	} {
   298  		t.Run(tc.description, func(t *testing.T) {
   299  			// Create new stream to S2Av2.
   300  			opts := []grpc.DialOption{
   301  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   302  				grpc.WithReturnConnectionError(),
   303  				grpc.WithBlock(),
   304  			}
   305  			conn, err := grpc.Dial(address, opts...)
   306  			if err != nil {
   307  				t.Fatalf("Client: failed to connect: %v", err)
   308  			}
   309  			defer conn.Close()
   310  			c := s2av2pb.NewS2AServiceClient(conn)
   311  			log.Printf("Client: connected to: %s", address)
   312  			ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   313  			defer cancel()
   314  
   315  			// Setup bidrectional streaming session.
   316  			callOpts := []grpc.CallOption{}
   317  			cstream, err := c.SetUpSession(ctx, callOpts...)
   318  			if err != nil {
   319  				t.Fatalf("Client: failed to setup bidirectional streaming RPC session: %v", err)
   320  			}
   321  			log.Printf("Client: set up bidirectional streaming RPC session.")
   322  
   323  			VerifyPeerCertificateFunc := VerifyServerCertificateChain(tc.hostname, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, cstream, tc.serverAuthorizationPolicy)
   324  			got, want := VerifyPeerCertificateFunc(tc.rawCerts, nil), tc.expectedErr
   325  			if want == nil {
   326  				if got != nil {
   327  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   328  				}
   329  			} else {
   330  				if got == nil {
   331  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   332  				}
   333  				if got.Error() != want.Error() {
   334  					t.Errorf("Peer certificate verification failed, got: %v, want: %v", got, want)
   335  				}
   336  			}
   337  		})
   338  	}
   339  	stop()
   340  }
   341  

View as plain text