1
18
19 package service
20
21 import (
22 "errors"
23 "os"
24 "strings"
25 "testing"
26
27 "github.com/google/go-cmp/cmp"
28 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
29 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
30 "google.golang.org/grpc"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/protobuf/testing/protocmp"
33 )
34
35 const (
36 testAccessToken = "test_access_token"
37 )
38
39 type fakeS2ASetupSessionServer struct {
40 grpc.ServerStream
41 recvCount int
42 reqs []*s2apb.SessionReq
43 resps []*s2apb.SessionResp
44 }
45
46 func (f *fakeS2ASetupSessionServer) Send(resp *s2apb.SessionResp) error {
47 f.resps = append(f.resps, resp)
48 return nil
49 }
50
51 func (f *fakeS2ASetupSessionServer) Recv() (*s2apb.SessionReq, error) {
52 if f.recvCount == len(f.reqs) {
53 return nil, errors.New("request buffer was fully exhausted")
54 }
55 req := f.reqs[f.recvCount]
56 f.recvCount++
57 return req, nil
58 }
59
60 func TestSetupSession(t *testing.T) {
61 os.Setenv(accessTokenEnvVariable, "")
62 for _, tc := range []struct {
63 desc string
64
65 reqs []*s2apb.SessionReq
66 outResps []*s2apb.SessionResp
67 hasNonOKStatus bool
68 }{
69 {
70 desc: "client failure no app protocols",
71 reqs: []*s2apb.SessionReq{
72 {
73 ReqOneof: &s2apb.SessionReq_ClientStart{
74 ClientStart: &s2apb.ClientSessionStartReq{},
75 },
76 },
77 },
78 hasNonOKStatus: true,
79 },
80 {
81 desc: "client failure non initial state",
82 reqs: []*s2apb.SessionReq{
83 {
84 ReqOneof: &s2apb.SessionReq_ClientStart{
85 ClientStart: &s2apb.ClientSessionStartReq{
86 ApplicationProtocols: []string{grpcAppProtocol},
87 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
88 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
89 TlsCiphersuites: []commonpb.Ciphersuite{
90 commonpb.Ciphersuite_AES_128_GCM_SHA256,
91 commonpb.Ciphersuite_AES_256_GCM_SHA384,
92 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
93 },
94 },
95 },
96 },
97 {
98 ReqOneof: &s2apb.SessionReq_ClientStart{
99 ClientStart: &s2apb.ClientSessionStartReq{
100 ApplicationProtocols: []string{grpcAppProtocol},
101 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
102 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
103 TlsCiphersuites: []commonpb.Ciphersuite{
104 commonpb.Ciphersuite_AES_128_GCM_SHA256,
105 commonpb.Ciphersuite_AES_256_GCM_SHA384,
106 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
107 },
108 },
109 },
110 },
111 },
112 outResps: []*s2apb.SessionResp{
113 {
114 OutFrames: []byte(clientHelloFrame),
115 Status: &s2apb.SessionStatus{
116 Code: uint32(codes.OK),
117 },
118 },
119 },
120 hasNonOKStatus: true,
121 },
122 {
123 desc: "client test",
124 reqs: []*s2apb.SessionReq{
125 {
126 ReqOneof: &s2apb.SessionReq_ClientStart{
127 ClientStart: &s2apb.ClientSessionStartReq{
128 ApplicationProtocols: []string{grpcAppProtocol},
129 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
130 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
131 TlsCiphersuites: []commonpb.Ciphersuite{
132 commonpb.Ciphersuite_AES_128_GCM_SHA256,
133 commonpb.Ciphersuite_AES_256_GCM_SHA384,
134 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
135 },
136 LocalIdentity: &commonpb.Identity{
137 IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
138 },
139 TargetIdentities: []*commonpb.Identity{
140 {
141 IdentityOneof: &commonpb.Identity_SpiffeId{SpiffeId: "peer spiffe identity"},
142 },
143 },
144 },
145 },
146 },
147 {
148 ReqOneof: &s2apb.SessionReq_Next{
149 Next: &s2apb.SessionNextReq{
150 InBytes: []byte(serverFrame),
151 },
152 },
153 },
154 },
155 outResps: []*s2apb.SessionResp{
156 {
157 OutFrames: []byte(clientHelloFrame),
158 Status: &s2apb.SessionStatus{
159 Code: uint32(codes.OK),
160 },
161 },
162 {
163 OutFrames: []byte(clientFinishedFrame),
164 BytesConsumed: uint32(len(serverFrame)),
165 Result: &s2apb.SessionResult{
166 ApplicationProtocol: grpcAppProtocol,
167 State: &s2apb.SessionState{
168 TlsVersion: commonpb.TLSVersion_TLS1_3,
169 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
170 InKey: []byte(inKey),
171 OutKey: []byte(outKey),
172 },
173 PeerIdentity: &commonpb.Identity{
174 IdentityOneof: &commonpb.Identity_SpiffeId{SpiffeId: "peer spiffe identity"},
175 },
176 LocalIdentity: &commonpb.Identity{
177 IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
178 },
179 },
180 Status: &s2apb.SessionStatus{
181 Code: uint32(codes.OK),
182 },
183 },
184 },
185 },
186 {
187 desc: "server failure no app protocols",
188 reqs: []*s2apb.SessionReq{
189 {
190 ReqOneof: &s2apb.SessionReq_ServerStart{
191 ServerStart: &s2apb.ServerSessionStartReq{},
192 },
193 },
194 },
195 hasNonOKStatus: true,
196 },
197 {
198 desc: "server failure non initial state",
199 reqs: []*s2apb.SessionReq{
200 {
201 ReqOneof: &s2apb.SessionReq_ServerStart{
202 ServerStart: &s2apb.ServerSessionStartReq{
203 ApplicationProtocols: []string{grpcAppProtocol},
204 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
205 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
206 TlsCiphersuites: []commonpb.Ciphersuite{
207 commonpb.Ciphersuite_AES_128_GCM_SHA256,
208 commonpb.Ciphersuite_AES_256_GCM_SHA384,
209 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
210 },
211 },
212 },
213 },
214 {
215 ReqOneof: &s2apb.SessionReq_ServerStart{
216 ServerStart: &s2apb.ServerSessionStartReq{
217 ApplicationProtocols: []string{grpcAppProtocol},
218 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
219 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
220 TlsCiphersuites: []commonpb.Ciphersuite{
221 commonpb.Ciphersuite_AES_128_GCM_SHA256,
222 commonpb.Ciphersuite_AES_256_GCM_SHA384,
223 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
224 },
225 },
226 },
227 },
228 },
229 outResps: []*s2apb.SessionResp{
230 {
231 Status: &s2apb.SessionStatus{
232 Code: uint32(codes.OK),
233 },
234 },
235 },
236 hasNonOKStatus: true,
237 },
238 {
239 desc: "server test",
240 reqs: []*s2apb.SessionReq{
241 {
242 ReqOneof: &s2apb.SessionReq_ServerStart{
243 ServerStart: &s2apb.ServerSessionStartReq{
244 ApplicationProtocols: []string{grpcAppProtocol},
245 MinTlsVersion: commonpb.TLSVersion_TLS1_3,
246 MaxTlsVersion: commonpb.TLSVersion_TLS1_3,
247 TlsCiphersuites: []commonpb.Ciphersuite{
248 commonpb.Ciphersuite_AES_128_GCM_SHA256,
249 commonpb.Ciphersuite_AES_256_GCM_SHA384,
250 commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
251 },
252 InBytes: []byte(clientHelloFrame),
253 LocalIdentities: []*commonpb.Identity{
254 {
255 IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
256 },
257 },
258 },
259 },
260 },
261 {
262 ReqOneof: &s2apb.SessionReq_Next{
263 Next: &s2apb.SessionNextReq{
264 InBytes: []byte(clientFinishedFrame),
265 },
266 },
267 },
268 },
269 outResps: []*s2apb.SessionResp{
270 {
271 OutFrames: []byte(serverFrame),
272 BytesConsumed: uint32(len(clientHelloFrame)),
273 Status: &s2apb.SessionStatus{
274 Code: uint32(codes.OK),
275 },
276 },
277 {
278 BytesConsumed: uint32(len(clientFinishedFrame)),
279 Result: &s2apb.SessionResult{
280 ApplicationProtocol: grpcAppProtocol,
281 State: &s2apb.SessionState{
282 TlsVersion: commonpb.TLSVersion_TLS1_3,
283 TlsCiphersuite: commonpb.Ciphersuite_AES_128_GCM_SHA256,
284 InKey: []byte(inKey),
285 OutKey: []byte(outKey),
286 },
287 LocalIdentity: &commonpb.Identity{
288 IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
289 },
290 },
291 Status: &s2apb.SessionStatus{
292 Code: uint32(codes.OK),
293 },
294 },
295 },
296 },
297 {
298 desc: "resumption ticket test",
299 reqs: []*s2apb.SessionReq{
300 {
301 ReqOneof: &s2apb.SessionReq_ResumptionTicket{
302 ResumptionTicket: &s2apb.ResumptionTicketReq{
303 ConnectionId: 1234,
304 LocalIdentity: &commonpb.Identity{
305 IdentityOneof: &commonpb.Identity_Hostname{Hostname: "local hostname"},
306 },
307 },
308 },
309 },
310 },
311 outResps: []*s2apb.SessionResp{
312 {
313 Status: &s2apb.SessionStatus{
314 Code: uint32(codes.OK),
315 },
316 },
317 },
318 hasNonOKStatus: false,
319 },
320 } {
321 t.Run(tc.desc, func(t *testing.T) {
322 hs := FakeHandshakerService{}
323 stream := &fakeS2ASetupSessionServer{reqs: tc.reqs}
324 if got, want := hs.SetUpSession(stream) == nil, !tc.hasNonOKStatus; got != want {
325 t.Errorf("hs.SetUpSession(%v) = (err=nil) = %v, want %v", stream, got, want)
326 }
327 hasNonOKStatus := false
328 for i := range tc.reqs {
329 if stream.resps[i].GetStatus().GetCode() != uint32(codes.OK) {
330 hasNonOKStatus = true
331 break
332 }
333 if got, want := stream.resps[i], tc.outResps[i]; !cmp.Equal(got, want, protocmp.Transform()) {
334 t.Fatalf("stream.resps[%d] = %v, want %v", i, got, want)
335 }
336 }
337 if got, want := hasNonOKStatus, tc.hasNonOKStatus; got != want {
338 t.Errorf("hasNonOKStatus = %v, want %v", got, want)
339 }
340 })
341 }
342 }
343
344 func TestAuthenticateRequest(t *testing.T) {
345 for _, tc := range []struct {
346 description string
347 acceptedToken string
348 request *s2apb.SessionReq
349 expectedError string
350 }{
351 {
352 description: "access token env variable is not set",
353 },
354 {
355 description: "request contains valid token",
356 acceptedToken: testAccessToken,
357 request: &s2apb.SessionReq{
358 AuthMechanisms: []*s2apb.AuthenticationMechanism{
359 {
360 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
361 Token: testAccessToken,
362 },
363 },
364 },
365 },
366 },
367 {
368 description: "request contains invalid token",
369 acceptedToken: testAccessToken,
370 request: &s2apb.SessionReq{
371 AuthMechanisms: []*s2apb.AuthenticationMechanism{
372 {
373 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
374 Token: "bad_access_token",
375 },
376 },
377 },
378 },
379 expectedError: "received token: bad_access_token, expected token: test_access_token",
380 },
381 {
382 description: "request contains valid and invalid tokens",
383 acceptedToken: testAccessToken,
384 request: &s2apb.SessionReq{
385 AuthMechanisms: []*s2apb.AuthenticationMechanism{
386 {
387 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
388 Token: testAccessToken,
389 },
390 },
391 {
392 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
393 Token: "bad_access_token",
394 },
395 },
396 },
397 },
398 expectedError: "received token: bad_access_token, expected token: test_access_token",
399 },
400 } {
401 t.Run(tc.description, func(t *testing.T) {
402 os.Setenv(accessTokenEnvVariable, tc.acceptedToken)
403 hs := &FakeHandshakerService{}
404 err := hs.authenticateRequest(tc.request)
405 if got, want := (err == nil), (tc.expectedError == ""); got != want {
406 t.Errorf("(err == nil): %t, (tc.expectedError == \"\"): %t", got, want)
407 }
408 if err != nil && !strings.Contains(err.Error(), tc.expectedError) {
409 t.Errorf("hs.authenticateRequest(%v)=%v, expected error to have substring: %v", tc.request, err, tc.expectedError)
410 }
411 })
412 }
413 }
414
View as plain text