...

Source file src/github.com/google/s2a-go/internal/record/ticketsender.go

Documentation: github.com/google/s2a-go/internal/record

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package record
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"sync"
    25  	"time"
    26  
    27  	"github.com/google/s2a-go/internal/handshaker/service"
    28  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    29  	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
    30  	"github.com/google/s2a-go/internal/tokenmanager"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/grpclog"
    33  )
    34  
    35  // sessionTimeout is the timeout for creating a session with the S2A handshaker
    36  // service.
    37  const sessionTimeout = time.Second * 5
    38  
    39  // s2aTicketSender sends session tickets to the S2A handshaker service.
    40  type s2aTicketSender interface {
    41  	// sendTicketsToS2A sends the given session tickets to the S2A handshaker
    42  	// service.
    43  	sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
    44  }
    45  
    46  // ticketStream is the stream used to send and receive session information.
    47  type ticketStream interface {
    48  	Send(*s2apb.SessionReq) error
    49  	Recv() (*s2apb.SessionResp, error)
    50  }
    51  
    52  type ticketSender struct {
    53  	// hsAddr stores the address of the S2A handshaker service.
    54  	hsAddr string
    55  	// connectionID is the connection identifier that was created and sent by
    56  	// S2A at the end of a handshake.
    57  	connectionID uint64
    58  	// localIdentity is the local identity that was used by S2A during session
    59  	// setup and included in the session result.
    60  	localIdentity *commonpb.Identity
    61  	// tokenManager manages access tokens for authenticating to S2A.
    62  	tokenManager tokenmanager.AccessTokenManager
    63  	// ensureProcessSessionTickets allows users to wait and ensure that all
    64  	// available session tickets are sent to S2A before a process completes.
    65  	ensureProcessSessionTickets *sync.WaitGroup
    66  }
    67  
    68  // sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
    69  // service. This is done asynchronously and writes to the error logs if an error
    70  // occurs.
    71  func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
    72  	// Note that the goroutine is in the function rather than at the caller
    73  	// because the fake ticket sender used for testing must run synchronously
    74  	// so that the session tickets can be accessed from it after the tests have
    75  	// been run.
    76  	if t.ensureProcessSessionTickets != nil {
    77  		t.ensureProcessSessionTickets.Add(1)
    78  	}
    79  	go func() {
    80  		if err := func() error {
    81  			defer func() {
    82  				if t.ensureProcessSessionTickets != nil {
    83  					t.ensureProcessSessionTickets.Done()
    84  				}
    85  			}()
    86  			ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
    87  			defer cancel()
    88  			// The transportCreds only needs to be set when talking to S2AV2 and also
    89  			// if mTLS is required.
    90  			hsConn, err := service.Dial(ctx, t.hsAddr, nil)
    91  			if err != nil {
    92  				return err
    93  			}
    94  			client := s2apb.NewS2AServiceClient(hsConn)
    95  			session, err := client.SetUpSession(ctx)
    96  			if err != nil {
    97  				return err
    98  			}
    99  			defer func() {
   100  				if err := session.CloseSend(); err != nil {
   101  					grpclog.Error(err)
   102  				}
   103  			}()
   104  			return t.writeTicketsToStream(session, sessionTickets)
   105  		}(); err != nil {
   106  			grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
   107  				t.localIdentity, err)
   108  		}
   109  		callComplete <- true
   110  		close(callComplete)
   111  	}()
   112  }
   113  
   114  // writeTicketsToStream writes the given session tickets to the given stream.
   115  func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
   116  	if err := stream.Send(
   117  		&s2apb.SessionReq{
   118  			ReqOneof: &s2apb.SessionReq_ResumptionTicket{
   119  				ResumptionTicket: &s2apb.ResumptionTicketReq{
   120  					InBytes:       sessionTickets,
   121  					ConnectionId:  t.connectionID,
   122  					LocalIdentity: t.localIdentity,
   123  				},
   124  			},
   125  			AuthMechanisms: t.getAuthMechanisms(),
   126  		},
   127  	); err != nil {
   128  		return err
   129  	}
   130  	sessionResp, err := stream.Recv()
   131  	if err != nil {
   132  		return err
   133  	}
   134  	if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
   135  		return fmt.Errorf("s2a session ticket response had error status: %v, %v",
   136  			sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
   137  	}
   138  	return nil
   139  }
   140  
   141  func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
   142  	if t.tokenManager == nil {
   143  		return nil
   144  	}
   145  	// First handle the special case when no local identity has been provided
   146  	// by the application. In this case, an AuthenticationMechanism with no local
   147  	// identity will be sent.
   148  	if t.localIdentity == nil {
   149  		token, err := t.tokenManager.DefaultToken()
   150  		if err != nil {
   151  			grpclog.Infof("unable to get token for empty local identity: %v", err)
   152  			return nil
   153  		}
   154  		return []*s2apb.AuthenticationMechanism{
   155  			{
   156  				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   157  					Token: token,
   158  				},
   159  			},
   160  		}
   161  	}
   162  
   163  	// Next, handle the case where the application (or the S2A) has specified
   164  	// a local identity.
   165  	token, err := t.tokenManager.Token(t.localIdentity)
   166  	if err != nil {
   167  		grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
   168  		return nil
   169  	}
   170  	return []*s2apb.AuthenticationMechanism{
   171  		{
   172  			Identity: t.localIdentity,
   173  			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
   174  				Token: token,
   175  			},
   176  		},
   177  	}
   178  }
   179  

View as plain text