...

Source file src/github.com/google/s2a-go/internal/v2/tlsconfigstore/tlsconfigstore.go

Documentation: github.com/google/s2a-go/internal/v2/tlsconfigstore

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package tlsconfigstore offloads operations to S2Av2.
    20  package tlsconfigstore
    21  
    22  import (
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"encoding/pem"
    26  	"errors"
    27  	"fmt"
    28  
    29  	"github.com/google/s2a-go/internal/tokenmanager"
    30  	"github.com/google/s2a-go/internal/v2/certverifier"
    31  	"github.com/google/s2a-go/internal/v2/remotesigner"
    32  	"github.com/google/s2a-go/stream"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/grpclog"
    35  
    36  	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
    37  	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
    38  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    39  )
    40  
    41  const (
    42  	// HTTP/2
    43  	h2 = "h2"
    44  )
    45  
    46  // GetTLSConfigurationForClient returns a tls.Config instance for use by a client application.
    47  func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, serverAuthorizationPolicy []byte) (*tls.Config, error) {
    48  	authMechanisms := getAuthMechanisms(tokenManager, []*commonpbv1.Identity{localIdentity})
    49  
    50  	if grpclog.V(1) {
    51  		grpclog.Infof("Sending request to S2Av2 for client TLS config.")
    52  	}
    53  	// Send request to S2Av2 for config.
    54  	if err := s2AStream.Send(&s2av2pb.SessionReq{
    55  		LocalIdentity:            localIdentity,
    56  		AuthenticationMechanisms: authMechanisms,
    57  		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
    58  			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
    59  				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT,
    60  			},
    61  		},
    62  	}); err != nil {
    63  		grpclog.Infof("Failed to send request to S2Av2 for client TLS config")
    64  		return nil, err
    65  	}
    66  
    67  	// Get the response containing config from S2Av2.
    68  	resp, err := s2AStream.Recv()
    69  	if err != nil {
    70  		grpclog.Infof("Failed to receive client TLS config response from S2Av2.")
    71  		return nil, err
    72  	}
    73  
    74  	// TODO(rmehta19): Add unit test for this if statement.
    75  	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
    76  		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
    77  	}
    78  
    79  	// Extract TLS configiguration from SessionResp.
    80  	tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()
    81  
    82  	var cert tls.Certificate
    83  	for i, v := range tlsConfig.CertificateChain {
    84  		// Populate Certificates field.
    85  		block, _ := pem.Decode([]byte(v))
    86  		if block == nil {
    87  			return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
    88  		}
    89  		x509Cert, err := x509.ParseCertificate(block.Bytes)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  		cert.Certificate = append(cert.Certificate, x509Cert.Raw)
    94  		if i == 0 {
    95  			cert.Leaf = x509Cert
    96  		}
    97  	}
    98  
    99  	if len(tlsConfig.CertificateChain) > 0 {
   100  		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
   101  		if cert.PrivateKey == nil {
   102  			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
   103  		}
   104  	}
   105  
   106  	minVersion, maxVersion, err := getTLSMinMaxVersionsClient(tlsConfig)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	// Create mTLS credentials for client.
   112  	config := &tls.Config{
   113  		VerifyPeerCertificate:  certverifier.VerifyServerCertificateChain(serverHostname, verificationMode, s2AStream, serverAuthorizationPolicy),
   114  		ServerName:             serverHostname,
   115  		InsecureSkipVerify:     true, // NOLINT
   116  		ClientSessionCache:     nil,
   117  		SessionTicketsDisabled: true,
   118  		MinVersion:             minVersion,
   119  		MaxVersion:             maxVersion,
   120  		NextProtos:             []string{h2},
   121  	}
   122  	if len(tlsConfig.CertificateChain) > 0 {
   123  		config.Certificates = []tls.Certificate{cert}
   124  	}
   125  	return config, nil
   126  }
   127  
   128  // GetTLSConfigurationForServer returns a tls.Config instance for use by a server application.
   129  func GetTLSConfigurationForServer(s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode) (*tls.Config, error) {
   130  	return &tls.Config{
   131  		GetConfigForClient: ClientConfig(tokenManager, localIdentities, verificationMode, s2AStream),
   132  	}, nil
   133  }
   134  
   135  // ClientConfig builds a TLS config for a server to establish a secure
   136  // connection with a client, based on SNI communicated during ClientHello.
   137  // Ensures that server presents the correct certificate to establish a TLS
   138  // connection.
   139  func ClientConfig(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
   140  	return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
   141  		tlsConfig, err := getServerConfigFromS2Av2(tokenManager, localIdentities, chi.ServerName, s2AStream)
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  
   146  		var cert tls.Certificate
   147  		for i, v := range tlsConfig.CertificateChain {
   148  			// Populate Certificates field.
   149  			block, _ := pem.Decode([]byte(v))
   150  			if block == nil {
   151  				return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
   152  			}
   153  			x509Cert, err := x509.ParseCertificate(block.Bytes)
   154  			if err != nil {
   155  				return nil, err
   156  			}
   157  			cert.Certificate = append(cert.Certificate, x509Cert.Raw)
   158  			if i == 0 {
   159  				cert.Leaf = x509Cert
   160  			}
   161  		}
   162  
   163  		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
   164  		if cert.PrivateKey == nil {
   165  			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
   166  		}
   167  
   168  		minVersion, maxVersion, err := getTLSMinMaxVersionsServer(tlsConfig)
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  
   173  		clientAuth := getTLSClientAuthType(tlsConfig)
   174  
   175  		var cipherSuites []uint16
   176  		cipherSuites = getCipherSuites(tlsConfig.Ciphersuites)
   177  
   178  		// Create mTLS credentials for server.
   179  		return &tls.Config{
   180  			Certificates:           []tls.Certificate{cert},
   181  			VerifyPeerCertificate:  certverifier.VerifyClientCertificateChain(verificationMode, s2AStream),
   182  			ClientAuth:             clientAuth,
   183  			CipherSuites:           cipherSuites,
   184  			SessionTicketsDisabled: true,
   185  			MinVersion:             minVersion,
   186  			MaxVersion:             maxVersion,
   187  			NextProtos:             []string{h2},
   188  		}, nil
   189  	}
   190  }
   191  
   192  func getCipherSuites(tlsConfigCipherSuites []commonpb.Ciphersuite) []uint16 {
   193  	var tlsGoCipherSuites []uint16
   194  	for _, v := range tlsConfigCipherSuites {
   195  		s := getTLSCipherSuite(v)
   196  		if s != 0xffff {
   197  			tlsGoCipherSuites = append(tlsGoCipherSuites, s)
   198  		}
   199  	}
   200  	return tlsGoCipherSuites
   201  }
   202  
   203  func getTLSCipherSuite(tlsCipherSuite commonpb.Ciphersuite) uint16 {
   204  	switch tlsCipherSuite {
   205  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
   206  		return tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
   207  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
   208  		return tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
   209  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
   210  		return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
   211  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
   212  		return tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
   213  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
   214  		return tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
   215  	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
   216  		return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
   217  	default:
   218  		return 0xffff
   219  	}
   220  }
   221  
   222  func getServerConfigFromS2Av2(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, sni string, s2AStream stream.S2AStream) (*s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration, error) {
   223  	authMechanisms := getAuthMechanisms(tokenManager, localIdentities)
   224  	var locID *commonpbv1.Identity
   225  	if localIdentities != nil {
   226  		locID = localIdentities[0]
   227  	}
   228  
   229  	if err := s2AStream.Send(&s2av2pb.SessionReq{
   230  		LocalIdentity:            locID,
   231  		AuthenticationMechanisms: authMechanisms,
   232  		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
   233  			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
   234  				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_SERVER,
   235  				Sni:            sni,
   236  			},
   237  		},
   238  	}); err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	resp, err := s2AStream.Recv()
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  
   247  	// TODO(rmehta19): Add unit test for this if statement.
   248  	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
   249  		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
   250  	}
   251  
   252  	return resp.GetGetTlsConfigurationResp().GetServerTlsConfiguration(), nil
   253  }
   254  
   255  func getTLSClientAuthType(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) tls.ClientAuthType {
   256  	var clientAuth tls.ClientAuthType
   257  	switch x := tlsConfig.RequestClientCertificate; x {
   258  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_DONT_REQUEST_CLIENT_CERTIFICATE:
   259  		clientAuth = tls.NoClientCert
   260  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
   261  		clientAuth = tls.RequestClientCert
   262  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
   263  		// This case actually maps to tls.VerifyClientCertIfGiven. However this
   264  		// mapping triggers normal verification, followed by custom verification,
   265  		// specified in VerifyPeerCertificate. To bypass normal verification, and
   266  		// only do custom verification we set clientAuth to RequireAnyClientCert or
   267  		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
   268  		// discussion.
   269  		clientAuth = tls.RequireAnyClientCert
   270  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
   271  		clientAuth = tls.RequireAnyClientCert
   272  	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
   273  		// This case actually maps to tls.RequireAndVerifyClientCert. However this
   274  		// mapping triggers normal verification, followed by custom verification,
   275  		// specified in VerifyPeerCertificate. To bypass normal verification, and
   276  		// only do custom verification we set clientAuth to RequireAnyClientCert or
   277  		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
   278  		// discussion.
   279  		clientAuth = tls.RequireAnyClientCert
   280  	default:
   281  		clientAuth = tls.RequireAnyClientCert
   282  	}
   283  	return clientAuth
   284  }
   285  
   286  func getAuthMechanisms(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity) []*s2av2pb.AuthenticationMechanism {
   287  	if tokenManager == nil {
   288  		return nil
   289  	}
   290  	if len(localIdentities) == 0 {
   291  		token, err := tokenManager.DefaultToken()
   292  		if err != nil {
   293  			grpclog.Infof("Unable to get token for empty local identity: %v", err)
   294  			return nil
   295  		}
   296  		return []*s2av2pb.AuthenticationMechanism{
   297  			{
   298  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
   299  					Token: token,
   300  				},
   301  			},
   302  		}
   303  	}
   304  	var authMechanisms []*s2av2pb.AuthenticationMechanism
   305  	for _, localIdentity := range localIdentities {
   306  		if localIdentity == nil {
   307  			token, err := tokenManager.DefaultToken()
   308  			if err != nil {
   309  				grpclog.Infof("Unable to get default token for local identity %v: %v", localIdentity, err)
   310  				continue
   311  			}
   312  			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
   313  				Identity: localIdentity,
   314  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
   315  					Token: token,
   316  				},
   317  			})
   318  		} else {
   319  			token, err := tokenManager.Token(localIdentity)
   320  			if err != nil {
   321  				grpclog.Infof("Unable to get token for local identity %v: %v", localIdentity, err)
   322  				continue
   323  			}
   324  			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
   325  				Identity: localIdentity,
   326  				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
   327  					Token: token,
   328  				},
   329  			})
   330  		}
   331  	}
   332  	return authMechanisms
   333  }
   334  
   335  // TODO(rmehta19): refactor switch statements into a helper function.
   336  func getTLSMinMaxVersionsClient(tlsConfig *s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration) (uint16, uint16, error) {
   337  	// Map S2Av2 TLSVersion to consts defined in tls package.
   338  	var minVersion uint16
   339  	var maxVersion uint16
   340  	switch x := tlsConfig.MinTlsVersion; x {
   341  	case commonpb.TLSVersion_TLS_VERSION_1_0:
   342  		minVersion = tls.VersionTLS10
   343  	case commonpb.TLSVersion_TLS_VERSION_1_1:
   344  		minVersion = tls.VersionTLS11
   345  	case commonpb.TLSVersion_TLS_VERSION_1_2:
   346  		minVersion = tls.VersionTLS12
   347  	case commonpb.TLSVersion_TLS_VERSION_1_3:
   348  		minVersion = tls.VersionTLS13
   349  	default:
   350  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
   351  	}
   352  
   353  	switch x := tlsConfig.MaxTlsVersion; x {
   354  	case commonpb.TLSVersion_TLS_VERSION_1_0:
   355  		maxVersion = tls.VersionTLS10
   356  	case commonpb.TLSVersion_TLS_VERSION_1_1:
   357  		maxVersion = tls.VersionTLS11
   358  	case commonpb.TLSVersion_TLS_VERSION_1_2:
   359  		maxVersion = tls.VersionTLS12
   360  	case commonpb.TLSVersion_TLS_VERSION_1_3:
   361  		maxVersion = tls.VersionTLS13
   362  	default:
   363  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
   364  	}
   365  	if minVersion > maxVersion {
   366  		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
   367  	}
   368  	return minVersion, maxVersion, nil
   369  }
   370  
   371  func getTLSMinMaxVersionsServer(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) (uint16, uint16, error) {
   372  	// Map S2Av2 TLSVersion to consts defined in tls package.
   373  	var minVersion uint16
   374  	var maxVersion uint16
   375  	switch x := tlsConfig.MinTlsVersion; x {
   376  	case commonpb.TLSVersion_TLS_VERSION_1_0:
   377  		minVersion = tls.VersionTLS10
   378  	case commonpb.TLSVersion_TLS_VERSION_1_1:
   379  		minVersion = tls.VersionTLS11
   380  	case commonpb.TLSVersion_TLS_VERSION_1_2:
   381  		minVersion = tls.VersionTLS12
   382  	case commonpb.TLSVersion_TLS_VERSION_1_3:
   383  		minVersion = tls.VersionTLS13
   384  	default:
   385  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
   386  	}
   387  
   388  	switch x := tlsConfig.MaxTlsVersion; x {
   389  	case commonpb.TLSVersion_TLS_VERSION_1_0:
   390  		maxVersion = tls.VersionTLS10
   391  	case commonpb.TLSVersion_TLS_VERSION_1_1:
   392  		maxVersion = tls.VersionTLS11
   393  	case commonpb.TLSVersion_TLS_VERSION_1_2:
   394  		maxVersion = tls.VersionTLS12
   395  	case commonpb.TLSVersion_TLS_VERSION_1_3:
   396  		maxVersion = tls.VersionTLS13
   397  	default:
   398  		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
   399  	}
   400  	if minVersion > maxVersion {
   401  		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
   402  	}
   403  	return minVersion, maxVersion, nil
   404  }
   405  

View as plain text