1
18
19 package s2a
20
21 import (
22 "context"
23 "testing"
24
25 "github.com/google/go-cmp/cmp"
26 "google.golang.org/protobuf/testing/protocmp"
27
28 s2apb "github.com/google/s2a-go/internal/proto/common_go_proto"
29 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
30 )
31
32 func TestNewClientCreds(t *testing.T) {
33 for _, tc := range []struct {
34 desc string
35 opts *ClientOptions
36 outMinTLSVersion s2apb.TLSVersion
37 outMaxTLSVersion s2apb.TLSVersion
38 outTLSCiphersuites []s2apb.Ciphersuite
39 outLocalIdentity *s2apb.Identity
40 outTargetIdentities []*s2apb.Identity
41 outS2AAddress string
42 }{
43 {
44 desc: "only hostnames",
45 opts: &ClientOptions{
46 TargetIdentities: []Identity{
47 &hostname{"test_server_hostname"},
48 },
49 LocalIdentity: &hostname{"test_client_hostname"},
50 S2AAddress: "test_s2a_address",
51 EnableLegacyMode: true,
52 },
53 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
54 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
55 outTLSCiphersuites: []s2apb.Ciphersuite{
56 s2apb.Ciphersuite_AES_128_GCM_SHA256,
57 s2apb.Ciphersuite_AES_256_GCM_SHA384,
58 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
59 outTargetIdentities: []*s2apb.Identity{
60 {
61 IdentityOneof: &s2apb.Identity_Hostname{
62 Hostname: "test_server_hostname",
63 },
64 },
65 },
66 outLocalIdentity: &s2apb.Identity{
67 IdentityOneof: &s2apb.Identity_Hostname{
68 Hostname: "test_client_hostname",
69 },
70 },
71 outS2AAddress: "test_s2a_address",
72 },
73 {
74 desc: "only spiffe IDs",
75 opts: &ClientOptions{
76 TargetIdentities: []Identity{
77 &spiffeID{"test_server_spiffe_id"},
78 },
79 LocalIdentity: &spiffeID{"test_client_spiffe_id"},
80 S2AAddress: "test_s2a_address",
81 EnableLegacyMode: true,
82 },
83 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
84 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
85 outTLSCiphersuites: []s2apb.Ciphersuite{
86 s2apb.Ciphersuite_AES_128_GCM_SHA256,
87 s2apb.Ciphersuite_AES_256_GCM_SHA384,
88 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
89 outTargetIdentities: []*s2apb.Identity{
90 {
91 IdentityOneof: &s2apb.Identity_SpiffeId{
92 SpiffeId: "test_server_spiffe_id",
93 },
94 },
95 },
96 outLocalIdentity: &s2apb.Identity{
97 IdentityOneof: &s2apb.Identity_SpiffeId{
98 SpiffeId: "test_client_spiffe_id",
99 },
100 },
101 outS2AAddress: "test_s2a_address",
102 },
103 {
104 desc: "only UIDs",
105 opts: &ClientOptions{
106 TargetIdentities: []Identity{
107 &uid{"test_server_uid"},
108 },
109 LocalIdentity: &uid{"test_client_uid"},
110 S2AAddress: "test_s2a_address",
111 EnableLegacyMode: true,
112 },
113 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
114 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
115 outTLSCiphersuites: []s2apb.Ciphersuite{
116 s2apb.Ciphersuite_AES_128_GCM_SHA256,
117 s2apb.Ciphersuite_AES_256_GCM_SHA384,
118 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
119 outTargetIdentities: []*s2apb.Identity{
120 {
121 IdentityOneof: &s2apb.Identity_Uid{
122 Uid: "test_server_uid",
123 },
124 },
125 },
126 outLocalIdentity: &s2apb.Identity{
127 IdentityOneof: &s2apb.Identity_Uid{
128 Uid: "test_client_uid",
129 },
130 },
131 outS2AAddress: "test_s2a_address",
132 },
133 {
134 desc: "mixed identities",
135 opts: &ClientOptions{
136 TargetIdentities: []Identity{
137 &spiffeID{"test_server_spiffe_id"},
138 &hostname{"test_server_hostname"},
139 &uid{"test_server_uid"},
140 },
141 LocalIdentity: &spiffeID{"test_client_spiffe_id"},
142 S2AAddress: "test_s2a_address",
143 EnableLegacyMode: true,
144 },
145 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
146 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
147 outTLSCiphersuites: []s2apb.Ciphersuite{
148 s2apb.Ciphersuite_AES_128_GCM_SHA256,
149 s2apb.Ciphersuite_AES_256_GCM_SHA384,
150 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
151 outTargetIdentities: []*s2apb.Identity{
152 {
153 IdentityOneof: &s2apb.Identity_SpiffeId{
154 SpiffeId: "test_server_spiffe_id",
155 },
156 },
157 {
158 IdentityOneof: &s2apb.Identity_Hostname{
159 Hostname: "test_server_hostname",
160 },
161 },
162 {
163 IdentityOneof: &s2apb.Identity_Uid{
164 Uid: "test_server_uid",
165 },
166 },
167 },
168 outLocalIdentity: &s2apb.Identity{
169 IdentityOneof: &s2apb.Identity_SpiffeId{
170 SpiffeId: "test_client_spiffe_id",
171 },
172 },
173 outS2AAddress: "test_s2a_address",
174 },
175 } {
176 t.Run(tc.desc, func(t *testing.T) {
177 c, err := NewClientCreds(tc.opts)
178 if err != nil {
179 t.Fatalf("NewClientCreds(_) failed: %v", err)
180 }
181 if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
182 t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
183 }
184 s2aCreds, ok := c.(*s2aTransportCreds)
185 if !ok {
186 t.Fatal("The created creds is not of type s2aTransportCreds")
187 }
188 if got, want := s2aCreds.minTLSVersion, tc.outMinTLSVersion; got != want {
189 t.Errorf("s2aCreds.minTLSVersion = %v, want %v", got, want)
190 }
191 if got, want := s2aCreds.maxTLSVersion, tc.outMaxTLSVersion; got != want {
192 t.Errorf("s2aCreds.maxTLSVersion = %v, want %v", got, want)
193 }
194 if got, want := s2aCreds.tlsCiphersuites, tc.outTLSCiphersuites; !cmp.Equal(got, want) {
195 t.Errorf("s2aCreds.tlsCiphersuites = %v, want %v", got, want)
196 }
197 if got, want := s2aCreds.targetIdentities, tc.outTargetIdentities; !cmp.Equal(got, want, protocmp.Transform()) {
198 t.Errorf("s2aCreds.targetIdentities = %v, want %v", got, want)
199 }
200 if got, want := s2aCreds.localIdentity, tc.outLocalIdentity; !cmp.Equal(got, want, protocmp.Transform()) {
201 t.Errorf("s2aCreds.localIdentity = %v, want %v", got, want)
202 }
203 if got, want := s2aCreds.s2aAddr, tc.outS2AAddress; got != want {
204 t.Errorf("s2aCreds.s2aAddr = %v, want %v", got, want)
205 }
206 })
207 }
208 }
209
210 func TestNewServerCreds(t *testing.T) {
211 for _, tc := range []struct {
212 desc string
213 opts *ServerOptions
214 outMinTLSVersion s2apb.TLSVersion
215 outMaxTLSVersion s2apb.TLSVersion
216 outTLSCiphersuites []s2apb.Ciphersuite
217 outLocalIdentities []*s2apb.Identity
218 outS2AAddress string
219 }{
220 {
221 desc: "only hostnames",
222 opts: &ServerOptions{
223 LocalIdentities: []Identity{
224 &hostname{"test_server_hostname"},
225 },
226 S2AAddress: "test_s2a_address",
227 EnableLegacyMode: true,
228 },
229 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
230 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
231 outTLSCiphersuites: []s2apb.Ciphersuite{
232 s2apb.Ciphersuite_AES_128_GCM_SHA256,
233 s2apb.Ciphersuite_AES_256_GCM_SHA384,
234 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
235 outLocalIdentities: []*s2apb.Identity{
236 {
237 IdentityOneof: &s2apb.Identity_Hostname{
238 Hostname: "test_server_hostname",
239 },
240 },
241 },
242 outS2AAddress: "test_s2a_address",
243 },
244 {
245 desc: "only spiffe IDs",
246 opts: &ServerOptions{
247 LocalIdentities: []Identity{
248 &spiffeID{"test_server_spiffe_id"},
249 },
250 S2AAddress: "test_s2a_address",
251 EnableLegacyMode: true,
252 },
253 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
254 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
255 outTLSCiphersuites: []s2apb.Ciphersuite{
256 s2apb.Ciphersuite_AES_128_GCM_SHA256,
257 s2apb.Ciphersuite_AES_256_GCM_SHA384,
258 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
259 outLocalIdentities: []*s2apb.Identity{
260 {
261 IdentityOneof: &s2apb.Identity_SpiffeId{
262 SpiffeId: "test_server_spiffe_id",
263 },
264 },
265 },
266 outS2AAddress: "test_s2a_address",
267 },
268 {
269 desc: "only UIDs",
270 opts: &ServerOptions{
271 LocalIdentities: []Identity{
272 &uid{"test_server_uid"},
273 },
274 S2AAddress: "test_s2a_address",
275 EnableLegacyMode: true,
276 },
277 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
278 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
279 outTLSCiphersuites: []s2apb.Ciphersuite{
280 s2apb.Ciphersuite_AES_128_GCM_SHA256,
281 s2apb.Ciphersuite_AES_256_GCM_SHA384,
282 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
283 outLocalIdentities: []*s2apb.Identity{
284 {
285 IdentityOneof: &s2apb.Identity_Uid{
286 Uid: "test_server_uid",
287 },
288 },
289 },
290 outS2AAddress: "test_s2a_address",
291 },
292 {
293 desc: "mixed identities",
294 opts: &ServerOptions{
295 LocalIdentities: []Identity{
296 &spiffeID{"test_server_spiffe_id"},
297 &hostname{"test_server_hostname"},
298 &uid{"test_server_uid"},
299 },
300 S2AAddress: "test_s2a_address",
301 EnableLegacyMode: true,
302 },
303 outMinTLSVersion: s2apb.TLSVersion_TLS1_3,
304 outMaxTLSVersion: s2apb.TLSVersion_TLS1_3,
305 outTLSCiphersuites: []s2apb.Ciphersuite{
306 s2apb.Ciphersuite_AES_128_GCM_SHA256,
307 s2apb.Ciphersuite_AES_256_GCM_SHA384,
308 s2apb.Ciphersuite_CHACHA20_POLY1305_SHA256},
309 outLocalIdentities: []*s2apb.Identity{
310 {
311 IdentityOneof: &s2apb.Identity_SpiffeId{
312 SpiffeId: "test_server_spiffe_id",
313 },
314 },
315 {
316 IdentityOneof: &s2apb.Identity_Hostname{
317 Hostname: "test_server_hostname",
318 },
319 },
320 {
321 IdentityOneof: &s2apb.Identity_Uid{
322 Uid: "test_server_uid",
323 },
324 },
325 },
326 outS2AAddress: "test_s2a_address",
327 },
328 } {
329 t.Run(tc.desc, func(t *testing.T) {
330 c, err := NewServerCreds(tc.opts)
331 if err != nil {
332 t.Fatalf("NewServerCreds(_) failed: %v", err)
333 }
334 if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
335 t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
336 }
337 s2aCreds, ok := c.(*s2aTransportCreds)
338 if !ok {
339 t.Fatal("The created creds is not of type s2aTransportCreds")
340 }
341 if got, want := s2aCreds.minTLSVersion, tc.outMinTLSVersion; got != want {
342 t.Errorf("s2aCreds.minTLSVersion = %v, want %v", got, want)
343 }
344 if got, want := s2aCreds.maxTLSVersion, tc.outMaxTLSVersion; got != want {
345 t.Errorf("s2aCreds.maxTLSVersion = %v, want %v", got, want)
346 }
347 if got, want := s2aCreds.tlsCiphersuites, tc.outTLSCiphersuites; !cmp.Equal(got, want) {
348 t.Errorf("s2aCreds.tlsCiphersuites = %v, want %v", got, want)
349 }
350 if got, want := s2aCreds.localIdentities, tc.outLocalIdentities; !cmp.Equal(got, want, protocmp.Transform()) {
351 t.Errorf("s2aCreds.localIdentities = %v, want %v", got, want)
352 }
353 if got, want := s2aCreds.s2aAddr, tc.outS2AAddress; got != want {
354 t.Errorf("s2aCreds.s2aAddr = %v, want %v", got, want)
355 }
356 })
357 }
358 }
359
360 func TestHandshakeFail(t *testing.T) {
361 cc := &s2aTransportCreds{isClient: false}
362 if _, _, err := cc.ClientHandshake(context.Background(), "", nil); err == nil {
363 t.Errorf("c.ClientHandshake(nil, \"\", nil) should fail with incorrect transport credentials")
364 }
365 sc := &s2aTransportCreds{isClient: true}
366 if _, _, err := sc.ServerHandshake(nil); err == nil {
367 t.Errorf("c.ServerHandshake(nil) should fail with incorrect transport credentials")
368 }
369 }
370
371 func TestInfo(t *testing.T) {
372
373
374 c, err := NewServerCreds(&ServerOptions{})
375 if err != nil {
376 t.Fatalf("NewServerCreds(&ServerOptions{}) failed: %v", err)
377 }
378 info := c.Info()
379 if got, want := info.ProtocolVersion, ""; got != want {
380 t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
381 }
382 if got, want := info.SecurityProtocol, "tls"; got != want {
383 t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
384 }
385 if got, want := info.ServerName, ""; got != want {
386 t.Errorf("info.ServerName=%v, want %v", got, want)
387 }
388 }
389
390 func TestCloneClient(t *testing.T) {
391 opt := &ClientOptions{
392 TargetIdentities: []Identity{
393 &spiffeID{"test_server_spiffe_id"},
394 &hostname{"test_server_hostname"},
395 },
396 LocalIdentity: &hostname{"test_client_hostname"},
397 S2AAddress: "test_s2a_address",
398 EnableLegacyMode: true,
399 }
400 c, err := NewClientCreds(opt)
401 if err != nil {
402 t.Fatalf("NewClientCreds(%v) failed: %v", opt, err)
403 }
404 cc := c.Clone()
405 s2aCreds, ok := c.(*s2aTransportCreds)
406 if !ok {
407 t.Fatal("The created creds is not of type s2aTransportCreds")
408 }
409 s2aCloneCreds, ok := cc.(*s2aTransportCreds)
410 if !ok {
411 t.Fatal("The created cloned creds is not of type s2aTransportCreds")
412 }
413 if got, want := cmp.Equal(s2aCreds, s2aCloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2aTransportCreds{})), true; got != want {
414 t.Errorf("cmp.Equal(%v, %v) = %v, want %v", s2aCreds, s2aCloneCreds, got, want)
415 }
416
417 s2aCloneCreds.targetIdentities[0] = &s2apb.Identity{
418 IdentityOneof: &s2apb.Identity_SpiffeId{
419 SpiffeId: "new_spiffe_id",
420 },
421 }
422 if got, want := cmp.Equal(s2aCreds, s2aCloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2aTransportCreds{})), false; got != want {
423 t.Errorf("cmp.Equal(%v, %v) = %v, want %v", s2aCreds, s2aCloneCreds, got, want)
424 }
425 }
426
427 func TestCloneServer(t *testing.T) {
428 c, err := NewServerCreds(&ServerOptions{
429 LocalIdentities: []Identity{
430 &spiffeID{"test_server_spiffe_id"},
431 &hostname{"test_server_hostname"},
432 },
433 S2AAddress: "test_s2a_address",
434 EnableLegacyMode: true,
435 })
436 if err != nil {
437 t.Fatalf("NewServerCreds(&ServerOptions{}) failed: %v", err)
438 }
439 cc := c.Clone()
440 s2aCreds, ok := c.(*s2aTransportCreds)
441 if !ok {
442 t.Fatal("The created creds is not of type s2aTransportCreds")
443 }
444 s2aCloneCreds, ok := cc.(*s2aTransportCreds)
445 if !ok {
446 t.Fatal("The created cloned creds is not of type s2aTransportCreds")
447 }
448 if got, want := cmp.Equal(s2aCreds, s2aCloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2aTransportCreds{})), true; got != want {
449 t.Errorf("cmp.Equal(%v, %v) = %v, want %v", s2aCreds, s2aCloneCreds, got, want)
450 }
451
452 s2aCloneCreds.localIdentities[0] = &s2apb.Identity{
453 IdentityOneof: &s2apb.Identity_SpiffeId{
454 SpiffeId: "new_spiffe_id",
455 },
456 }
457 if got, want := cmp.Equal(s2aCreds, s2aCloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2aTransportCreds{})), false; got != want {
458 t.Errorf("cmp.Equal(%v, %v) = %v, want %v", s2aCreds, s2aCloneCreds, got, want)
459 }
460 }
461
462 func TestOverrideServerName(t *testing.T) {
463 wantServerName := "server.name"
464
465
466 c, err := NewServerCreds(&ServerOptions{})
467 if err != nil {
468 t.Fatalf("NewServerCreds(&ServerOptions{}) failed: %v", err)
469 }
470 if got, want := c.Info().ServerName, ""; got != want {
471 t.Errorf("c.Info().ServerName = %v, want %v", got, want)
472 }
473 if err := c.OverrideServerName(wantServerName); err != nil {
474 t.Fatalf("c.OverrideServerName(%v) failed: %v", wantServerName, err)
475 }
476 if got, want := c.Info().ServerName, wantServerName; got != want {
477 t.Errorf("c.Info().ServerName = %v, want %v", got, want)
478 }
479 }
480
481 func TestGetVerificationMode(t *testing.T) {
482 for _, tc := range []struct {
483 description string
484 verificationMode VerificationModeType
485 expVerificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
486 }{
487 {
488 description: "connect to google",
489 verificationMode: ConnectToGoogle,
490 expVerificationMode: s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE,
491 },
492 {
493 description: "spiffe",
494 verificationMode: Spiffe,
495 expVerificationMode: s2av2pb.ValidatePeerCertificateChainReq_SPIFFE,
496 },
497 {
498 description: "unspecified",
499 verificationMode: Unspecified,
500 expVerificationMode: s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED,
501 },
502 } {
503 t.Run(tc.description, func(t *testing.T) {
504 if got, want := getVerificationMode(tc.verificationMode), tc.expVerificationMode; got != want {
505 t.Errorf("got = %v, want = %v", got, want)
506 }
507 })
508 }
509 }
510
View as plain text