...

Source file src/github.com/google/s2a-go/internal/handshaker/handshaker.go

Documentation: github.com/google/s2a-go/internal/handshaker

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package handshaker communicates with the S2A handshaker service.
    20  package handshaker
    21  
    22  import (
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"sync"
    29  
    30  	"github.com/google/s2a-go/internal/authinfo"
    31  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    32  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    33  	"github.com/google/s2a-go/internal/record"
    34  	"github.com/google/s2a-go/internal/tokenmanager"
    35  	grpc "google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/grpclog"
    39  )
    40  
    41  var (
    42  	// appProtocol contains the application protocol accepted by the handshaker.
    43  	appProtocol = "grpc"
    44  	// frameLimit is the maximum size of a frame in bytes.
    45  	frameLimit = 1024 * 64
    46  	// peerNotRespondingError is the error thrown when the peer doesn't respond.
    47  	errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
    48  )
    49  
    50  // Handshaker defines a handshaker interface.
    51  type Handshaker interface {
    52  	// ClientHandshake starts and completes a TLS handshake from the client side,
    53  	// and returns a secure connection along with additional auth information.
    54  	ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
    55  	// ServerHandshake starts and completes a TLS handshake from the server side,
    56  	// and returns a secure connection along with additional auth information.
    57  	ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
    58  	// Close terminates the Handshaker. It should be called when the handshake
    59  	// is complete.
    60  	Close() error
    61  }
    62  
    63  // ClientHandshakerOptions contains the options needed to configure the S2A
    64  // handshaker service on the client-side.
    65  type ClientHandshakerOptions struct {
    66  	// MinTLSVersion specifies the min TLS version supported by the client.
    67  	MinTLSVersion commonpb.TLSVersion
    68  	// MaxTLSVersion specifies the max TLS version supported by the client.
    69  	MaxTLSVersion commonpb.TLSVersion
    70  	// TLSCiphersuites is the ordered list of ciphersuites supported by the
    71  	// client.
    72  	TLSCiphersuites []commonpb.Ciphersuite
    73  	// TargetIdentities contains a list of allowed server identities. One of the
    74  	// target identities should match the peer identity in the handshake
    75  	// result; otherwise, the handshake fails.
    76  	TargetIdentities []*commonpb.Identity
    77  	// LocalIdentity is the local identity of the client application. If none is
    78  	// provided, then the S2A will choose the default identity.
    79  	LocalIdentity *commonpb.Identity
    80  	// TargetName is the allowed server name, which may be used for server
    81  	// authorization check by the S2A if it is provided.
    82  	TargetName string
    83  	// EnsureProcessSessionTickets allows users to wait and ensure that all
    84  	// available session tickets are sent to S2A before a process completes.
    85  	EnsureProcessSessionTickets *sync.WaitGroup
    86  }
    87  
    88  // ServerHandshakerOptions contains the options needed to configure the S2A
    89  // handshaker service on the server-side.
    90  type ServerHandshakerOptions struct {
    91  	// MinTLSVersion specifies the min TLS version supported by the server.
    92  	MinTLSVersion commonpb.TLSVersion
    93  	// MaxTLSVersion specifies the max TLS version supported by the server.
    94  	MaxTLSVersion commonpb.TLSVersion
    95  	// TLSCiphersuites is the ordered list of ciphersuites supported by the
    96  	// server.
    97  	TLSCiphersuites []commonpb.Ciphersuite
    98  	// LocalIdentities is the list of local identities that may be assumed by
    99  	// the server. If no local identity is specified, then the S2A chooses a
   100  	// default local identity.
   101  	LocalIdentities []*commonpb.Identity
   102  }
   103  
   104  // s2aHandshaker performs a TLS handshake using the S2A handshaker service.
   105  type s2aHandshaker struct {
   106  	// stream is used to communicate with the S2A handshaker service.
   107  	stream s2apb.S2AService_SetUpSessionClient
   108  	// conn is the connection to the peer.
   109  	conn net.Conn
   110  	// clientOpts should be non-nil iff the handshaker is client-side.
   111  	clientOpts *ClientHandshakerOptions
   112  	// serverOpts should be non-nil iff the handshaker is server-side.
   113  	serverOpts *ServerHandshakerOptions
   114  	// isClient determines if the handshaker is client or server side.
   115  	isClient bool
   116  	// hsAddr stores the address of the S2A handshaker service.
   117  	hsAddr string
   118  	// tokenManager manages access tokens for authenticating to S2A.
   119  	tokenManager tokenmanager.AccessTokenManager
   120  	// localIdentities is the set of local identities for whom the
   121  	// tokenManager should fetch a token when preparing a request to be
   122  	// sent to S2A.
   123  	localIdentities []*commonpb.Identity
   124  }
   125  
   126  // NewClientHandshaker creates an s2aHandshaker instance that performs a
   127  // client-side TLS handshake using the S2A handshaker service.
   128  func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
   129  	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   134  	if err != nil {
   135  		grpclog.Infof("failed to create single token access token manager: %v", err)
   136  	}
   137  	return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
   138  }
   139  
   140  func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
   141  	var localIdentities []*commonpb.Identity
   142  	if opts != nil {
   143  		localIdentities = []*commonpb.Identity{opts.LocalIdentity}
   144  	}
   145  	return &s2aHandshaker{
   146  		stream:          stream,
   147  		conn:            c,
   148  		clientOpts:      opts,
   149  		isClient:        true,
   150  		hsAddr:          hsAddr,
   151  		tokenManager:    tokenManager,
   152  		localIdentities: localIdentities,
   153  	}
   154  }
   155  
   156  // NewServerHandshaker creates an s2aHandshaker instance that performs a
   157  // server-side TLS handshake using the S2A handshaker service.
   158  func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
   159  	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   164  	if err != nil {
   165  		grpclog.Infof("failed to create single token access token manager: %v", err)
   166  	}
   167  	return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
   168  }
   169  
   170  func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
   171  	var localIdentities []*commonpb.Identity
   172  	if opts != nil {
   173  		localIdentities = opts.LocalIdentities
   174  	}
   175  	return &s2aHandshaker{
   176  		stream:          stream,
   177  		conn:            c,
   178  		serverOpts:      opts,
   179  		isClient:        false,
   180  		hsAddr:          hsAddr,
   181  		tokenManager:    tokenManager,
   182  		localIdentities: localIdentities,
   183  	}
   184  }
   185  
   186  // ClientHandshake performs a client-side TLS handshake using the S2A handshaker
   187  // service. When complete, returns a TLS connection.
   188  func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
   189  	if !h.isClient {
   190  		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
   191  	}
   192  	// Extract the hostname from the target name. The target name is assumed to be an authority.
   193  	hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
   194  	if err != nil {
   195  		// If the target name had no host port or could not be parsed, use it as is.
   196  		hostname = h.clientOpts.TargetName
   197  	}
   198  
   199  	// Prepare a client start message to send to the S2A handshaker service.
   200  	req := &s2apb.SessionReq{
   201  		ReqOneof: &s2apb.SessionReq_ClientStart{
   202  			ClientStart: &s2apb.ClientSessionStartReq{
   203  				ApplicationProtocols: []string{appProtocol},
   204  				MinTlsVersion:        h.clientOpts.MinTLSVersion,
   205  				MaxTlsVersion:        h.clientOpts.MaxTLSVersion,
   206  				TlsCiphersuites:      h.clientOpts.TLSCiphersuites,
   207  				TargetIdentities:     h.clientOpts.TargetIdentities,
   208  				LocalIdentity:        h.clientOpts.LocalIdentity,
   209  				TargetName:           hostname,
   210  			},
   211  		},
   212  		AuthMechanisms: h.getAuthMechanisms(),
   213  	}
   214  	conn, result, err := h.setUpSession(req)
   215  	if err != nil {
   216  		return nil, nil, err
   217  	}
   218  	authInfo, err := authinfo.NewS2AAuthInfo(result)
   219  	if err != nil {
   220  		return nil, nil, err
   221  	}
   222  	return conn, authInfo, nil
   223  }
   224  
   225  // ServerHandshake performs a server-side TLS handshake using the S2A handshaker
   226  // service. When complete, returns a TLS connection.
   227  func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
   228  	if h.isClient {
   229  		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
   230  	}
   231  	p := make([]byte, frameLimit)
   232  	n, err := h.conn.Read(p)
   233  	if err != nil {
   234  		return nil, nil, err
   235  	}
   236  	// Prepare a server start message to send to the S2A handshaker service.
   237  	req := &s2apb.SessionReq{
   238  		ReqOneof: &s2apb.SessionReq_ServerStart{
   239  			ServerStart: &s2apb.ServerSessionStartReq{
   240  				ApplicationProtocols: []string{appProtocol},
   241  				MinTlsVersion:        h.serverOpts.MinTLSVersion,
   242  				MaxTlsVersion:        h.serverOpts.MaxTLSVersion,
   243  				TlsCiphersuites:      h.serverOpts.TLSCiphersuites,
   244  				LocalIdentities:      h.serverOpts.LocalIdentities,
   245  				InBytes:              p[:n],
   246  			},
   247  		},
   248  		AuthMechanisms: h.getAuthMechanisms(),
   249  	}
   250  	conn, result, err := h.setUpSession(req)
   251  	if err != nil {
   252  		return nil, nil, err
   253  	}
   254  	authInfo, err := authinfo.NewS2AAuthInfo(result)
   255  	if err != nil {
   256  		return nil, nil, err
   257  	}
   258  	return conn, authInfo, nil
   259  }
   260  
   261  // setUpSession proxies messages between the peer and the S2A handshaker
   262  // service.
   263  func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
   264  	resp, err := h.accessHandshakerService(req)
   265  	if err != nil {
   266  		return nil, nil, err
   267  	}
   268  	// Check if the returned status is an error.
   269  	if resp.GetStatus() != nil {
   270  		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
   271  			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
   272  		}
   273  	}
   274  	// Calculate the extra unread bytes from the Session. Attempting to consume
   275  	// more than the bytes sent will throw an error.
   276  	var extra []byte
   277  	if req.GetServerStart() != nil {
   278  		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
   279  			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
   280  		}
   281  		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
   282  	}
   283  	result, extra, err := h.processUntilDone(resp, extra)
   284  	if err != nil {
   285  		return nil, nil, err
   286  	}
   287  	if result.GetLocalIdentity() == nil {
   288  		return nil, nil, errors.New("local identity must be populated in session result")
   289  	}
   290  
   291  	// Create a new TLS record protocol using the Session Result.
   292  	newConn, err := record.NewConn(&record.ConnParameters{
   293  		NetConn:                     h.conn,
   294  		Ciphersuite:                 result.GetState().GetTlsCiphersuite(),
   295  		TLSVersion:                  result.GetState().GetTlsVersion(),
   296  		InTrafficSecret:             result.GetState().GetInKey(),
   297  		OutTrafficSecret:            result.GetState().GetOutKey(),
   298  		UnusedBuf:                   extra,
   299  		InSequence:                  result.GetState().GetInSequence(),
   300  		OutSequence:                 result.GetState().GetOutSequence(),
   301  		HSAddr:                      h.hsAddr,
   302  		ConnectionID:                result.GetState().GetConnectionId(),
   303  		LocalIdentity:               result.GetLocalIdentity(),
   304  		EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
   305  	})
   306  	if err != nil {
   307  		return nil, nil, err
   308  	}
   309  	return newConn, result, nil
   310  }
   311  
   312  func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
   313  	if h.clientOpts == nil {
   314  		return nil
   315  	}
   316  	return h.clientOpts.EnsureProcessSessionTickets
   317  }
   318  
   319  // accessHandshakerService sends the session request to the S2A handshaker
   320  // service and returns the session response.
   321  func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
   322  	if err := h.stream.Send(req); err != nil {
   323  		return nil, err
   324  	}
   325  	resp, err := h.stream.Recv()
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  	return resp, nil
   330  }
   331  
   332  // processUntilDone continues proxying messages between the peer and the S2A
   333  // handshaker service until the handshaker service returns the SessionResult at
   334  // the end of the handshake or an error occurs.
   335  func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
   336  	for {
   337  		if len(resp.OutFrames) > 0 {
   338  			if _, err := h.conn.Write(resp.OutFrames); err != nil {
   339  				return nil, nil, err
   340  			}
   341  		}
   342  		if resp.Result != nil {
   343  			return resp.Result, unusedBytes, nil
   344  		}
   345  		buf := make([]byte, frameLimit)
   346  		n, err := h.conn.Read(buf)
   347  		if err != nil && err != io.EOF {
   348  			return nil, nil, err
   349  		}
   350  		// If there is nothing to send to the handshaker service and nothing is
   351  		// received from the peer, then we are stuck. This covers the case when
   352  		// the peer is not responding. Note that handshaker service connection
   353  		// issues are caught in accessHandshakerService before we even get
   354  		// here.
   355  		if len(resp.OutFrames) == 0 && n == 0 {
   356  			return nil, nil, errPeerNotResponding
   357  		}
   358  		// Append extra bytes from the previous interaction with the handshaker
   359  		// service with the current buffer read from conn.
   360  		p := append(unusedBytes, buf[:n]...)
   361  		// From here on, p and unusedBytes point to the same slice.
   362  		resp, err = h.accessHandshakerService(&s2apb.SessionReq{
   363  			ReqOneof: &s2apb.SessionReq_Next{
   364  				Next: &s2apb.SessionNextReq{
   365  					InBytes: p,
   366  				},
   367  			},
   368  			AuthMechanisms: h.getAuthMechanisms(),
   369  		})
   370  		if err != nil {
   371  			return nil, nil, err
   372  		}
   373  
   374  		// Cache the local identity returned by S2A, if it is populated. This
   375  		// overwrites any existing local identities. This is done because, once the
   376  		// S2A has selected a local identity, then only that local identity should
   377  		// be asserted in future requests until the end of the current handshake.
   378  		if resp.GetLocalIdentity() != nil {
   379  			h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
   380  		}
   381  
   382  		// Set unusedBytes based on the handshaker service response.
   383  		if resp.GetBytesConsumed() > uint32(len(p)) {
   384  			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
   385  		}
   386  		unusedBytes = p[resp.GetBytesConsumed():]
   387  	}
   388  }
   389  
   390  // Close shuts down the handshaker and the stream to the S2A handshaker service
   391  // when the handshake is complete. It should be called when the caller obtains
   392  // the secure connection at the end of the handshake.
   393  func (h *s2aHandshaker) Close() error {
   394  	return h.stream.CloseSend()
   395  }
   396  
   397  func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
   398  	if h.tokenManager == nil {
   399  		return nil
   400  	}
   401  	// First handle the special case when no local identities have been provided
   402  	// by the application. In this case, an AuthenticationMechanism with no local
   403  	// identity will be sent.
   404  	if len(h.localIdentities) == 0 {
   405  		token, err := h.tokenManager.DefaultToken()
   406  		if err != nil {
   407  			grpclog.Infof("unable to get token for empty local identity: %v", err)
   408  			return nil
   409  		}
   410  		return []*s2apb.AuthenticationMechanism{
   411  			{
   412  				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   413  					Token: token,
   414  				},
   415  			},
   416  		}
   417  	}
   418  
   419  	// Next, handle the case where the application (or the S2A) has provided
   420  	// one or more local identities.
   421  	var authMechanisms []*s2apb.AuthenticationMechanism
   422  	for _, localIdentity := range h.localIdentities {
   423  		token, err := h.tokenManager.Token(localIdentity)
   424  		if err != nil {
   425  			grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
   426  			continue
   427  		}
   428  
   429  		authMechanism := &s2apb.AuthenticationMechanism{
   430  			Identity: localIdentity,
   431  			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   432  				Token: token,
   433  			},
   434  		}
   435  		authMechanisms = append(authMechanisms, authMechanism)
   436  	}
   437  	return authMechanisms
   438  }
   439  

View as plain text