1
18
19
20
21 package s2a
22
23 import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "fmt"
28 "net"
29 "sync"
30 "time"
31
32 "github.com/golang/protobuf/proto"
33 "github.com/google/s2a-go/fallback"
34 "github.com/google/s2a-go/internal/handshaker"
35 "github.com/google/s2a-go/internal/handshaker/service"
36 "github.com/google/s2a-go/internal/tokenmanager"
37 "github.com/google/s2a-go/internal/v2"
38 "github.com/google/s2a-go/retry"
39 "google.golang.org/grpc/credentials"
40 "google.golang.org/grpc/grpclog"
41
42 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
43 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
44 )
45
46 const (
47 s2aSecurityProtocol = "tls"
48
49 defaultTimeout = 30.0 * time.Second
50 )
51
52
53
54
55 type s2aTransportCreds struct {
56 info *credentials.ProtocolInfo
57 minTLSVersion commonpb.TLSVersion
58 maxTLSVersion commonpb.TLSVersion
59
60
61 tlsCiphersuites []commonpb.Ciphersuite
62
63 localIdentity *commonpb.Identity
64
65 localIdentities []*commonpb.Identity
66
67 targetIdentities []*commonpb.Identity
68 isClient bool
69 s2aAddr string
70 ensureProcessSessionTickets *sync.WaitGroup
71 }
72
73
74
75 func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
76 if opts == nil {
77 return nil, errors.New("nil client options")
78 }
79 var targetIdentities []*commonpb.Identity
80 for _, targetIdentity := range opts.TargetIdentities {
81 protoTargetIdentity, err := toProtoIdentity(targetIdentity)
82 if err != nil {
83 return nil, err
84 }
85 targetIdentities = append(targetIdentities, protoTargetIdentity)
86 }
87 localIdentity, err := toProtoIdentity(opts.LocalIdentity)
88 if err != nil {
89 return nil, err
90 }
91 if opts.EnableLegacyMode {
92 return &s2aTransportCreds{
93 info: &credentials.ProtocolInfo{
94 SecurityProtocol: s2aSecurityProtocol,
95 },
96 minTLSVersion: commonpb.TLSVersion_TLS1_3,
97 maxTLSVersion: commonpb.TLSVersion_TLS1_3,
98 tlsCiphersuites: []commonpb.Ciphersuite{
99 commonpb.Ciphersuite_AES_128_GCM_SHA256,
100 commonpb.Ciphersuite_AES_256_GCM_SHA384,
101 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
102 },
103 localIdentity: localIdentity,
104 targetIdentities: targetIdentities,
105 isClient: true,
106 s2aAddr: opts.S2AAddress,
107 ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
108 }, nil
109 }
110 verificationMode := getVerificationMode(opts.VerificationMode)
111 var fallbackFunc fallback.ClientHandshake
112 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
113 fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
114 }
115 return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
116 }
117
118
119
120 func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
121 if opts == nil {
122 return nil, errors.New("nil server options")
123 }
124 var localIdentities []*commonpb.Identity
125 for _, localIdentity := range opts.LocalIdentities {
126 protoLocalIdentity, err := toProtoIdentity(localIdentity)
127 if err != nil {
128 return nil, err
129 }
130 localIdentities = append(localIdentities, protoLocalIdentity)
131 }
132 if opts.EnableLegacyMode {
133 return &s2aTransportCreds{
134 info: &credentials.ProtocolInfo{
135 SecurityProtocol: s2aSecurityProtocol,
136 },
137 minTLSVersion: commonpb.TLSVersion_TLS1_3,
138 maxTLSVersion: commonpb.TLSVersion_TLS1_3,
139 tlsCiphersuites: []commonpb.Ciphersuite{
140 commonpb.Ciphersuite_AES_128_GCM_SHA256,
141 commonpb.Ciphersuite_AES_256_GCM_SHA384,
142 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
143 },
144 localIdentities: localIdentities,
145 isClient: false,
146 s2aAddr: opts.S2AAddress,
147 }, nil
148 }
149 verificationMode := getVerificationMode(opts.VerificationMode)
150 return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, localIdentities, verificationMode, opts.getS2AStream)
151 }
152
153
154 func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
155 if !c.isClient {
156 return nil, nil, errors.New("client handshake called using server transport credentials")
157 }
158
159 var cancel context.CancelFunc
160 ctx, cancel = context.WithCancel(ctx)
161 defer cancel()
162
163
164 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
165 if err != nil {
166 grpclog.Infof("Failed to connect to S2A: %v", err)
167 return nil, nil, err
168 }
169
170 opts := &handshaker.ClientHandshakerOptions{
171 MinTLSVersion: c.minTLSVersion,
172 MaxTLSVersion: c.maxTLSVersion,
173 TLSCiphersuites: c.tlsCiphersuites,
174 TargetIdentities: c.targetIdentities,
175 LocalIdentity: c.localIdentity,
176 TargetName: serverAuthority,
177 EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
178 }
179 chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
180 if err != nil {
181 grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
182 return nil, nil, err
183 }
184 defer func() {
185 if err != nil {
186 if closeErr := chs.Close(); closeErr != nil {
187 grpclog.Infof("Close failed unexpectedly: %v", err)
188 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
189 }
190 }
191 }()
192
193 secConn, authInfo, err := chs.ClientHandshake(context.Background())
194 if err != nil {
195 grpclog.Infof("Handshake failed: %v", err)
196 return nil, nil, err
197 }
198 return secConn, authInfo, nil
199 }
200
201
202 func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
203 if c.isClient {
204 return nil, nil, errors.New("server handshake called using client transport credentials")
205 }
206
207 ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
208 defer cancel()
209
210
211 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
212 if err != nil {
213 grpclog.Infof("Failed to connect to S2A: %v", err)
214 return nil, nil, err
215 }
216
217 opts := &handshaker.ServerHandshakerOptions{
218 MinTLSVersion: c.minTLSVersion,
219 MaxTLSVersion: c.maxTLSVersion,
220 TLSCiphersuites: c.tlsCiphersuites,
221 LocalIdentities: c.localIdentities,
222 }
223 shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
224 if err != nil {
225 grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
226 return nil, nil, err
227 }
228 defer func() {
229 if err != nil {
230 if closeErr := shs.Close(); closeErr != nil {
231 grpclog.Infof("Close failed unexpectedly: %v", err)
232 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
233 }
234 }
235 }()
236
237 secConn, authInfo, err := shs.ServerHandshake(context.Background())
238 if err != nil {
239 grpclog.Infof("Handshake failed: %v", err)
240 return nil, nil, err
241 }
242 return secConn, authInfo, nil
243 }
244
245 func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
246 return *c.info
247 }
248
249 func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
250 info := *c.info
251 var localIdentity *commonpb.Identity
252 if c.localIdentity != nil {
253 localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
254 }
255 var localIdentities []*commonpb.Identity
256 if c.localIdentities != nil {
257 localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
258 for i, localIdentity := range c.localIdentities {
259 localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
260 }
261 }
262 var targetIdentities []*commonpb.Identity
263 if c.targetIdentities != nil {
264 targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
265 for i, targetIdentity := range c.targetIdentities {
266 targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
267 }
268 }
269 return &s2aTransportCreds{
270 info: &info,
271 minTLSVersion: c.minTLSVersion,
272 maxTLSVersion: c.maxTLSVersion,
273 tlsCiphersuites: c.tlsCiphersuites,
274 localIdentity: localIdentity,
275 localIdentities: localIdentities,
276 targetIdentities: targetIdentities,
277 isClient: c.isClient,
278 s2aAddr: c.s2aAddr,
279 }
280 }
281
282 func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
283 c.info.ServerName = serverNameOverride
284 return nil
285 }
286
287
288 type TLSClientConfigOptions struct {
289
290
291
292
293 ServerName string
294 }
295
296
297 type TLSClientConfigFactory interface {
298 Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
299 }
300
301
302 func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
303 if opts == nil {
304 return nil, fmt.Errorf("opts must be non-nil")
305 }
306 if opts.EnableLegacyMode {
307 return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
308 }
309 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
310 if err != nil {
311
312
313 grpclog.Infof("Access token manager not initialized: %v", err)
314 return &s2aTLSClientConfigFactory{
315 s2av2Address: opts.S2AAddress,
316 transportCreds: opts.TransportCreds,
317 tokenManager: nil,
318 verificationMode: getVerificationMode(opts.VerificationMode),
319 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
320 }, nil
321 }
322 return &s2aTLSClientConfigFactory{
323 s2av2Address: opts.S2AAddress,
324 transportCreds: opts.TransportCreds,
325 tokenManager: tokenManager,
326 verificationMode: getVerificationMode(opts.VerificationMode),
327 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
328 }, nil
329 }
330
331 type s2aTLSClientConfigFactory struct {
332 s2av2Address string
333 transportCreds credentials.TransportCredentials
334 tokenManager tokenmanager.AccessTokenManager
335 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
336 serverAuthorizationPolicy []byte
337 }
338
339 func (f *s2aTLSClientConfigFactory) Build(
340 ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
341 serverName := ""
342 if opts != nil && opts.ServerName != "" {
343 serverName = opts.ServerName
344 }
345 return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
346 }
347
348 func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
349 switch verificationMode {
350 case ConnectToGoogle:
351 return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
352 case Spiffe:
353 return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
354 default:
355 return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
356 }
357 }
358
359
360
361
362
363
364
365
366
367 func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
368
369 return func(ctx context.Context, network, addr string) (net.Conn, error) {
370
371 fallback := func(err error) (net.Conn, error) {
372 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
373 opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
374 fbDialer := opts.FallbackOpts.FallbackDialer
375 grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
376 fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
377 if fbErr != nil {
378 return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
379 }
380 return fbConn, nil
381 }
382 return nil, err
383 }
384
385 factory, err := NewTLSClientConfigFactory(opts)
386 if err != nil {
387 grpclog.Infof("error creating S2A client config factory: %v", err)
388 return fallback(err)
389 }
390
391 serverName, _, err := net.SplitHostPort(addr)
392 if err != nil {
393 serverName = addr
394 }
395 timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
396 defer cancel()
397
398 var s2aTLSConfig *tls.Config
399 retry.Run(timeoutCtx,
400 func() error {
401 s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
402 ServerName: serverName,
403 })
404 return err
405 })
406 if err != nil {
407 grpclog.Infof("error building S2A TLS config: %v", err)
408 return fallback(err)
409 }
410
411 s2aDialer := &tls.Dialer{
412 Config: s2aTLSConfig,
413 }
414 var c net.Conn
415 retry.Run(timeoutCtx,
416 func() error {
417 c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
418 return err
419 })
420 if err != nil {
421 grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
422 return fallback(err)
423 }
424 grpclog.Infof("success dialing MTLS to %s with S2A", addr)
425 return c, nil
426 }
427 }
428
View as plain text