...

Source file src/github.com/google/s2a-go/internal/v2/fakes2av2/fakes2av2.go

Documentation: github.com/google/s2a-go/internal/v2/fakes2av2

     1  /*
     2   *
     3   * Copyright 2022 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package fakes2av2 is a fake S2Av2 Go implementation.
    20  package fakes2av2
    21  
    22  import (
    23  	"bytes"
    24  	"crypto"
    25  	"crypto/rand"
    26  	"crypto/rsa"
    27  	"crypto/tls"
    28  	"crypto/x509"
    29  	"errors"
    30  	"fmt"
    31  	"log"
    32  	"time"
    33  
    34  	"google.golang.org/grpc/codes"
    35  
    36  	_ "embed"
    37  
    38  	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
    39  	s2av2ctx "github.com/google/s2a-go/internal/proto/v2/s2a_context_go_proto"
    40  	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
    41  )
    42  
    43  var (
    44  	//go:embed testdata/client_root_cert.pem
    45  	clientCert []byte
    46  	//go:embed testdata/client_root_cert.der
    47  	clientDERCert []byte
    48  	//go:embed testdata/client_root_key.pem
    49  	clientKey []byte
    50  	//go:embed testdata/server_root_cert.pem
    51  	serverCert []byte
    52  	//go:embed testdata/server_root_cert.der
    53  	serverDERCert []byte
    54  	//go:embed testdata/server_root_key.pem
    55  	serverKey []byte
    56  )
    57  
    58  // Server is a fake S2A Server for testing.
    59  type Server struct {
    60  	s2av2pb.UnimplementedS2AServiceServer
    61  	// ExpectedToken is the token S2Av2 expects to be attached to the SessionReq.
    62  	ExpectedToken string
    63  	// ShouldNotReturnClientCredentials indicates whether the fake S2Av2 should
    64  	// not return credentials when GetTlsConfiguration is called by a client.
    65  	ShouldNotReturnClientCredentials bool
    66  	isAssistingClientSide            bool
    67  	ServerAuthorizationPolicy        []byte
    68  	// TODO(rmehta19): Decide whether to also store validationResult (bool).
    69  	// Set this after validating token attached to first SessionReq. Check
    70  	// this field before completing subsequent SessionReq.
    71  }
    72  
    73  // SetUpSession receives SessionReq, performs request, and returns a
    74  // SessionResp, all on the server stream.
    75  func (s *Server) SetUpSession(stream s2av2pb.S2AService_SetUpSessionServer) error {
    76  	for {
    77  		req, err := stream.Recv()
    78  		if err != nil {
    79  			log.Printf("Fake S2A Service: failed to receive SessionReq: %v", err)
    80  			return err
    81  		}
    82  		// Call one of the 4 possible RespOneof's
    83  		// TODO(rmehta19): Consider validating the body of the request.
    84  		var resp *s2av2pb.SessionResp
    85  		switch x := req.ReqOneof.(type) {
    86  		case *s2av2pb.SessionReq_GetTlsConfigurationReq:
    87  			if err := s.hasValidToken(req.GetAuthenticationMechanisms()); err != nil {
    88  				log.Printf("Fake S2A Service: authentication error: %v", err)
    89  				return err
    90  			}
    91  			if err := s.findConnectionSide(req); err != nil {
    92  				resp = &s2av2pb.SessionResp{
    93  					Status: &s2av2pb.Status{
    94  						Code:    uint32(codes.InvalidArgument),
    95  						Details: err.Error(),
    96  					},
    97  				}
    98  				break
    99  			}
   100  			resp, err = getTLSConfiguration(req.GetGetTlsConfigurationReq(), s.ShouldNotReturnClientCredentials)
   101  			if err != nil {
   102  				log.Printf("Fake S2A Service: failed to build SessionResp with GetTlsConfigurationResp: %v", err)
   103  				return err
   104  			}
   105  		case *s2av2pb.SessionReq_OffloadPrivateKeyOperationReq:
   106  			resp, err = offloadPrivateKeyOperation(req.GetOffloadPrivateKeyOperationReq(), s.isAssistingClientSide)
   107  			if err != nil {
   108  				log.Printf("Fake S2A Service: failed to build SessionResp with OffloadPrivateKeyOperationResp: %v", err)
   109  				return err
   110  			}
   111  		case *s2av2pb.SessionReq_OffloadResumptionKeyOperationReq:
   112  			// TODO(rmehta19): Implement fake.
   113  		case *s2av2pb.SessionReq_ValidatePeerCertificateChainReq:
   114  			resp, err = validatePeerCertificateChain(req.GetValidatePeerCertificateChainReq(), s.ServerAuthorizationPolicy)
   115  			if err != nil {
   116  				log.Printf("Fake S2A Service: failed to build SessionResp with ValidatePeerCertificateChainResp: %v", err)
   117  				return err
   118  			}
   119  		default:
   120  			return fmt.Errorf("SessionReq.ReqOneof has unexpected type %T", x)
   121  		}
   122  		if err := stream.Send(resp); err != nil {
   123  			log.Printf("Fake S2A Service: failed to send SessionResp: %v", err)
   124  			return err
   125  		}
   126  	}
   127  }
   128  
   129  func (s *Server) findConnectionSide(req *s2av2pb.SessionReq) error {
   130  	switch connSide := req.GetGetTlsConfigurationReq().GetConnectionSide(); connSide {
   131  	case commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT:
   132  		s.isAssistingClientSide = true
   133  	case commonpb.ConnectionSide_CONNECTION_SIDE_SERVER:
   134  		s.isAssistingClientSide = false
   135  	default:
   136  		return fmt.Errorf("unknown ConnectionSide: %v", connSide)
   137  	}
   138  	return nil
   139  }
   140  
   141  func (s *Server) hasValidToken(authMechanisms []*s2av2pb.AuthenticationMechanism) error {
   142  	if len(authMechanisms) == 0 {
   143  		return nil
   144  	}
   145  	for _, v := range authMechanisms {
   146  		token := v.GetToken()
   147  		if token == s.ExpectedToken {
   148  			return nil
   149  		}
   150  	}
   151  	return errors.New("SessionReq has no AuthenticationMechanism with a valid token")
   152  }
   153  
   154  func offloadPrivateKeyOperation(req *s2av2pb.OffloadPrivateKeyOperationReq, isAssistingClientSide bool) (*s2av2pb.SessionResp, error) {
   155  	switch x := req.GetOperation(); x {
   156  	case s2av2pb.OffloadPrivateKeyOperationReq_SIGN:
   157  		var root tls.Certificate
   158  		var err error
   159  		// Retrieve S2Av2 implementation of crypto.Signer.
   160  		if isAssistingClientSide {
   161  			root, err = tls.X509KeyPair(clientCert, clientKey)
   162  			if err != nil {
   163  				return nil, err
   164  			}
   165  		} else {
   166  			root, err = tls.X509KeyPair(serverCert, serverKey)
   167  			if err != nil {
   168  				return nil, err
   169  			}
   170  		}
   171  		var signedBytes []byte
   172  		if req.GetSignatureAlgorithm() == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA256 {
   173  			signedBytes, err = root.PrivateKey.(crypto.Signer).Sign(rand.Reader, req.GetSha256Digest(), crypto.SHA256)
   174  			if err != nil {
   175  				return nil, err
   176  			}
   177  		} else if req.GetSignatureAlgorithm() == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256 {
   178  			opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256}
   179  			signedBytes, err = root.PrivateKey.(crypto.Signer).Sign(rand.Reader, req.GetSha256Digest(), opts)
   180  			if err != nil {
   181  				return nil, err
   182  			}
   183  		} else {
   184  			return &s2av2pb.SessionResp{
   185  				Status: &s2av2pb.Status{
   186  					Code:    uint32(codes.InvalidArgument),
   187  					Details: fmt.Sprintf("invalid signature algorithm: %v", req.GetSignatureAlgorithm()),
   188  				},
   189  			}, nil
   190  		}
   191  		return &s2av2pb.SessionResp{
   192  			Status: &s2av2pb.Status{
   193  				Code: uint32(codes.OK),
   194  			},
   195  			RespOneof: &s2av2pb.SessionResp_OffloadPrivateKeyOperationResp{
   196  				OffloadPrivateKeyOperationResp: &s2av2pb.OffloadPrivateKeyOperationResp{
   197  					OutBytes: signedBytes,
   198  				},
   199  			},
   200  		}, nil
   201  	case s2av2pb.OffloadPrivateKeyOperationReq_DECRYPT:
   202  		return nil, errors.New("decrypt operation not implemented yet")
   203  	default:
   204  		return nil, fmt.Errorf("unspecified private key operation requested: %d", x)
   205  	}
   206  }
   207  
   208  func validatePeerCertificateChain(req *s2av2pb.ValidatePeerCertificateChainReq, serverAuthorizationPolicy []byte) (*s2av2pb.SessionResp, error) {
   209  	switch x := req.PeerOneof.(type) {
   210  	case *s2av2pb.ValidatePeerCertificateChainReq_ClientPeer_:
   211  		return verifyClientPeer(req)
   212  	case *s2av2pb.ValidatePeerCertificateChainReq_ServerPeer_:
   213  		return verifyServerPeer(req, serverAuthorizationPolicy)
   214  	default:
   215  		err := fmt.Errorf("peer verification failed: invalid Peer type %T", x)
   216  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   217  	}
   218  }
   219  
   220  // TODO(rmehta19): Update this to return ciphersuites in Client/Server TlsConfiguration.
   221  func getTLSConfiguration(req *s2av2pb.GetTlsConfigurationReq, shouldNotReturnClientCredentials bool) (*s2av2pb.SessionResp, error) {
   222  	if req.GetConnectionSide() == commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT {
   223  		if shouldNotReturnClientCredentials {
   224  			return &s2av2pb.SessionResp{
   225  				Status: &s2av2pb.Status{
   226  					Code: uint32(codes.OK),
   227  				},
   228  				RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
   229  					GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
   230  						TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration_{
   231  							ClientTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration{
   232  								MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
   233  								MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
   234  							},
   235  						},
   236  					},
   237  				},
   238  			}, nil
   239  		}
   240  		return &s2av2pb.SessionResp{
   241  			Status: &s2av2pb.Status{
   242  				Code: uint32(codes.OK),
   243  			},
   244  			RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
   245  				GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
   246  					TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration_{
   247  						ClientTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration{
   248  							CertificateChain: []string{
   249  								string(clientCert),
   250  							},
   251  							MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
   252  							MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
   253  						},
   254  					},
   255  				},
   256  			},
   257  		}, nil
   258  	} else if req.GetConnectionSide() == commonpb.ConnectionSide_CONNECTION_SIDE_SERVER {
   259  		return &s2av2pb.SessionResp{
   260  			Status: &s2av2pb.Status{
   261  				Code: uint32(codes.OK),
   262  			},
   263  			RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
   264  				GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
   265  					TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_{
   266  						ServerTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration{
   267  							CertificateChain: []string{
   268  								string(serverCert),
   269  							},
   270  							MinTlsVersion:            commonpb.TLSVersion_TLS_VERSION_1_3,
   271  							MaxTlsVersion:            commonpb.TLSVersion_TLS_VERSION_1_3,
   272  							TlsResumptionEnabled:     false,
   273  							RequestClientCertificate: s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY,
   274  							MaxOverheadOfTicketAead:  0,
   275  						},
   276  					},
   277  				},
   278  			},
   279  		}, nil
   280  	}
   281  	return nil, fmt.Errorf("unspecified connection side: %v", req.GetConnectionSide())
   282  }
   283  
   284  func buildValidatePeerCertificateChainSessionResp(StatusCode uint32, StatusDetails string, ValidationResult s2av2pb.ValidatePeerCertificateChainResp_ValidationResult, ValidationDetails string, Context *s2av2ctx.S2AContext) *s2av2pb.SessionResp {
   285  	return &s2av2pb.SessionResp{
   286  		Status: &s2av2pb.Status{
   287  			Code:    StatusCode,
   288  			Details: StatusDetails,
   289  		},
   290  		RespOneof: &s2av2pb.SessionResp_ValidatePeerCertificateChainResp{
   291  			ValidatePeerCertificateChainResp: &s2av2pb.ValidatePeerCertificateChainResp{
   292  				ValidationResult:  ValidationResult,
   293  				ValidationDetails: ValidationDetails,
   294  				Context:           Context,
   295  			},
   296  		},
   297  	}
   298  }
   299  
   300  func verifyClientPeer(req *s2av2pb.ValidatePeerCertificateChainReq) (*s2av2pb.SessionResp, error) {
   301  	derCertChain := req.GetClientPeer().CertificateChain
   302  	if len(derCertChain) == 0 {
   303  		s := "client peer verification failed: client cert chain is empty"
   304  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
   305  	}
   306  
   307  	// Obtain the set of root certificates.
   308  	rootCertPool := x509.NewCertPool()
   309  	if ok := rootCertPool.AppendCertsFromPEM(clientCert); ok != true {
   310  		err := errors.New("client peer verification failed: S2Av2 could not obtain/parse roots")
   311  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   312  	}
   313  
   314  	// Set the Intermediates: certs between leaf and root, excluding the leaf and root.
   315  	intermediateCertPool := x509.NewCertPool()
   316  	for i := 1; i < (len(derCertChain)); i++ {
   317  		x509Cert, err := x509.ParseCertificate(derCertChain[i])
   318  		if err != nil {
   319  			return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   320  		}
   321  		intermediateCertPool.AddCert(x509Cert)
   322  	}
   323  
   324  	// Verify the leaf certificate.
   325  	opts := x509.VerifyOptions{
   326  		CurrentTime:   time.Now(),
   327  		Roots:         rootCertPool,
   328  		Intermediates: intermediateCertPool,
   329  	}
   330  	x509LeafCert, err := x509.ParseCertificate(derCertChain[0])
   331  	if err != nil {
   332  		s := fmt.Sprintf("client peer verification failed: %v", err)
   333  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), err
   334  	}
   335  	if _, err := x509LeafCert.Verify(opts); err != nil {
   336  		s := fmt.Sprintf("client peer verification failed: %v", err)
   337  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
   338  	}
   339  	return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_SUCCESS, "client peer verification succeeded", &s2av2ctx.S2AContext{}), nil
   340  }
   341  
   342  func verifyServerPeer(req *s2av2pb.ValidatePeerCertificateChainReq, serverAuthorizationPolicy []byte) (*s2av2pb.SessionResp, error) {
   343  	if serverAuthorizationPolicy != nil {
   344  		if got := req.GetServerPeer().SerializedUnrestrictedClientPolicy; !bytes.Equal(got, serverAuthorizationPolicy) {
   345  			err := fmt.Errorf("server peer verification failed: invalid server authorization policy, expected: %s, got: %s",
   346  				serverAuthorizationPolicy, got)
   347  			return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   348  		}
   349  	}
   350  	derCertChain := req.GetServerPeer().CertificateChain
   351  	if len(derCertChain) == 0 {
   352  		s := "server peer verification failed: server cert chain is empty"
   353  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
   354  	}
   355  
   356  	// Obtain the set of root certificates.
   357  	rootCertPool := x509.NewCertPool()
   358  	if ok := rootCertPool.AppendCertsFromPEM(serverCert); ok != true {
   359  		err := errors.New("server peer verification failed: S2Av2 could not obtain/parse roots")
   360  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   361  	}
   362  
   363  	// Set the Intermediates: certs between leaf and root, excluding the leaf and root.
   364  	intermediateCertPool := x509.NewCertPool()
   365  	for i := 1; i < (len(derCertChain)); i++ {
   366  		x509Cert, err := x509.ParseCertificate(derCertChain[i])
   367  		if err != nil {
   368  			return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
   369  		}
   370  		intermediateCertPool.AddCert(x509Cert)
   371  	}
   372  
   373  	// Verify the leaf certificate.
   374  	opts := x509.VerifyOptions{
   375  		CurrentTime:   time.Now(),
   376  		Roots:         rootCertPool,
   377  		Intermediates: intermediateCertPool,
   378  	}
   379  	x509LeafCert, err := x509.ParseCertificate(derCertChain[0])
   380  	if err != nil {
   381  		s := fmt.Sprintf("server peer verification failed: %v", err)
   382  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), err
   383  	}
   384  	if _, err := x509LeafCert.Verify(opts); err != nil {
   385  		s := fmt.Sprintf("server peer verification failed: %v", err)
   386  		return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
   387  	}
   388  
   389  	return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_SUCCESS, "server peer verification succeeded", &s2av2ctx.S2AContext{}), nil
   390  }
   391  

View as plain text