1
18
19
20 package handshaker
21
22 import (
23 "context"
24 "errors"
25 "fmt"
26 "io"
27 "net"
28 "sync"
29
30 "github.com/google/s2a-go/internal/authinfo"
31 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
32 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
33 "github.com/google/s2a-go/internal/record"
34 "github.com/google/s2a-go/internal/tokenmanager"
35 grpc "google.golang.org/grpc"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/grpclog"
39 )
40
41 var (
42
43 appProtocol = "grpc"
44
45 frameLimit = 1024 * 64
46
47 errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
48 )
49
50
51 type Handshaker interface {
52
53
54 ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
55
56
57 ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
58
59
60 Close() error
61 }
62
63
64
65 type ClientHandshakerOptions struct {
66
67 MinTLSVersion commonpb.TLSVersion
68
69 MaxTLSVersion commonpb.TLSVersion
70
71
72 TLSCiphersuites []commonpb.Ciphersuite
73
74
75
76 TargetIdentities []*commonpb.Identity
77
78
79 LocalIdentity *commonpb.Identity
80
81
82 TargetName string
83
84
85 EnsureProcessSessionTickets *sync.WaitGroup
86 }
87
88
89
90 type ServerHandshakerOptions struct {
91
92 MinTLSVersion commonpb.TLSVersion
93
94 MaxTLSVersion commonpb.TLSVersion
95
96
97 TLSCiphersuites []commonpb.Ciphersuite
98
99
100
101 LocalIdentities []*commonpb.Identity
102 }
103
104
105 type s2aHandshaker struct {
106
107 stream s2apb.S2AService_SetUpSessionClient
108
109 conn net.Conn
110
111 clientOpts *ClientHandshakerOptions
112
113 serverOpts *ServerHandshakerOptions
114
115 isClient bool
116
117 hsAddr string
118
119 tokenManager tokenmanager.AccessTokenManager
120
121
122
123 localIdentities []*commonpb.Identity
124 }
125
126
127
128 func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
129 stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
130 if err != nil {
131 return nil, err
132 }
133 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
134 if err != nil {
135 grpclog.Infof("failed to create single token access token manager: %v", err)
136 }
137 return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
138 }
139
140 func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
141 var localIdentities []*commonpb.Identity
142 if opts != nil {
143 localIdentities = []*commonpb.Identity{opts.LocalIdentity}
144 }
145 return &s2aHandshaker{
146 stream: stream,
147 conn: c,
148 clientOpts: opts,
149 isClient: true,
150 hsAddr: hsAddr,
151 tokenManager: tokenManager,
152 localIdentities: localIdentities,
153 }
154 }
155
156
157
158 func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
159 stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
160 if err != nil {
161 return nil, err
162 }
163 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
164 if err != nil {
165 grpclog.Infof("failed to create single token access token manager: %v", err)
166 }
167 return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
168 }
169
170 func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
171 var localIdentities []*commonpb.Identity
172 if opts != nil {
173 localIdentities = opts.LocalIdentities
174 }
175 return &s2aHandshaker{
176 stream: stream,
177 conn: c,
178 serverOpts: opts,
179 isClient: false,
180 hsAddr: hsAddr,
181 tokenManager: tokenManager,
182 localIdentities: localIdentities,
183 }
184 }
185
186
187
188 func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
189 if !h.isClient {
190 return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
191 }
192
193 hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
194 if err != nil {
195
196 hostname = h.clientOpts.TargetName
197 }
198
199
200 req := &s2apb.SessionReq{
201 ReqOneof: &s2apb.SessionReq_ClientStart{
202 ClientStart: &s2apb.ClientSessionStartReq{
203 ApplicationProtocols: []string{appProtocol},
204 MinTlsVersion: h.clientOpts.MinTLSVersion,
205 MaxTlsVersion: h.clientOpts.MaxTLSVersion,
206 TlsCiphersuites: h.clientOpts.TLSCiphersuites,
207 TargetIdentities: h.clientOpts.TargetIdentities,
208 LocalIdentity: h.clientOpts.LocalIdentity,
209 TargetName: hostname,
210 },
211 },
212 AuthMechanisms: h.getAuthMechanisms(),
213 }
214 conn, result, err := h.setUpSession(req)
215 if err != nil {
216 return nil, nil, err
217 }
218 authInfo, err := authinfo.NewS2AAuthInfo(result)
219 if err != nil {
220 return nil, nil, err
221 }
222 return conn, authInfo, nil
223 }
224
225
226
227 func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
228 if h.isClient {
229 return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
230 }
231 p := make([]byte, frameLimit)
232 n, err := h.conn.Read(p)
233 if err != nil {
234 return nil, nil, err
235 }
236
237 req := &s2apb.SessionReq{
238 ReqOneof: &s2apb.SessionReq_ServerStart{
239 ServerStart: &s2apb.ServerSessionStartReq{
240 ApplicationProtocols: []string{appProtocol},
241 MinTlsVersion: h.serverOpts.MinTLSVersion,
242 MaxTlsVersion: h.serverOpts.MaxTLSVersion,
243 TlsCiphersuites: h.serverOpts.TLSCiphersuites,
244 LocalIdentities: h.serverOpts.LocalIdentities,
245 InBytes: p[:n],
246 },
247 },
248 AuthMechanisms: h.getAuthMechanisms(),
249 }
250 conn, result, err := h.setUpSession(req)
251 if err != nil {
252 return nil, nil, err
253 }
254 authInfo, err := authinfo.NewS2AAuthInfo(result)
255 if err != nil {
256 return nil, nil, err
257 }
258 return conn, authInfo, nil
259 }
260
261
262
263 func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
264 resp, err := h.accessHandshakerService(req)
265 if err != nil {
266 return nil, nil, err
267 }
268
269 if resp.GetStatus() != nil {
270 if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
271 return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
272 }
273 }
274
275
276 var extra []byte
277 if req.GetServerStart() != nil {
278 if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
279 return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
280 }
281 extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
282 }
283 result, extra, err := h.processUntilDone(resp, extra)
284 if err != nil {
285 return nil, nil, err
286 }
287 if result.GetLocalIdentity() == nil {
288 return nil, nil, errors.New("local identity must be populated in session result")
289 }
290
291
292 newConn, err := record.NewConn(&record.ConnParameters{
293 NetConn: h.conn,
294 Ciphersuite: result.GetState().GetTlsCiphersuite(),
295 TLSVersion: result.GetState().GetTlsVersion(),
296 InTrafficSecret: result.GetState().GetInKey(),
297 OutTrafficSecret: result.GetState().GetOutKey(),
298 UnusedBuf: extra,
299 InSequence: result.GetState().GetInSequence(),
300 OutSequence: result.GetState().GetOutSequence(),
301 HSAddr: h.hsAddr,
302 ConnectionID: result.GetState().GetConnectionId(),
303 LocalIdentity: result.GetLocalIdentity(),
304 EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
305 })
306 if err != nil {
307 return nil, nil, err
308 }
309 return newConn, result, nil
310 }
311
312 func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
313 if h.clientOpts == nil {
314 return nil
315 }
316 return h.clientOpts.EnsureProcessSessionTickets
317 }
318
319
320
321 func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
322 if err := h.stream.Send(req); err != nil {
323 return nil, err
324 }
325 resp, err := h.stream.Recv()
326 if err != nil {
327 return nil, err
328 }
329 return resp, nil
330 }
331
332
333
334
335 func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
336 for {
337 if len(resp.OutFrames) > 0 {
338 if _, err := h.conn.Write(resp.OutFrames); err != nil {
339 return nil, nil, err
340 }
341 }
342 if resp.Result != nil {
343 return resp.Result, unusedBytes, nil
344 }
345 buf := make([]byte, frameLimit)
346 n, err := h.conn.Read(buf)
347 if err != nil && err != io.EOF {
348 return nil, nil, err
349 }
350
351
352
353
354
355 if len(resp.OutFrames) == 0 && n == 0 {
356 return nil, nil, errPeerNotResponding
357 }
358
359
360 p := append(unusedBytes, buf[:n]...)
361
362 resp, err = h.accessHandshakerService(&s2apb.SessionReq{
363 ReqOneof: &s2apb.SessionReq_Next{
364 Next: &s2apb.SessionNextReq{
365 InBytes: p,
366 },
367 },
368 AuthMechanisms: h.getAuthMechanisms(),
369 })
370 if err != nil {
371 return nil, nil, err
372 }
373
374
375
376
377
378 if resp.GetLocalIdentity() != nil {
379 h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
380 }
381
382
383 if resp.GetBytesConsumed() > uint32(len(p)) {
384 return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
385 }
386 unusedBytes = p[resp.GetBytesConsumed():]
387 }
388 }
389
390
391
392
393 func (h *s2aHandshaker) Close() error {
394 return h.stream.CloseSend()
395 }
396
397 func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
398 if h.tokenManager == nil {
399 return nil
400 }
401
402
403
404 if len(h.localIdentities) == 0 {
405 token, err := h.tokenManager.DefaultToken()
406 if err != nil {
407 grpclog.Infof("unable to get token for empty local identity: %v", err)
408 return nil
409 }
410 return []*s2apb.AuthenticationMechanism{
411 {
412 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
413 Token: token,
414 },
415 },
416 }
417 }
418
419
420
421 var authMechanisms []*s2apb.AuthenticationMechanism
422 for _, localIdentity := range h.localIdentities {
423 token, err := h.tokenManager.Token(localIdentity)
424 if err != nil {
425 grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
426 continue
427 }
428
429 authMechanism := &s2apb.AuthenticationMechanism{
430 Identity: localIdentity,
431 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
432 Token: token,
433 },
434 }
435 authMechanisms = append(authMechanisms, authMechanism)
436 }
437 return authMechanisms
438 }
439
View as plain text