...

Source file src/github.com/google/s2a-go/s2a.go

Documentation: github.com/google/s2a-go

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package s2a provides the S2A transport credentials used by a gRPC
    20  // application.
    21  package s2a
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"errors"
    27  	"fmt"
    28  	"net"
    29  	"sync"
    30  	"time"
    31  
    32  	"github.com/golang/protobuf/proto"
    33  	"github.com/google/s2a-go/fallback"
    34  	"github.com/google/s2a-go/internal/handshaker"
    35  	"github.com/google/s2a-go/internal/handshaker/service"
    36  	"github.com/google/s2a-go/internal/tokenmanager"
    37  	"github.com/google/s2a-go/internal/v2"
    38  	"github.com/google/s2a-go/retry"
    39  	"google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/grpclog"
    41  
    42  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    43  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    44  )
    45  
    46  const (
    47  	s2aSecurityProtocol = "tls"
    48  	// defaultTimeout specifies the default server handshake timeout.
    49  	defaultTimeout = 30.0 * time.Second
    50  )
    51  
    52  // s2aTransportCreds are the transport credentials required for establishing
    53  // a secure connection using the S2A. They implement the
    54  // credentials.TransportCredentials interface.
    55  type s2aTransportCreds struct {
    56  	info          *credentials.ProtocolInfo
    57  	minTLSVersion commonpb.TLSVersion
    58  	maxTLSVersion commonpb.TLSVersion
    59  	// tlsCiphersuites contains the ciphersuites used in the S2A connection.
    60  	// Note that these are currently unconfigurable.
    61  	tlsCiphersuites []commonpb.Ciphersuite
    62  	// localIdentity should only be used by the client.
    63  	localIdentity *commonpb.Identity
    64  	// localIdentities should only be used by the server.
    65  	localIdentities []*commonpb.Identity
    66  	// targetIdentities should only be used by the client.
    67  	targetIdentities            []*commonpb.Identity
    68  	isClient                    bool
    69  	s2aAddr                     string
    70  	ensureProcessSessionTickets *sync.WaitGroup
    71  }
    72  
    73  // NewClientCreds returns a client-side transport credentials object that uses
    74  // the S2A to establish a secure connection with a server.
    75  func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
    76  	if opts == nil {
    77  		return nil, errors.New("nil client options")
    78  	}
    79  	var targetIdentities []*commonpb.Identity
    80  	for _, targetIdentity := range opts.TargetIdentities {
    81  		protoTargetIdentity, err := toProtoIdentity(targetIdentity)
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  		targetIdentities = append(targetIdentities, protoTargetIdentity)
    86  	}
    87  	localIdentity, err := toProtoIdentity(opts.LocalIdentity)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if opts.EnableLegacyMode {
    92  		return &s2aTransportCreds{
    93  			info: &credentials.ProtocolInfo{
    94  				SecurityProtocol: s2aSecurityProtocol,
    95  			},
    96  			minTLSVersion: commonpb.TLSVersion_TLS1_3,
    97  			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
    98  			tlsCiphersuites: []commonpb.Ciphersuite{
    99  				commonpb.Ciphersuite_AES_128_GCM_SHA256,
   100  				commonpb.Ciphersuite_AES_256_GCM_SHA384,
   101  				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   102  			},
   103  			localIdentity:               localIdentity,
   104  			targetIdentities:            targetIdentities,
   105  			isClient:                    true,
   106  			s2aAddr:                     opts.S2AAddress,
   107  			ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
   108  		}, nil
   109  	}
   110  	verificationMode := getVerificationMode(opts.VerificationMode)
   111  	var fallbackFunc fallback.ClientHandshake
   112  	if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
   113  		fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
   114  	}
   115  	return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
   116  }
   117  
   118  // NewServerCreds returns a server-side transport credentials object that uses
   119  // the S2A to establish a secure connection with a client.
   120  func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
   121  	if opts == nil {
   122  		return nil, errors.New("nil server options")
   123  	}
   124  	var localIdentities []*commonpb.Identity
   125  	for _, localIdentity := range opts.LocalIdentities {
   126  		protoLocalIdentity, err := toProtoIdentity(localIdentity)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  		localIdentities = append(localIdentities, protoLocalIdentity)
   131  	}
   132  	if opts.EnableLegacyMode {
   133  		return &s2aTransportCreds{
   134  			info: &credentials.ProtocolInfo{
   135  				SecurityProtocol: s2aSecurityProtocol,
   136  			},
   137  			minTLSVersion: commonpb.TLSVersion_TLS1_3,
   138  			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
   139  			tlsCiphersuites: []commonpb.Ciphersuite{
   140  				commonpb.Ciphersuite_AES_128_GCM_SHA256,
   141  				commonpb.Ciphersuite_AES_256_GCM_SHA384,
   142  				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
   143  			},
   144  			localIdentities: localIdentities,
   145  			isClient:        false,
   146  			s2aAddr:         opts.S2AAddress,
   147  		}, nil
   148  	}
   149  	verificationMode := getVerificationMode(opts.VerificationMode)
   150  	return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, localIdentities, verificationMode, opts.getS2AStream)
   151  }
   152  
   153  // ClientHandshake initiates a client-side TLS handshake using the S2A.
   154  func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   155  	if !c.isClient {
   156  		return nil, nil, errors.New("client handshake called using server transport credentials")
   157  	}
   158  
   159  	var cancel context.CancelFunc
   160  	ctx, cancel = context.WithCancel(ctx)
   161  	defer cancel()
   162  
   163  	// Connect to the S2A.
   164  	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
   165  	if err != nil {
   166  		grpclog.Infof("Failed to connect to S2A: %v", err)
   167  		return nil, nil, err
   168  	}
   169  
   170  	opts := &handshaker.ClientHandshakerOptions{
   171  		MinTLSVersion:               c.minTLSVersion,
   172  		MaxTLSVersion:               c.maxTLSVersion,
   173  		TLSCiphersuites:             c.tlsCiphersuites,
   174  		TargetIdentities:            c.targetIdentities,
   175  		LocalIdentity:               c.localIdentity,
   176  		TargetName:                  serverAuthority,
   177  		EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
   178  	}
   179  	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
   180  	if err != nil {
   181  		grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
   182  		return nil, nil, err
   183  	}
   184  	defer func() {
   185  		if err != nil {
   186  			if closeErr := chs.Close(); closeErr != nil {
   187  				grpclog.Infof("Close failed unexpectedly: %v", err)
   188  				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
   189  			}
   190  		}
   191  	}()
   192  
   193  	secConn, authInfo, err := chs.ClientHandshake(context.Background())
   194  	if err != nil {
   195  		grpclog.Infof("Handshake failed: %v", err)
   196  		return nil, nil, err
   197  	}
   198  	return secConn, authInfo, nil
   199  }
   200  
   201  // ServerHandshake initiates a server-side TLS handshake using the S2A.
   202  func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   203  	if c.isClient {
   204  		return nil, nil, errors.New("server handshake called using client transport credentials")
   205  	}
   206  
   207  	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
   208  	defer cancel()
   209  
   210  	// Connect to the S2A.
   211  	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
   212  	if err != nil {
   213  		grpclog.Infof("Failed to connect to S2A: %v", err)
   214  		return nil, nil, err
   215  	}
   216  
   217  	opts := &handshaker.ServerHandshakerOptions{
   218  		MinTLSVersion:   c.minTLSVersion,
   219  		MaxTLSVersion:   c.maxTLSVersion,
   220  		TLSCiphersuites: c.tlsCiphersuites,
   221  		LocalIdentities: c.localIdentities,
   222  	}
   223  	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
   224  	if err != nil {
   225  		grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
   226  		return nil, nil, err
   227  	}
   228  	defer func() {
   229  		if err != nil {
   230  			if closeErr := shs.Close(); closeErr != nil {
   231  				grpclog.Infof("Close failed unexpectedly: %v", err)
   232  				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
   233  			}
   234  		}
   235  	}()
   236  
   237  	secConn, authInfo, err := shs.ServerHandshake(context.Background())
   238  	if err != nil {
   239  		grpclog.Infof("Handshake failed: %v", err)
   240  		return nil, nil, err
   241  	}
   242  	return secConn, authInfo, nil
   243  }
   244  
   245  func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
   246  	return *c.info
   247  }
   248  
   249  func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
   250  	info := *c.info
   251  	var localIdentity *commonpb.Identity
   252  	if c.localIdentity != nil {
   253  		localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
   254  	}
   255  	var localIdentities []*commonpb.Identity
   256  	if c.localIdentities != nil {
   257  		localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
   258  		for i, localIdentity := range c.localIdentities {
   259  			localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
   260  		}
   261  	}
   262  	var targetIdentities []*commonpb.Identity
   263  	if c.targetIdentities != nil {
   264  		targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
   265  		for i, targetIdentity := range c.targetIdentities {
   266  			targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
   267  		}
   268  	}
   269  	return &s2aTransportCreds{
   270  		info:             &info,
   271  		minTLSVersion:    c.minTLSVersion,
   272  		maxTLSVersion:    c.maxTLSVersion,
   273  		tlsCiphersuites:  c.tlsCiphersuites,
   274  		localIdentity:    localIdentity,
   275  		localIdentities:  localIdentities,
   276  		targetIdentities: targetIdentities,
   277  		isClient:         c.isClient,
   278  		s2aAddr:          c.s2aAddr,
   279  	}
   280  }
   281  
   282  func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
   283  	c.info.ServerName = serverNameOverride
   284  	return nil
   285  }
   286  
   287  // TLSClientConfigOptions specifies parameters for creating client TLS config.
   288  type TLSClientConfigOptions struct {
   289  	// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
   290  	// 		tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
   291  	//			ServerName: "example.com",
   292  	//		})
   293  	ServerName string
   294  }
   295  
   296  // TLSClientConfigFactory defines the interface for a client TLS config factory.
   297  type TLSClientConfigFactory interface {
   298  	Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
   299  }
   300  
   301  // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
   302  func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
   303  	if opts == nil {
   304  		return nil, fmt.Errorf("opts must be non-nil")
   305  	}
   306  	if opts.EnableLegacyMode {
   307  		return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
   308  	}
   309  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   310  	if err != nil {
   311  		// The only possible error is: access token not set in the environment,
   312  		// which is okay in environments other than serverless.
   313  		grpclog.Infof("Access token manager not initialized: %v", err)
   314  		return &s2aTLSClientConfigFactory{
   315  			s2av2Address:              opts.S2AAddress,
   316  			transportCreds:            opts.TransportCreds,
   317  			tokenManager:              nil,
   318  			verificationMode:          getVerificationMode(opts.VerificationMode),
   319  			serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
   320  		}, nil
   321  	}
   322  	return &s2aTLSClientConfigFactory{
   323  		s2av2Address:              opts.S2AAddress,
   324  		transportCreds:            opts.TransportCreds,
   325  		tokenManager:              tokenManager,
   326  		verificationMode:          getVerificationMode(opts.VerificationMode),
   327  		serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
   328  	}, nil
   329  }
   330  
   331  type s2aTLSClientConfigFactory struct {
   332  	s2av2Address              string
   333  	transportCreds            credentials.TransportCredentials
   334  	tokenManager              tokenmanager.AccessTokenManager
   335  	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
   336  	serverAuthorizationPolicy []byte
   337  }
   338  
   339  func (f *s2aTLSClientConfigFactory) Build(
   340  	ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
   341  	serverName := ""
   342  	if opts != nil && opts.ServerName != "" {
   343  		serverName = opts.ServerName
   344  	}
   345  	return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
   346  }
   347  
   348  func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
   349  	switch verificationMode {
   350  	case ConnectToGoogle:
   351  		return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
   352  	case Spiffe:
   353  		return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
   354  	default:
   355  		return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
   356  	}
   357  }
   358  
   359  // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
   360  // Example use with http.RoundTripper:
   361  //
   362  //		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
   363  //			S2AAddress:         s2aAddress, // required
   364  //		})
   365  //	 	transport := http.DefaultTransport
   366  //	 	transport.DialTLSContext = dialTLSContext
   367  func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
   368  
   369  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
   370  
   371  		fallback := func(err error) (net.Conn, error) {
   372  			if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
   373  				opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
   374  				fbDialer := opts.FallbackOpts.FallbackDialer
   375  				grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
   376  				fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
   377  				if fbErr != nil {
   378  					return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
   379  				}
   380  				return fbConn, nil
   381  			}
   382  			return nil, err
   383  		}
   384  
   385  		factory, err := NewTLSClientConfigFactory(opts)
   386  		if err != nil {
   387  			grpclog.Infof("error creating S2A client config factory: %v", err)
   388  			return fallback(err)
   389  		}
   390  
   391  		serverName, _, err := net.SplitHostPort(addr)
   392  		if err != nil {
   393  			serverName = addr
   394  		}
   395  		timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
   396  		defer cancel()
   397  
   398  		var s2aTLSConfig *tls.Config
   399  		retry.Run(timeoutCtx,
   400  			func() error {
   401  				s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
   402  					ServerName: serverName,
   403  				})
   404  				return err
   405  			})
   406  		if err != nil {
   407  			grpclog.Infof("error building S2A TLS config: %v", err)
   408  			return fallback(err)
   409  		}
   410  
   411  		s2aDialer := &tls.Dialer{
   412  			Config: s2aTLSConfig,
   413  		}
   414  		var c net.Conn
   415  		retry.Run(timeoutCtx,
   416  			func() error {
   417  				c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
   418  				return err
   419  			})
   420  		if err != nil {
   421  			grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
   422  			return fallback(err)
   423  		}
   424  		grpclog.Infof("success dialing MTLS to %s with S2A", addr)
   425  		return c, nil
   426  	}
   427  }
   428  

View as plain text