1
18
19
20
21 package v2
22
23 import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "net"
28 "os"
29 "time"
30
31 "github.com/golang/protobuf/proto"
32 "github.com/google/s2a-go/fallback"
33 "github.com/google/s2a-go/internal/handshaker/service"
34 "github.com/google/s2a-go/internal/tokenmanager"
35 "github.com/google/s2a-go/internal/v2/tlsconfigstore"
36 "github.com/google/s2a-go/retry"
37 "github.com/google/s2a-go/stream"
38 "google.golang.org/grpc"
39 "google.golang.org/grpc/credentials"
40 "google.golang.org/grpc/grpclog"
41
42 commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
43 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
44 )
45
46 const (
47 s2aSecurityProtocol = "tls"
48 defaultS2ATimeout = 6 * time.Second
49 )
50
51
52 const s2aTimeoutEnv = "S2A_TIMEOUT"
53
54 type s2av2TransportCreds struct {
55 info *credentials.ProtocolInfo
56 isClient bool
57 serverName string
58 s2av2Address string
59 transportCreds credentials.TransportCredentials
60 tokenManager *tokenmanager.AccessTokenManager
61
62 localIdentity *commonpbv1.Identity
63
64 localIdentities []*commonpbv1.Identity
65 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
66 fallbackClientHandshake fallback.ClientHandshake
67 getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
68 serverAuthorizationPolicy []byte
69 }
70
71
72
73 func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
74
75 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
76
77 creds := &s2av2TransportCreds{
78 info: &credentials.ProtocolInfo{
79 SecurityProtocol: s2aSecurityProtocol,
80 },
81 isClient: true,
82 serverName: "",
83 s2av2Address: s2av2Address,
84 transportCreds: transportCreds,
85 localIdentity: localIdentity,
86 verificationMode: verificationMode,
87 fallbackClientHandshake: fallbackClientHandshakeFunc,
88 getS2AStream: getS2AStream,
89 serverAuthorizationPolicy: serverAuthorizationPolicy,
90 }
91 if err != nil {
92 creds.tokenManager = nil
93 } else {
94 creds.tokenManager = &accessTokenManager
95 }
96 if grpclog.V(1) {
97 grpclog.Info("Created client S2Av2 transport credentials.")
98 }
99 return creds, nil
100 }
101
102
103
104 func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
105
106 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
107 creds := &s2av2TransportCreds{
108 info: &credentials.ProtocolInfo{
109 SecurityProtocol: s2aSecurityProtocol,
110 },
111 isClient: false,
112 s2av2Address: s2av2Address,
113 transportCreds: transportCreds,
114 localIdentities: localIdentities,
115 verificationMode: verificationMode,
116 getS2AStream: getS2AStream,
117 }
118 if err != nil {
119 creds.tokenManager = nil
120 } else {
121 creds.tokenManager = &accessTokenManager
122 }
123 if grpclog.V(1) {
124 grpclog.Info("Created server S2Av2 transport credentials.")
125 }
126 return creds, nil
127 }
128
129
130 func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
131 if !c.isClient {
132 return nil, nil, errors.New("client handshake called using server transport credentials")
133 }
134
135 serverName := removeServerNamePort(serverAuthority)
136 timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
137 defer cancel()
138 var s2AStream stream.S2AStream
139 var err error
140 retry.Run(timeoutCtx,
141 func() error {
142 s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
143 return err
144 })
145 if err != nil {
146 grpclog.Infof("Failed to connect to S2Av2: %v", err)
147 if c.fallbackClientHandshake != nil {
148 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
149 }
150 return nil, nil, err
151 }
152 defer s2AStream.CloseSend()
153 if grpclog.V(1) {
154 grpclog.Infof("Connected to S2Av2.")
155 }
156 var config *tls.Config
157
158 var tokenManager tokenmanager.AccessTokenManager
159 if c.tokenManager == nil {
160 tokenManager = nil
161 } else {
162 tokenManager = *c.tokenManager
163 }
164
165 sn := serverName
166 if c.serverName != "" {
167 sn = c.serverName
168 }
169 retry.Run(timeoutCtx,
170 func() error {
171 config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
172 return err
173 })
174 if err != nil {
175 grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
176 if c.fallbackClientHandshake != nil {
177 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
178 }
179 return nil, nil, err
180 }
181 if grpclog.V(1) {
182 grpclog.Infof("Got client TLS config from S2Av2.")
183 }
184
185 creds := credentials.NewTLS(config)
186 var conn net.Conn
187 var authInfo credentials.AuthInfo
188 retry.Run(timeoutCtx,
189 func() error {
190 conn, authInfo, err = creds.ClientHandshake(timeoutCtx, serverName, rawConn)
191 return err
192 })
193 if err != nil {
194 grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
195 if c.fallbackClientHandshake != nil {
196 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
197 }
198 return nil, nil, err
199 }
200 grpclog.Infof("Successfully done client handshake using S2Av2 to: %s", serverName)
201
202 return conn, authInfo, err
203 }
204
205
206 func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
207 if c.isClient {
208 return nil, nil, errors.New("server handshake called using client transport credentials")
209 }
210 ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
211 defer cancel()
212 var s2AStream stream.S2AStream
213 var err error
214 retry.Run(ctx,
215 func() error {
216 s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
217 return err
218 })
219 if err != nil {
220 grpclog.Infof("Failed to connect to S2Av2: %v", err)
221 return nil, nil, err
222 }
223 defer s2AStream.CloseSend()
224 if grpclog.V(1) {
225 grpclog.Infof("Connected to S2Av2.")
226 }
227
228 var tokenManager tokenmanager.AccessTokenManager
229 if c.tokenManager == nil {
230 tokenManager = nil
231 } else {
232 tokenManager = *c.tokenManager
233 }
234
235 var config *tls.Config
236 retry.Run(ctx,
237 func() error {
238 config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
239 return err
240 })
241 if err != nil {
242 grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
243 return nil, nil, err
244 }
245 if grpclog.V(1) {
246 grpclog.Infof("Got server TLS config from S2Av2.")
247 }
248
249 creds := credentials.NewTLS(config)
250 var conn net.Conn
251 var authInfo credentials.AuthInfo
252 retry.Run(ctx,
253 func() error {
254 conn, authInfo, err = creds.ServerHandshake(rawConn)
255 return err
256 })
257 if err != nil {
258 grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
259 return nil, nil, err
260 }
261 return conn, authInfo, err
262 }
263
264
265 func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
266 return *c.info
267 }
268
269
270 func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
271 info := *c.info
272 serverName := c.serverName
273 fallbackClientHandshake := c.fallbackClientHandshake
274
275 s2av2Address := c.s2av2Address
276 var tokenManager tokenmanager.AccessTokenManager
277 if c.tokenManager == nil {
278 tokenManager = nil
279 } else {
280 tokenManager = *c.tokenManager
281 }
282 verificationMode := c.verificationMode
283 var localIdentity *commonpbv1.Identity
284 if c.localIdentity != nil {
285 localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
286 }
287 var localIdentities []*commonpbv1.Identity
288 if c.localIdentities != nil {
289 localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
290 for i, localIdentity := range c.localIdentities {
291 localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
292 }
293 }
294 creds := &s2av2TransportCreds{
295 info: &info,
296 isClient: c.isClient,
297 serverName: serverName,
298 fallbackClientHandshake: fallbackClientHandshake,
299 s2av2Address: s2av2Address,
300 localIdentity: localIdentity,
301 localIdentities: localIdentities,
302 verificationMode: verificationMode,
303 }
304 if c.tokenManager == nil {
305 creds.tokenManager = nil
306 } else {
307 creds.tokenManager = &tokenManager
308 }
309 return creds
310 }
311
312
313
314 func NewClientTLSConfig(
315 ctx context.Context,
316 s2av2Address string,
317 transportCreds credentials.TransportCredentials,
318 tokenManager tokenmanager.AccessTokenManager,
319 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
320 serverName string,
321 serverAuthorizationPolicy []byte) (*tls.Config, error) {
322 s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
323 if err != nil {
324 grpclog.Infof("Failed to connect to S2Av2: %v", err)
325 return nil, err
326 }
327
328 return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
329 }
330
331
332
333 func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
334 serverName := removeServerNamePort(serverNameOverride)
335 c.info.ServerName = serverName
336 c.serverName = serverName
337 return nil
338 }
339
340
341 func removeServerNamePort(serverName string) string {
342 name, _, err := net.SplitHostPort(serverName)
343 if err != nil {
344 name = serverName
345 }
346 return name
347 }
348
349 type s2AGrpcStream struct {
350 stream s2av2pb.S2AService_SetUpSessionClient
351 }
352
353 func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
354 return x.stream.Send(m)
355 }
356
357 func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
358 return x.stream.Recv()
359 }
360
361 func (x s2AGrpcStream) CloseSend() error {
362 return x.stream.CloseSend()
363 }
364
365 func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
366 if getS2AStream != nil {
367 return getS2AStream(ctx, s2av2Address)
368 }
369
370 conn, err := service.Dial(ctx, s2av2Address, transportCreds)
371 if err != nil {
372 return nil, err
373 }
374 client := s2av2pb.NewS2AServiceClient(conn)
375 gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
376 if err != nil {
377 return nil, err
378 }
379 return &s2AGrpcStream{
380 stream: gRPCStream,
381 }, nil
382 }
383
384
385 func GetS2ATimeout() time.Duration {
386 timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
387 if err != nil {
388 return defaultS2ATimeout
389 }
390 return timeout
391 }
392
View as plain text