1
18
19
20 package fakes2av2
21
22 import (
23 "bytes"
24 "crypto"
25 "crypto/rand"
26 "crypto/rsa"
27 "crypto/tls"
28 "crypto/x509"
29 "errors"
30 "fmt"
31 "log"
32 "time"
33
34 "google.golang.org/grpc/codes"
35
36 _ "embed"
37
38 commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
39 s2av2ctx "github.com/google/s2a-go/internal/proto/v2/s2a_context_go_proto"
40 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
41 )
42
43 var (
44
45 clientCert []byte
46
47 clientDERCert []byte
48
49 clientKey []byte
50
51 serverCert []byte
52
53 serverDERCert []byte
54
55 serverKey []byte
56 )
57
58
59 type Server struct {
60 s2av2pb.UnimplementedS2AServiceServer
61
62 ExpectedToken string
63
64
65 ShouldNotReturnClientCredentials bool
66 isAssistingClientSide bool
67 ServerAuthorizationPolicy []byte
68
69
70
71 }
72
73
74
75 func (s *Server) SetUpSession(stream s2av2pb.S2AService_SetUpSessionServer) error {
76 for {
77 req, err := stream.Recv()
78 if err != nil {
79 log.Printf("Fake S2A Service: failed to receive SessionReq: %v", err)
80 return err
81 }
82
83
84 var resp *s2av2pb.SessionResp
85 switch x := req.ReqOneof.(type) {
86 case *s2av2pb.SessionReq_GetTlsConfigurationReq:
87 if err := s.hasValidToken(req.GetAuthenticationMechanisms()); err != nil {
88 log.Printf("Fake S2A Service: authentication error: %v", err)
89 return err
90 }
91 if err := s.findConnectionSide(req); err != nil {
92 resp = &s2av2pb.SessionResp{
93 Status: &s2av2pb.Status{
94 Code: uint32(codes.InvalidArgument),
95 Details: err.Error(),
96 },
97 }
98 break
99 }
100 resp, err = getTLSConfiguration(req.GetGetTlsConfigurationReq(), s.ShouldNotReturnClientCredentials)
101 if err != nil {
102 log.Printf("Fake S2A Service: failed to build SessionResp with GetTlsConfigurationResp: %v", err)
103 return err
104 }
105 case *s2av2pb.SessionReq_OffloadPrivateKeyOperationReq:
106 resp, err = offloadPrivateKeyOperation(req.GetOffloadPrivateKeyOperationReq(), s.isAssistingClientSide)
107 if err != nil {
108 log.Printf("Fake S2A Service: failed to build SessionResp with OffloadPrivateKeyOperationResp: %v", err)
109 return err
110 }
111 case *s2av2pb.SessionReq_OffloadResumptionKeyOperationReq:
112
113 case *s2av2pb.SessionReq_ValidatePeerCertificateChainReq:
114 resp, err = validatePeerCertificateChain(req.GetValidatePeerCertificateChainReq(), s.ServerAuthorizationPolicy)
115 if err != nil {
116 log.Printf("Fake S2A Service: failed to build SessionResp with ValidatePeerCertificateChainResp: %v", err)
117 return err
118 }
119 default:
120 return fmt.Errorf("SessionReq.ReqOneof has unexpected type %T", x)
121 }
122 if err := stream.Send(resp); err != nil {
123 log.Printf("Fake S2A Service: failed to send SessionResp: %v", err)
124 return err
125 }
126 }
127 }
128
129 func (s *Server) findConnectionSide(req *s2av2pb.SessionReq) error {
130 switch connSide := req.GetGetTlsConfigurationReq().GetConnectionSide(); connSide {
131 case commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT:
132 s.isAssistingClientSide = true
133 case commonpb.ConnectionSide_CONNECTION_SIDE_SERVER:
134 s.isAssistingClientSide = false
135 default:
136 return fmt.Errorf("unknown ConnectionSide: %v", connSide)
137 }
138 return nil
139 }
140
141 func (s *Server) hasValidToken(authMechanisms []*s2av2pb.AuthenticationMechanism) error {
142 if len(authMechanisms) == 0 {
143 return nil
144 }
145 for _, v := range authMechanisms {
146 token := v.GetToken()
147 if token == s.ExpectedToken {
148 return nil
149 }
150 }
151 return errors.New("SessionReq has no AuthenticationMechanism with a valid token")
152 }
153
154 func offloadPrivateKeyOperation(req *s2av2pb.OffloadPrivateKeyOperationReq, isAssistingClientSide bool) (*s2av2pb.SessionResp, error) {
155 switch x := req.GetOperation(); x {
156 case s2av2pb.OffloadPrivateKeyOperationReq_SIGN:
157 var root tls.Certificate
158 var err error
159
160 if isAssistingClientSide {
161 root, err = tls.X509KeyPair(clientCert, clientKey)
162 if err != nil {
163 return nil, err
164 }
165 } else {
166 root, err = tls.X509KeyPair(serverCert, serverKey)
167 if err != nil {
168 return nil, err
169 }
170 }
171 var signedBytes []byte
172 if req.GetSignatureAlgorithm() == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA256 {
173 signedBytes, err = root.PrivateKey.(crypto.Signer).Sign(rand.Reader, req.GetSha256Digest(), crypto.SHA256)
174 if err != nil {
175 return nil, err
176 }
177 } else if req.GetSignatureAlgorithm() == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256 {
178 opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256}
179 signedBytes, err = root.PrivateKey.(crypto.Signer).Sign(rand.Reader, req.GetSha256Digest(), opts)
180 if err != nil {
181 return nil, err
182 }
183 } else {
184 return &s2av2pb.SessionResp{
185 Status: &s2av2pb.Status{
186 Code: uint32(codes.InvalidArgument),
187 Details: fmt.Sprintf("invalid signature algorithm: %v", req.GetSignatureAlgorithm()),
188 },
189 }, nil
190 }
191 return &s2av2pb.SessionResp{
192 Status: &s2av2pb.Status{
193 Code: uint32(codes.OK),
194 },
195 RespOneof: &s2av2pb.SessionResp_OffloadPrivateKeyOperationResp{
196 OffloadPrivateKeyOperationResp: &s2av2pb.OffloadPrivateKeyOperationResp{
197 OutBytes: signedBytes,
198 },
199 },
200 }, nil
201 case s2av2pb.OffloadPrivateKeyOperationReq_DECRYPT:
202 return nil, errors.New("decrypt operation not implemented yet")
203 default:
204 return nil, fmt.Errorf("unspecified private key operation requested: %d", x)
205 }
206 }
207
208 func validatePeerCertificateChain(req *s2av2pb.ValidatePeerCertificateChainReq, serverAuthorizationPolicy []byte) (*s2av2pb.SessionResp, error) {
209 switch x := req.PeerOneof.(type) {
210 case *s2av2pb.ValidatePeerCertificateChainReq_ClientPeer_:
211 return verifyClientPeer(req)
212 case *s2av2pb.ValidatePeerCertificateChainReq_ServerPeer_:
213 return verifyServerPeer(req, serverAuthorizationPolicy)
214 default:
215 err := fmt.Errorf("peer verification failed: invalid Peer type %T", x)
216 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
217 }
218 }
219
220
221 func getTLSConfiguration(req *s2av2pb.GetTlsConfigurationReq, shouldNotReturnClientCredentials bool) (*s2av2pb.SessionResp, error) {
222 if req.GetConnectionSide() == commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT {
223 if shouldNotReturnClientCredentials {
224 return &s2av2pb.SessionResp{
225 Status: &s2av2pb.Status{
226 Code: uint32(codes.OK),
227 },
228 RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
229 GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
230 TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration_{
231 ClientTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration{
232 MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
233 MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
234 },
235 },
236 },
237 },
238 }, nil
239 }
240 return &s2av2pb.SessionResp{
241 Status: &s2av2pb.Status{
242 Code: uint32(codes.OK),
243 },
244 RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
245 GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
246 TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration_{
247 ClientTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration{
248 CertificateChain: []string{
249 string(clientCert),
250 },
251 MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
252 MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
253 },
254 },
255 },
256 },
257 }, nil
258 } else if req.GetConnectionSide() == commonpb.ConnectionSide_CONNECTION_SIDE_SERVER {
259 return &s2av2pb.SessionResp{
260 Status: &s2av2pb.Status{
261 Code: uint32(codes.OK),
262 },
263 RespOneof: &s2av2pb.SessionResp_GetTlsConfigurationResp{
264 GetTlsConfigurationResp: &s2av2pb.GetTlsConfigurationResp{
265 TlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_{
266 ServerTlsConfiguration: &s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration{
267 CertificateChain: []string{
268 string(serverCert),
269 },
270 MinTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
271 MaxTlsVersion: commonpb.TLSVersion_TLS_VERSION_1_3,
272 TlsResumptionEnabled: false,
273 RequestClientCertificate: s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY,
274 MaxOverheadOfTicketAead: 0,
275 },
276 },
277 },
278 },
279 }, nil
280 }
281 return nil, fmt.Errorf("unspecified connection side: %v", req.GetConnectionSide())
282 }
283
284 func buildValidatePeerCertificateChainSessionResp(StatusCode uint32, StatusDetails string, ValidationResult s2av2pb.ValidatePeerCertificateChainResp_ValidationResult, ValidationDetails string, Context *s2av2ctx.S2AContext) *s2av2pb.SessionResp {
285 return &s2av2pb.SessionResp{
286 Status: &s2av2pb.Status{
287 Code: StatusCode,
288 Details: StatusDetails,
289 },
290 RespOneof: &s2av2pb.SessionResp_ValidatePeerCertificateChainResp{
291 ValidatePeerCertificateChainResp: &s2av2pb.ValidatePeerCertificateChainResp{
292 ValidationResult: ValidationResult,
293 ValidationDetails: ValidationDetails,
294 Context: Context,
295 },
296 },
297 }
298 }
299
300 func verifyClientPeer(req *s2av2pb.ValidatePeerCertificateChainReq) (*s2av2pb.SessionResp, error) {
301 derCertChain := req.GetClientPeer().CertificateChain
302 if len(derCertChain) == 0 {
303 s := "client peer verification failed: client cert chain is empty"
304 return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
305 }
306
307
308 rootCertPool := x509.NewCertPool()
309 if ok := rootCertPool.AppendCertsFromPEM(clientCert); ok != true {
310 err := errors.New("client peer verification failed: S2Av2 could not obtain/parse roots")
311 return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
312 }
313
314
315 intermediateCertPool := x509.NewCertPool()
316 for i := 1; i < (len(derCertChain)); i++ {
317 x509Cert, err := x509.ParseCertificate(derCertChain[i])
318 if err != nil {
319 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
320 }
321 intermediateCertPool.AddCert(x509Cert)
322 }
323
324
325 opts := x509.VerifyOptions{
326 CurrentTime: time.Now(),
327 Roots: rootCertPool,
328 Intermediates: intermediateCertPool,
329 }
330 x509LeafCert, err := x509.ParseCertificate(derCertChain[0])
331 if err != nil {
332 s := fmt.Sprintf("client peer verification failed: %v", err)
333 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), err
334 }
335 if _, err := x509LeafCert.Verify(opts); err != nil {
336 s := fmt.Sprintf("client peer verification failed: %v", err)
337 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
338 }
339 return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_SUCCESS, "client peer verification succeeded", &s2av2ctx.S2AContext{}), nil
340 }
341
342 func verifyServerPeer(req *s2av2pb.ValidatePeerCertificateChainReq, serverAuthorizationPolicy []byte) (*s2av2pb.SessionResp, error) {
343 if serverAuthorizationPolicy != nil {
344 if got := req.GetServerPeer().SerializedUnrestrictedClientPolicy; !bytes.Equal(got, serverAuthorizationPolicy) {
345 err := fmt.Errorf("server peer verification failed: invalid server authorization policy, expected: %s, got: %s",
346 serverAuthorizationPolicy, got)
347 return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
348 }
349 }
350 derCertChain := req.GetServerPeer().CertificateChain
351 if len(derCertChain) == 0 {
352 s := "server peer verification failed: server cert chain is empty"
353 return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
354 }
355
356
357 rootCertPool := x509.NewCertPool()
358 if ok := rootCertPool.AppendCertsFromPEM(serverCert); ok != true {
359 err := errors.New("server peer verification failed: S2Av2 could not obtain/parse roots")
360 return buildValidatePeerCertificateChainSessionResp(uint32(codes.Internal), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
361 }
362
363
364 intermediateCertPool := x509.NewCertPool()
365 for i := 1; i < (len(derCertChain)); i++ {
366 x509Cert, err := x509.ParseCertificate(derCertChain[i])
367 if err != nil {
368 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), err.Error(), s2av2pb.ValidatePeerCertificateChainResp_FAILURE, err.Error(), &s2av2ctx.S2AContext{}), err
369 }
370 intermediateCertPool.AddCert(x509Cert)
371 }
372
373
374 opts := x509.VerifyOptions{
375 CurrentTime: time.Now(),
376 Roots: rootCertPool,
377 Intermediates: intermediateCertPool,
378 }
379 x509LeafCert, err := x509.ParseCertificate(derCertChain[0])
380 if err != nil {
381 s := fmt.Sprintf("server peer verification failed: %v", err)
382 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), err
383 }
384 if _, err := x509LeafCert.Verify(opts); err != nil {
385 s := fmt.Sprintf("server peer verification failed: %v", err)
386 return buildValidatePeerCertificateChainSessionResp(uint32(codes.InvalidArgument), s, s2av2pb.ValidatePeerCertificateChainResp_FAILURE, s, &s2av2ctx.S2AContext{}), nil
387 }
388
389 return buildValidatePeerCertificateChainSessionResp(uint32(codes.OK), "", s2av2pb.ValidatePeerCertificateChainResp_SUCCESS, "server peer verification succeeded", &s2av2ctx.S2AContext{}), nil
390 }
391
View as plain text