...

Source file src/github.com/google/s2a-go/internal/fakehandshaker/service/s2a_service.go

Documentation: github.com/google/s2a-go/internal/fakehandshaker/service

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package service is a fake S2A handshaker service.
    20  package service
    21  
    22  import (
    23  	"bytes"
    24  	"fmt"
    25  	"os"
    26  
    27  	"google.golang.org/grpc/codes"
    28  
    29  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    30  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    31  )
    32  
    33  type handshakeState int
    34  
    35  const (
    36  	// initial is the state of the handshaker service before any handshake
    37  	// message has been received.
    38  	initial handshakeState = 0
    39  	// started is the state of the handshaker service when the handshake has
    40  	// been initiated but no bytes have been sent or received.
    41  	started handshakeState = 1
    42  	// sent is the state of the handshaker service when the handshake has been
    43  	// initiated and bytes have been sent.
    44  	sent handshakeState = 2
    45  	// completed is the state of the handshaker service when the handshake has
    46  	// been completed.
    47  	completed handshakeState = 3
    48  )
    49  
    50  const (
    51  	accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
    52  	grpcAppProtocol        = "grpc"
    53  	clientHelloFrame       = "ClientHello"
    54  	clientFinishedFrame    = "ClientFinished"
    55  	serverFrame            = "ServerHelloAndFinished"
    56  )
    57  
    58  const (
    59  	inKey  = "kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk"
    60  	outKey = "kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk"
    61  )
    62  
    63  // FakeHandshakerService implements the s2apb.S2AServiceServer. The fake
    64  // handshaker service should not be used by more than 1 application at a time.
    65  type FakeHandshakerService struct {
    66  	s2apb.S2AServiceServer
    67  
    68  	assistingClient bool
    69  	state           handshakeState
    70  	peerIdentity    *commonpb.Identity
    71  	localIdentity   *commonpb.Identity
    72  }
    73  
    74  // SetUpSession sets up the S2A session.
    75  func (hs *FakeHandshakerService) SetUpSession(stream s2apb.S2AService_SetUpSessionServer) error {
    76  	for {
    77  		sessionReq, err := stream.Recv()
    78  		if err != nil {
    79  			return fmt.Errorf("stream recv failed: %v", err)
    80  		}
    81  		if err := hs.authenticateRequest(sessionReq); err != nil {
    82  			return fmt.Errorf("S2A cannot authenticate the request: %v", err)
    83  		}
    84  
    85  		var resp *s2apb.SessionResp
    86  		receivedTicket := false
    87  		switch req := sessionReq.ReqOneof.(type) {
    88  		case *s2apb.SessionReq_ClientStart:
    89  			resp = hs.processClientStart(req)
    90  		case *s2apb.SessionReq_ServerStart:
    91  			resp = hs.processServerStart(req)
    92  		case *s2apb.SessionReq_Next:
    93  			resp = hs.processNext(req)
    94  		case *s2apb.SessionReq_ResumptionTicket:
    95  			resp = hs.processResumptionTicket(req)
    96  			receivedTicket = true
    97  		default:
    98  			return fmt.Errorf("session request has unexpected type %T", req)
    99  		}
   100  
   101  		if err = stream.Send(resp); err != nil {
   102  			return fmt.Errorf("stream send failed: %v", err)
   103  		}
   104  
   105  		if receivedTicket || resp.GetResult() != nil {
   106  			return nil
   107  		}
   108  	}
   109  }
   110  
   111  // processClientStart processes a ClientSessionStartReq.
   112  func (hs *FakeHandshakerService) processClientStart(req *s2apb.SessionReq_ClientStart) *s2apb.SessionResp {
   113  	resp := s2apb.SessionResp{}
   114  	if hs.state != initial {
   115  		resp.Status = &s2apb.SessionStatus{
   116  			Code:    uint32(codes.FailedPrecondition),
   117  			Details: "client start handshake not in initial state",
   118  		}
   119  		return &resp
   120  	}
   121  	if len(req.ClientStart.GetApplicationProtocols()) != 1 ||
   122  		req.ClientStart.GetApplicationProtocols()[0] != grpcAppProtocol {
   123  		resp.Status = &s2apb.SessionStatus{
   124  			Code:    uint32(codes.InvalidArgument),
   125  			Details: "application protocol was not grpc",
   126  		}
   127  		return &resp
   128  	}
   129  	if req.ClientStart.GetMaxTlsVersion() != commonpb.TLSVersion_TLS1_3 {
   130  		resp.Status = &s2apb.SessionStatus{
   131  			Code:    uint32(codes.InvalidArgument),
   132  			Details: "max TLS version must be 1.3",
   133  		}
   134  		return &resp
   135  	}
   136  	if req.ClientStart.GetMinTlsVersion() != commonpb.TLSVersion_TLS1_3 {
   137  		resp.Status = &s2apb.SessionStatus{
   138  			Code:    uint32(codes.InvalidArgument),
   139  			Details: "min TLS version must be 1.3",
   140  		}
   141  		return &resp
   142  	}
   143  	resp.OutFrames = []byte(clientHelloFrame)
   144  	resp.BytesConsumed = 0
   145  	resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
   146  	hs.localIdentity = req.ClientStart.LocalIdentity
   147  	if len(req.ClientStart.TargetIdentities) > 0 {
   148  		hs.peerIdentity = req.ClientStart.TargetIdentities[0]
   149  	}
   150  	hs.assistingClient = true
   151  	hs.state = sent
   152  	return &resp
   153  }
   154  
   155  // processServerStart processes a ServerSessionStartReq.
   156  func (hs *FakeHandshakerService) processServerStart(req *s2apb.SessionReq_ServerStart) *s2apb.SessionResp {
   157  	resp := s2apb.SessionResp{}
   158  	if hs.state != initial {
   159  		resp.Status = &s2apb.SessionStatus{
   160  			Code:    uint32(codes.FailedPrecondition),
   161  			Details: "server start handshake not in initial state",
   162  		}
   163  		return &resp
   164  	}
   165  	if len(req.ServerStart.GetApplicationProtocols()) != 1 ||
   166  		req.ServerStart.GetApplicationProtocols()[0] != grpcAppProtocol {
   167  		resp.Status = &s2apb.SessionStatus{
   168  			Code:    uint32(codes.InvalidArgument),
   169  			Details: "application protocol was not grpc",
   170  		}
   171  		return &resp
   172  	}
   173  	if req.ServerStart.GetMaxTlsVersion() != commonpb.TLSVersion_TLS1_3 {
   174  		resp.Status = &s2apb.SessionStatus{
   175  			Code:    uint32(codes.InvalidArgument),
   176  			Details: "max TLS version must be 1.3",
   177  		}
   178  		return &resp
   179  	}
   180  	if req.ServerStart.GetMinTlsVersion() != commonpb.TLSVersion_TLS1_3 {
   181  		resp.Status = &s2apb.SessionStatus{
   182  			Code:    uint32(codes.InvalidArgument),
   183  			Details: "min TLS version must be 1.3",
   184  		}
   185  		return &resp
   186  	}
   187  	if len(req.ServerStart.InBytes) == 0 {
   188  		resp.BytesConsumed = 0
   189  		hs.state = started
   190  	} else if bytes.Equal(req.ServerStart.InBytes, []byte(clientHelloFrame)) {
   191  		resp.OutFrames = []byte(serverFrame)
   192  		resp.BytesConsumed = uint32(len(clientHelloFrame))
   193  		hs.state = sent
   194  	} else {
   195  		resp.Status = &s2apb.SessionStatus{
   196  			Code:    uint32(codes.Internal),
   197  			Details: "server start request did not have the correct input bytes",
   198  		}
   199  		return &resp
   200  	}
   201  
   202  	resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
   203  	if len(req.ServerStart.LocalIdentities) > 0 {
   204  		hs.localIdentity = req.ServerStart.LocalIdentities[0]
   205  	}
   206  	hs.assistingClient = false
   207  	return &resp
   208  }
   209  
   210  // processNext processes a SessionNext request.
   211  func (hs *FakeHandshakerService) processNext(req *s2apb.SessionReq_Next) *s2apb.SessionResp {
   212  	resp := s2apb.SessionResp{}
   213  	if hs.assistingClient {
   214  		if hs.state != sent {
   215  			resp.Status = &s2apb.SessionStatus{
   216  				Code:    uint32(codes.FailedPrecondition),
   217  				Details: "client handshake was not in sent state",
   218  			}
   219  			return &resp
   220  		}
   221  		if !bytes.Equal(req.Next.InBytes, []byte(serverFrame)) {
   222  			resp.Status = &s2apb.SessionStatus{
   223  				Code:    uint32(codes.Internal),
   224  				Details: "client request did not match server frame",
   225  			}
   226  			return &resp
   227  		}
   228  		resp.OutFrames = []byte(clientFinishedFrame)
   229  		resp.BytesConsumed = uint32(len(serverFrame))
   230  		hs.state = completed
   231  	} else {
   232  		if hs.state == started {
   233  			if !bytes.Equal(req.Next.InBytes, []byte(clientHelloFrame)) {
   234  				resp.Status = &s2apb.SessionStatus{
   235  					Code:    uint32(codes.Internal),
   236  					Details: "server request did not match client hello frame",
   237  				}
   238  				return &resp
   239  			}
   240  			resp.OutFrames = []byte(serverFrame)
   241  			resp.BytesConsumed = uint32(len(clientHelloFrame))
   242  			hs.state = sent
   243  		} else if hs.state == sent {
   244  			if !bytes.Equal(req.Next.InBytes[:len(clientFinishedFrame)], []byte(clientFinishedFrame)) {
   245  				resp.Status = &s2apb.SessionStatus{
   246  					Code:    uint32(codes.Internal),
   247  					Details: "server request did not match client finished frame",
   248  				}
   249  				return &resp
   250  			}
   251  			resp.BytesConsumed = uint32(len(clientFinishedFrame))
   252  			hs.state = completed
   253  		} else {
   254  			resp.Status = &s2apb.SessionStatus{
   255  				Code:    uint32(codes.FailedPrecondition),
   256  				Details: "server request was not in expected state",
   257  			}
   258  			return &resp
   259  		}
   260  	}
   261  	resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
   262  	if hs.state == completed {
   263  		resp.Result = hs.getSessionResult()
   264  	}
   265  	return &resp
   266  }
   267  
   268  // processResumptionTicket processes a ResumptionTicketReq request.
   269  func (hs *FakeHandshakerService) processResumptionTicket(req *s2apb.SessionReq_ResumptionTicket) *s2apb.SessionResp {
   270  	return &s2apb.SessionResp{
   271  		Status: &s2apb.SessionStatus{Code: uint32(codes.OK)},
   272  	}
   273  }
   274  
   275  // getSessionResult returns a dummy SessionResult.
   276  func (hs *FakeHandshakerService) getSessionResult() *s2apb.SessionResult {
   277  	res := s2apb.SessionResult{}
   278  	res.ApplicationProtocol = grpcAppProtocol
   279  	res.State = &s2apb.SessionState{
   280  		TlsVersion:     commonpb.TLSVersion_TLS1_3,
   281  		TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
   282  		InKey:          []byte(inKey),
   283  		OutKey:         []byte(outKey),
   284  	}
   285  	res.PeerIdentity = hs.peerIdentity
   286  	res.LocalIdentity = hs.localIdentity
   287  	return &res
   288  }
   289  
   290  func (hs *FakeHandshakerService) authenticateRequest(request *s2apb.SessionReq) error {
   291  	// If the S2A_ACCESS_TOKEN environment variable has not been set, then do not
   292  	// enforce anything on the request.
   293  	acceptedToken := os.Getenv(accessTokenEnvVariable)
   294  	if acceptedToken == "" {
   295  		return nil
   296  	}
   297  	if len(request.GetAuthMechanisms()) == 0 {
   298  		return fmt.Errorf("expected token but none was received")
   299  	}
   300  	for _, authMechanism := range request.GetAuthMechanisms() {
   301  		if authMechanism.GetToken() != acceptedToken {
   302  			return fmt.Errorf("received token: %s, expected token: %s", authMechanism.GetToken(), acceptedToken)
   303  		}
   304  	}
   305  	return nil
   306  }
   307  

View as plain text