...
1
18
19 package record
20
21 import (
22 "context"
23 "fmt"
24 "sync"
25 "time"
26
27 "github.com/google/s2a-go/internal/handshaker/service"
28 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
29 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
30 "github.com/google/s2a-go/internal/tokenmanager"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/grpclog"
33 )
34
35
36
37 const sessionTimeout = time.Second * 5
38
39
40 type s2aTicketSender interface {
41
42
43 sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
44 }
45
46
47 type ticketStream interface {
48 Send(*s2apb.SessionReq) error
49 Recv() (*s2apb.SessionResp, error)
50 }
51
52 type ticketSender struct {
53
54 hsAddr string
55
56
57 connectionID uint64
58
59
60 localIdentity *commonpb.Identity
61
62 tokenManager tokenmanager.AccessTokenManager
63
64
65 ensureProcessSessionTickets *sync.WaitGroup
66 }
67
68
69
70
71 func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
72
73
74
75
76 if t.ensureProcessSessionTickets != nil {
77 t.ensureProcessSessionTickets.Add(1)
78 }
79 go func() {
80 if err := func() error {
81 defer func() {
82 if t.ensureProcessSessionTickets != nil {
83 t.ensureProcessSessionTickets.Done()
84 }
85 }()
86 ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
87 defer cancel()
88
89
90 hsConn, err := service.Dial(ctx, t.hsAddr, nil)
91 if err != nil {
92 return err
93 }
94 client := s2apb.NewS2AServiceClient(hsConn)
95 session, err := client.SetUpSession(ctx)
96 if err != nil {
97 return err
98 }
99 defer func() {
100 if err := session.CloseSend(); err != nil {
101 grpclog.Error(err)
102 }
103 }()
104 return t.writeTicketsToStream(session, sessionTickets)
105 }(); err != nil {
106 grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
107 t.localIdentity, err)
108 }
109 callComplete <- true
110 close(callComplete)
111 }()
112 }
113
114
115 func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
116 if err := stream.Send(
117 &s2apb.SessionReq{
118 ReqOneof: &s2apb.SessionReq_ResumptionTicket{
119 ResumptionTicket: &s2apb.ResumptionTicketReq{
120 InBytes: sessionTickets,
121 ConnectionId: t.connectionID,
122 LocalIdentity: t.localIdentity,
123 },
124 },
125 AuthMechanisms: t.getAuthMechanisms(),
126 },
127 ); err != nil {
128 return err
129 }
130 sessionResp, err := stream.Recv()
131 if err != nil {
132 return err
133 }
134 if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
135 return fmt.Errorf("s2a session ticket response had error status: %v, %v",
136 sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
137 }
138 return nil
139 }
140
141 func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
142 if t.tokenManager == nil {
143 return nil
144 }
145
146
147
148 if t.localIdentity == nil {
149 token, err := t.tokenManager.DefaultToken()
150 if err != nil {
151 grpclog.Infof("unable to get token for empty local identity: %v", err)
152 return nil
153 }
154 return []*s2apb.AuthenticationMechanism{
155 {
156 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
157 Token: token,
158 },
159 },
160 }
161 }
162
163
164
165 token, err := t.tokenManager.Token(t.localIdentity)
166 if err != nil {
167 grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
168 return nil
169 }
170 return []*s2apb.AuthenticationMechanism{
171 {
172 Identity: t.localIdentity,
173 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
174 Token: token,
175 },
176 },
177 }
178 }
179
View as plain text