1
18
19
20 package service
21
22 import (
23 "bytes"
24 "fmt"
25 "os"
26
27 "google.golang.org/grpc/codes"
28
29 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
30 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
31 )
32
33 type handshakeState int
34
35 const (
36
37
38 initial handshakeState = 0
39
40
41 started handshakeState = 1
42
43
44 sent handshakeState = 2
45
46
47 completed handshakeState = 3
48 )
49
50 const (
51 accessTokenEnvVariable = "S2A_ACCESS_TOKEN"
52 grpcAppProtocol = "grpc"
53 clientHelloFrame = "ClientHello"
54 clientFinishedFrame = "ClientFinished"
55 serverFrame = "ServerHelloAndFinished"
56 )
57
58 const (
59 inKey = "kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk"
60 outKey = "kkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk"
61 )
62
63
64
65 type FakeHandshakerService struct {
66 s2apb.S2AServiceServer
67
68 assistingClient bool
69 state handshakeState
70 peerIdentity *commonpb.Identity
71 localIdentity *commonpb.Identity
72 }
73
74
75 func (hs *FakeHandshakerService) SetUpSession(stream s2apb.S2AService_SetUpSessionServer) error {
76 for {
77 sessionReq, err := stream.Recv()
78 if err != nil {
79 return fmt.Errorf("stream recv failed: %v", err)
80 }
81 if err := hs.authenticateRequest(sessionReq); err != nil {
82 return fmt.Errorf("S2A cannot authenticate the request: %v", err)
83 }
84
85 var resp *s2apb.SessionResp
86 receivedTicket := false
87 switch req := sessionReq.ReqOneof.(type) {
88 case *s2apb.SessionReq_ClientStart:
89 resp = hs.processClientStart(req)
90 case *s2apb.SessionReq_ServerStart:
91 resp = hs.processServerStart(req)
92 case *s2apb.SessionReq_Next:
93 resp = hs.processNext(req)
94 case *s2apb.SessionReq_ResumptionTicket:
95 resp = hs.processResumptionTicket(req)
96 receivedTicket = true
97 default:
98 return fmt.Errorf("session request has unexpected type %T", req)
99 }
100
101 if err = stream.Send(resp); err != nil {
102 return fmt.Errorf("stream send failed: %v", err)
103 }
104
105 if receivedTicket || resp.GetResult() != nil {
106 return nil
107 }
108 }
109 }
110
111
112 func (hs *FakeHandshakerService) processClientStart(req *s2apb.SessionReq_ClientStart) *s2apb.SessionResp {
113 resp := s2apb.SessionResp{}
114 if hs.state != initial {
115 resp.Status = &s2apb.SessionStatus{
116 Code: uint32(codes.FailedPrecondition),
117 Details: "client start handshake not in initial state",
118 }
119 return &resp
120 }
121 if len(req.ClientStart.GetApplicationProtocols()) != 1 ||
122 req.ClientStart.GetApplicationProtocols()[0] != grpcAppProtocol {
123 resp.Status = &s2apb.SessionStatus{
124 Code: uint32(codes.InvalidArgument),
125 Details: "application protocol was not grpc",
126 }
127 return &resp
128 }
129 if req.ClientStart.GetMaxTlsVersion() != commonpb.TLSVersion_TLS1_3 {
130 resp.Status = &s2apb.SessionStatus{
131 Code: uint32(codes.InvalidArgument),
132 Details: "max TLS version must be 1.3",
133 }
134 return &resp
135 }
136 if req.ClientStart.GetMinTlsVersion() != commonpb.TLSVersion_TLS1_3 {
137 resp.Status = &s2apb.SessionStatus{
138 Code: uint32(codes.InvalidArgument),
139 Details: "min TLS version must be 1.3",
140 }
141 return &resp
142 }
143 resp.OutFrames = []byte(clientHelloFrame)
144 resp.BytesConsumed = 0
145 resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
146 hs.localIdentity = req.ClientStart.LocalIdentity
147 if len(req.ClientStart.TargetIdentities) > 0 {
148 hs.peerIdentity = req.ClientStart.TargetIdentities[0]
149 }
150 hs.assistingClient = true
151 hs.state = sent
152 return &resp
153 }
154
155
156 func (hs *FakeHandshakerService) processServerStart(req *s2apb.SessionReq_ServerStart) *s2apb.SessionResp {
157 resp := s2apb.SessionResp{}
158 if hs.state != initial {
159 resp.Status = &s2apb.SessionStatus{
160 Code: uint32(codes.FailedPrecondition),
161 Details: "server start handshake not in initial state",
162 }
163 return &resp
164 }
165 if len(req.ServerStart.GetApplicationProtocols()) != 1 ||
166 req.ServerStart.GetApplicationProtocols()[0] != grpcAppProtocol {
167 resp.Status = &s2apb.SessionStatus{
168 Code: uint32(codes.InvalidArgument),
169 Details: "application protocol was not grpc",
170 }
171 return &resp
172 }
173 if req.ServerStart.GetMaxTlsVersion() != commonpb.TLSVersion_TLS1_3 {
174 resp.Status = &s2apb.SessionStatus{
175 Code: uint32(codes.InvalidArgument),
176 Details: "max TLS version must be 1.3",
177 }
178 return &resp
179 }
180 if req.ServerStart.GetMinTlsVersion() != commonpb.TLSVersion_TLS1_3 {
181 resp.Status = &s2apb.SessionStatus{
182 Code: uint32(codes.InvalidArgument),
183 Details: "min TLS version must be 1.3",
184 }
185 return &resp
186 }
187 if len(req.ServerStart.InBytes) == 0 {
188 resp.BytesConsumed = 0
189 hs.state = started
190 } else if bytes.Equal(req.ServerStart.InBytes, []byte(clientHelloFrame)) {
191 resp.OutFrames = []byte(serverFrame)
192 resp.BytesConsumed = uint32(len(clientHelloFrame))
193 hs.state = sent
194 } else {
195 resp.Status = &s2apb.SessionStatus{
196 Code: uint32(codes.Internal),
197 Details: "server start request did not have the correct input bytes",
198 }
199 return &resp
200 }
201
202 resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
203 if len(req.ServerStart.LocalIdentities) > 0 {
204 hs.localIdentity = req.ServerStart.LocalIdentities[0]
205 }
206 hs.assistingClient = false
207 return &resp
208 }
209
210
211 func (hs *FakeHandshakerService) processNext(req *s2apb.SessionReq_Next) *s2apb.SessionResp {
212 resp := s2apb.SessionResp{}
213 if hs.assistingClient {
214 if hs.state != sent {
215 resp.Status = &s2apb.SessionStatus{
216 Code: uint32(codes.FailedPrecondition),
217 Details: "client handshake was not in sent state",
218 }
219 return &resp
220 }
221 if !bytes.Equal(req.Next.InBytes, []byte(serverFrame)) {
222 resp.Status = &s2apb.SessionStatus{
223 Code: uint32(codes.Internal),
224 Details: "client request did not match server frame",
225 }
226 return &resp
227 }
228 resp.OutFrames = []byte(clientFinishedFrame)
229 resp.BytesConsumed = uint32(len(serverFrame))
230 hs.state = completed
231 } else {
232 if hs.state == started {
233 if !bytes.Equal(req.Next.InBytes, []byte(clientHelloFrame)) {
234 resp.Status = &s2apb.SessionStatus{
235 Code: uint32(codes.Internal),
236 Details: "server request did not match client hello frame",
237 }
238 return &resp
239 }
240 resp.OutFrames = []byte(serverFrame)
241 resp.BytesConsumed = uint32(len(clientHelloFrame))
242 hs.state = sent
243 } else if hs.state == sent {
244 if !bytes.Equal(req.Next.InBytes[:len(clientFinishedFrame)], []byte(clientFinishedFrame)) {
245 resp.Status = &s2apb.SessionStatus{
246 Code: uint32(codes.Internal),
247 Details: "server request did not match client finished frame",
248 }
249 return &resp
250 }
251 resp.BytesConsumed = uint32(len(clientFinishedFrame))
252 hs.state = completed
253 } else {
254 resp.Status = &s2apb.SessionStatus{
255 Code: uint32(codes.FailedPrecondition),
256 Details: "server request was not in expected state",
257 }
258 return &resp
259 }
260 }
261 resp.Status = &s2apb.SessionStatus{Code: uint32(codes.OK)}
262 if hs.state == completed {
263 resp.Result = hs.getSessionResult()
264 }
265 return &resp
266 }
267
268
269 func (hs *FakeHandshakerService) processResumptionTicket(req *s2apb.SessionReq_ResumptionTicket) *s2apb.SessionResp {
270 return &s2apb.SessionResp{
271 Status: &s2apb.SessionStatus{Code: uint32(codes.OK)},
272 }
273 }
274
275
276 func (hs *FakeHandshakerService) getSessionResult() *s2apb.SessionResult {
277 res := s2apb.SessionResult{}
278 res.ApplicationProtocol = grpcAppProtocol
279 res.State = &s2apb.SessionState{
280 TlsVersion: commonpb.TLSVersion_TLS1_3,
281 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
282 InKey: []byte(inKey),
283 OutKey: []byte(outKey),
284 }
285 res.PeerIdentity = hs.peerIdentity
286 res.LocalIdentity = hs.localIdentity
287 return &res
288 }
289
290 func (hs *FakeHandshakerService) authenticateRequest(request *s2apb.SessionReq) error {
291
292
293 acceptedToken := os.Getenv(accessTokenEnvVariable)
294 if acceptedToken == "" {
295 return nil
296 }
297 if len(request.GetAuthMechanisms()) == 0 {
298 return fmt.Errorf("expected token but none was received")
299 }
300 for _, authMechanism := range request.GetAuthMechanisms() {
301 if authMechanism.GetToken() != acceptedToken {
302 return fmt.Errorf("received token: %s, expected token: %s", authMechanism.GetToken(), acceptedToken)
303 }
304 }
305 return nil
306 }
307
View as plain text