...

Source file src/github.com/google/s2a-go/internal/v2/s2av2.go

Documentation: github.com/google/s2a-go/internal/v2

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package v2 provides the S2Av2 transport credentials used by a gRPC
    20  // application.
    21  package v2
    22  
    23  import (
    24  	"context"
    25  	"crypto/tls"
    26  	"errors"
    27  	"net"
    28  	"os"
    29  	"time"
    30  
    31  	"github.com/golang/protobuf/proto"
    32  	"github.com/google/s2a-go/fallback"
    33  	"github.com/google/s2a-go/internal/handshaker/service"
    34  	"github.com/google/s2a-go/internal/tokenmanager"
    35  	"github.com/google/s2a-go/internal/v2/tlsconfigstore"
    36  	"github.com/google/s2a-go/retry"
    37  	"github.com/google/s2a-go/stream"
    38  	"google.golang.org/grpc"
    39  	"google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/grpclog"
    41  
    42  	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
    43  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    44  )
    45  
    46  const (
    47  	s2aSecurityProtocol = "tls"
    48  	defaultS2ATimeout   = 6 * time.Second
    49  )
    50  
    51  // An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
    52  const s2aTimeoutEnv = "S2A_TIMEOUT"
    53  
    54  type s2av2TransportCreds struct {
    55  	info           *credentials.ProtocolInfo
    56  	isClient       bool
    57  	serverName     string
    58  	s2av2Address   string
    59  	transportCreds credentials.TransportCredentials
    60  	tokenManager   *tokenmanager.AccessTokenManager
    61  	// localIdentity should only be used by the client.
    62  	localIdentity *commonpbv1.Identity
    63  	// localIdentities should only be used by the server.
    64  	localIdentities           []*commonpbv1.Identity
    65  	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
    66  	fallbackClientHandshake   fallback.ClientHandshake
    67  	getS2AStream              func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
    68  	serverAuthorizationPolicy []byte
    69  }
    70  
    71  // NewClientCreds returns a client-side transport credentials object that uses
    72  // the S2Av2 to establish a secure connection with a server.
    73  func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
    74  	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
    75  	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
    76  
    77  	creds := &s2av2TransportCreds{
    78  		info: &credentials.ProtocolInfo{
    79  			SecurityProtocol: s2aSecurityProtocol,
    80  		},
    81  		isClient:                  true,
    82  		serverName:                "",
    83  		s2av2Address:              s2av2Address,
    84  		transportCreds:            transportCreds,
    85  		localIdentity:             localIdentity,
    86  		verificationMode:          verificationMode,
    87  		fallbackClientHandshake:   fallbackClientHandshakeFunc,
    88  		getS2AStream:              getS2AStream,
    89  		serverAuthorizationPolicy: serverAuthorizationPolicy,
    90  	}
    91  	if err != nil {
    92  		creds.tokenManager = nil
    93  	} else {
    94  		creds.tokenManager = &accessTokenManager
    95  	}
    96  	if grpclog.V(1) {
    97  		grpclog.Info("Created client S2Av2 transport credentials.")
    98  	}
    99  	return creds, nil
   100  }
   101  
   102  // NewServerCreds returns a server-side transport credentials object that uses
   103  // the S2Av2 to establish a secure connection with a client.
   104  func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
   105  	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
   106  	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   107  	creds := &s2av2TransportCreds{
   108  		info: &credentials.ProtocolInfo{
   109  			SecurityProtocol: s2aSecurityProtocol,
   110  		},
   111  		isClient:         false,
   112  		s2av2Address:     s2av2Address,
   113  		transportCreds:   transportCreds,
   114  		localIdentities:  localIdentities,
   115  		verificationMode: verificationMode,
   116  		getS2AStream:     getS2AStream,
   117  	}
   118  	if err != nil {
   119  		creds.tokenManager = nil
   120  	} else {
   121  		creds.tokenManager = &accessTokenManager
   122  	}
   123  	if grpclog.V(1) {
   124  		grpclog.Info("Created server S2Av2 transport credentials.")
   125  	}
   126  	return creds, nil
   127  }
   128  
   129  // ClientHandshake performs a client-side mTLS handshake using the S2Av2.
   130  func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   131  	if !c.isClient {
   132  		return nil, nil, errors.New("client handshake called using server transport credentials")
   133  	}
   134  	// Remove the port from serverAuthority.
   135  	serverName := removeServerNamePort(serverAuthority)
   136  	timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
   137  	defer cancel()
   138  	var s2AStream stream.S2AStream
   139  	var err error
   140  	retry.Run(timeoutCtx,
   141  		func() error {
   142  			s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
   143  			return err
   144  		})
   145  	if err != nil {
   146  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
   147  		if c.fallbackClientHandshake != nil {
   148  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
   149  		}
   150  		return nil, nil, err
   151  	}
   152  	defer s2AStream.CloseSend()
   153  	if grpclog.V(1) {
   154  		grpclog.Infof("Connected to S2Av2.")
   155  	}
   156  	var config *tls.Config
   157  
   158  	var tokenManager tokenmanager.AccessTokenManager
   159  	if c.tokenManager == nil {
   160  		tokenManager = nil
   161  	} else {
   162  		tokenManager = *c.tokenManager
   163  	}
   164  
   165  	sn := serverName
   166  	if c.serverName != "" {
   167  		sn = c.serverName
   168  	}
   169  	retry.Run(timeoutCtx,
   170  		func() error {
   171  			config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
   172  			return err
   173  		})
   174  	if err != nil {
   175  		grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
   176  		if c.fallbackClientHandshake != nil {
   177  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
   178  		}
   179  		return nil, nil, err
   180  	}
   181  	if grpclog.V(1) {
   182  		grpclog.Infof("Got client TLS config from S2Av2.")
   183  	}
   184  
   185  	creds := credentials.NewTLS(config)
   186  	var conn net.Conn
   187  	var authInfo credentials.AuthInfo
   188  	retry.Run(timeoutCtx,
   189  		func() error {
   190  			conn, authInfo, err = creds.ClientHandshake(timeoutCtx, serverName, rawConn)
   191  			return err
   192  		})
   193  	if err != nil {
   194  		grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
   195  		if c.fallbackClientHandshake != nil {
   196  			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
   197  		}
   198  		return nil, nil, err
   199  	}
   200  	grpclog.Infof("Successfully done client handshake using S2Av2 to: %s", serverName)
   201  
   202  	return conn, authInfo, err
   203  }
   204  
   205  // ServerHandshake performs a server-side mTLS handshake using the S2Av2.
   206  func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   207  	if c.isClient {
   208  		return nil, nil, errors.New("server handshake called using client transport credentials")
   209  	}
   210  	ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
   211  	defer cancel()
   212  	var s2AStream stream.S2AStream
   213  	var err error
   214  	retry.Run(ctx,
   215  		func() error {
   216  			s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
   217  			return err
   218  		})
   219  	if err != nil {
   220  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
   221  		return nil, nil, err
   222  	}
   223  	defer s2AStream.CloseSend()
   224  	if grpclog.V(1) {
   225  		grpclog.Infof("Connected to S2Av2.")
   226  	}
   227  
   228  	var tokenManager tokenmanager.AccessTokenManager
   229  	if c.tokenManager == nil {
   230  		tokenManager = nil
   231  	} else {
   232  		tokenManager = *c.tokenManager
   233  	}
   234  
   235  	var config *tls.Config
   236  	retry.Run(ctx,
   237  		func() error {
   238  			config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
   239  			return err
   240  		})
   241  	if err != nil {
   242  		grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
   243  		return nil, nil, err
   244  	}
   245  	if grpclog.V(1) {
   246  		grpclog.Infof("Got server TLS config from S2Av2.")
   247  	}
   248  
   249  	creds := credentials.NewTLS(config)
   250  	var conn net.Conn
   251  	var authInfo credentials.AuthInfo
   252  	retry.Run(ctx,
   253  		func() error {
   254  			conn, authInfo, err = creds.ServerHandshake(rawConn)
   255  			return err
   256  		})
   257  	if err != nil {
   258  		grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
   259  		return nil, nil, err
   260  	}
   261  	return conn, authInfo, err
   262  }
   263  
   264  // Info returns protocol info of s2av2TransportCreds.
   265  func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
   266  	return *c.info
   267  }
   268  
   269  // Clone makes a deep copy of s2av2TransportCreds.
   270  func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
   271  	info := *c.info
   272  	serverName := c.serverName
   273  	fallbackClientHandshake := c.fallbackClientHandshake
   274  
   275  	s2av2Address := c.s2av2Address
   276  	var tokenManager tokenmanager.AccessTokenManager
   277  	if c.tokenManager == nil {
   278  		tokenManager = nil
   279  	} else {
   280  		tokenManager = *c.tokenManager
   281  	}
   282  	verificationMode := c.verificationMode
   283  	var localIdentity *commonpbv1.Identity
   284  	if c.localIdentity != nil {
   285  		localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
   286  	}
   287  	var localIdentities []*commonpbv1.Identity
   288  	if c.localIdentities != nil {
   289  		localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
   290  		for i, localIdentity := range c.localIdentities {
   291  			localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
   292  		}
   293  	}
   294  	creds := &s2av2TransportCreds{
   295  		info:                    &info,
   296  		isClient:                c.isClient,
   297  		serverName:              serverName,
   298  		fallbackClientHandshake: fallbackClientHandshake,
   299  		s2av2Address:            s2av2Address,
   300  		localIdentity:           localIdentity,
   301  		localIdentities:         localIdentities,
   302  		verificationMode:        verificationMode,
   303  	}
   304  	if c.tokenManager == nil {
   305  		creds.tokenManager = nil
   306  	} else {
   307  		creds.tokenManager = &tokenManager
   308  	}
   309  	return creds
   310  }
   311  
   312  // NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
   313  // a client. The tls.Config MUST only be used to establish a single TLS connection.
   314  func NewClientTLSConfig(
   315  	ctx context.Context,
   316  	s2av2Address string,
   317  	transportCreds credentials.TransportCredentials,
   318  	tokenManager tokenmanager.AccessTokenManager,
   319  	verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
   320  	serverName string,
   321  	serverAuthorizationPolicy []byte) (*tls.Config, error) {
   322  	s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
   323  	if err != nil {
   324  		grpclog.Infof("Failed to connect to S2Av2: %v", err)
   325  		return nil, err
   326  	}
   327  
   328  	return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
   329  }
   330  
   331  // OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
   332  // info. The ServerName MUST be a hostname.
   333  func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
   334  	serverName := removeServerNamePort(serverNameOverride)
   335  	c.info.ServerName = serverName
   336  	c.serverName = serverName
   337  	return nil
   338  }
   339  
   340  // Remove the trailing port from server name.
   341  func removeServerNamePort(serverName string) string {
   342  	name, _, err := net.SplitHostPort(serverName)
   343  	if err != nil {
   344  		name = serverName
   345  	}
   346  	return name
   347  }
   348  
   349  type s2AGrpcStream struct {
   350  	stream s2av2pb.S2AService_SetUpSessionClient
   351  }
   352  
   353  func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
   354  	return x.stream.Send(m)
   355  }
   356  
   357  func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
   358  	return x.stream.Recv()
   359  }
   360  
   361  func (x s2AGrpcStream) CloseSend() error {
   362  	return x.stream.CloseSend()
   363  }
   364  
   365  func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
   366  	if getS2AStream != nil {
   367  		return getS2AStream(ctx, s2av2Address)
   368  	}
   369  	// TODO(rmehta19): Consider whether to close the connection to S2Av2.
   370  	conn, err := service.Dial(ctx, s2av2Address, transportCreds)
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  	client := s2av2pb.NewS2AServiceClient(conn)
   375  	gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
   376  	if err != nil {
   377  		return nil, err
   378  	}
   379  	return &s2AGrpcStream{
   380  		stream: gRPCStream,
   381  	}, nil
   382  }
   383  
   384  // GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
   385  func GetS2ATimeout() time.Duration {
   386  	timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
   387  	if err != nil {
   388  		return defaultS2ATimeout
   389  	}
   390  	return timeout
   391  }
   392  

View as plain text