1
18
19 package v2
20
21 import (
22 "context"
23 "os"
24 "reflect"
25 "testing"
26 "time"
27
28 "github.com/google/go-cmp/cmp"
29 "github.com/google/s2a-go/fallback"
30 "github.com/google/s2a-go/internal/tokenmanager"
31 "github.com/google/s2a-go/stream"
32 "google.golang.org/protobuf/testing/protocmp"
33
34 commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
35 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
36 )
37
38 var (
39 fakes2av2Address = "0.0.0.0:0"
40 )
41
42 func TestNewClientCreds(t *testing.T) {
43 os.Setenv("S2A_ACCESS_TOKEN", "TestNewClientCreds_s2a_access_token")
44 for _, tc := range []struct {
45 description string
46 }{
47 {
48 description: "static",
49 },
50 } {
51 t.Run(tc.description, func(t *testing.T) {
52 c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
53 IdentityOneof: &commonpbv1.Identity_Hostname{
54 Hostname: "test_rsa_client_identity",
55 },
56 }, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
57 if err != nil {
58 t.Fatalf("NewClientCreds() failed: %v", err)
59 }
60 if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
61 t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
62 }
63 _, ok := c.(*s2av2TransportCreds)
64 if !ok {
65 t.Fatal("The created creds is not of type s2av2TransportCreds")
66 }
67 })
68 }
69 }
70
71 func TestNewServerCreds(t *testing.T) {
72 os.Setenv("S2A_ACCESS_TOKEN", "TestNewServerCreds_s2a_access_token")
73 for _, tc := range []struct {
74 description string
75 }{
76 {
77 description: "static",
78 },
79 } {
80 t.Run(tc.description, func(t *testing.T) {
81 localIdentities := []*commonpbv1.Identity{
82 {
83 IdentityOneof: &commonpbv1.Identity_Hostname{
84 Hostname: "test_rsa_server_identity",
85 },
86 },
87 }
88 c, err := NewServerCreds(fakes2av2Address, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
89 if err != nil {
90 t.Fatalf("NewServerCreds() failed: %v", err)
91 }
92 if got, want := c.Info().SecurityProtocol, s2aSecurityProtocol; got != want {
93 t.Errorf("c.Info().SecurityProtocol = %v, want %v", got, want)
94 }
95 _, ok := c.(*s2av2TransportCreds)
96 if !ok {
97 t.Fatal("The created creds is not of type s2av2TransportCreds")
98 }
99 })
100 }
101 }
102
103 func TestClientHandshakeFail(t *testing.T) {
104 cc := &s2av2TransportCreds{isClient: false}
105 if _, _, err := cc.ClientHandshake(context.Background(), "", nil); err == nil {
106 t.Errorf("c.ClientHandshake(nil, \"\", nil) should fail with incorrect transport credentials")
107 }
108 }
109
110 func TestServerHandshakeFail(t *testing.T) {
111 sc := &s2av2TransportCreds{isClient: true}
112 if _, _, err := sc.ServerHandshake(nil); err == nil {
113 t.Errorf("c.ServerHandshake(nil) should fail with incorrect transport credentials")
114 }
115 }
116
117 func TestInfo(t *testing.T) {
118 os.Setenv("S2A_ACCESS_TOKEN", "TestInfo_s2a_access_token")
119 c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
120 IdentityOneof: &commonpbv1.Identity_Hostname{
121 Hostname: "test_rsa_client_identity",
122 },
123 }, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
124 if err != nil {
125 t.Fatalf("NewClientCreds() failed: %v", err)
126 }
127 info := c.Info()
128 if got, want := info.SecurityProtocol, "tls"; got != want {
129 t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
130 }
131 }
132
133 func TestCloneClient(t *testing.T) {
134 os.Setenv("S2A_ACCESS_TOKEN", "TestCloneClient_s2a_access_token")
135 fallbackFunc, err := fallback.DefaultFallbackClientHandshakeFunc("example.com")
136 if err != nil {
137 t.Errorf("error creating fallback handshake function: %v", err)
138 }
139 c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
140 IdentityOneof: &commonpbv1.Identity_Hostname{
141 Hostname: "test_rsa_client_identity",
142 },
143 }, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, fallbackFunc, nil, nil)
144 if err != nil {
145 t.Fatalf("NewClientCreds() failed: %v", err)
146 }
147 cc := c.Clone()
148 s2av2Creds, ok := c.(*s2av2TransportCreds)
149 if !ok {
150 t.Fatal("The created creds is not of type s2av2TransportCreds")
151 }
152 s2av2CloneCreds, ok := cc.(*s2av2TransportCreds)
153 if !ok {
154 t.Fatal("The created clone creds is not of type s2aTransportCreds")
155 }
156 if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
157 xToken, err := x.DefaultToken()
158 if err != nil {
159 t.Errorf("Failed to compare cloned creds: %v", err)
160 }
161 yToken, err := y.DefaultToken()
162 if err != nil {
163 t.Errorf("Failed to compare cloned creds: %v", err)
164 }
165 if xToken == yToken {
166 return true
167 }
168 return false
169 }), cmp.Comparer(func(x, y fallback.ClientHandshake) bool {
170 return reflect.ValueOf(x) == reflect.ValueOf(y)
171 })), true; got != want {
172 t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
173 }
174
175 s2av2CloneCreds.info.SecurityProtocol = "s2a"
176 if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
177 xToken, err := x.DefaultToken()
178 if err != nil {
179 t.Errorf("Failed to compare cloned creds: %v", err)
180 }
181 yToken, err := y.DefaultToken()
182 if err != nil {
183 t.Errorf("Failed to compare cloned creds: %v", err)
184 }
185 if xToken == yToken {
186 return true
187 }
188 return false
189 })), false; got != want {
190 t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
191 }
192 }
193
194 func TestCloneServer(t *testing.T) {
195 os.Setenv("S2A_ACCESS_TOKEN", "TestCloneServer_s2a_access_token")
196 localIdentities := []*commonpbv1.Identity{
197 {
198 IdentityOneof: &commonpbv1.Identity_Hostname{
199 Hostname: "test_rsa_server_identity",
200 },
201 },
202 }
203 c, err := NewServerCreds(fakes2av2Address, nil, localIdentities, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil)
204 if err != nil {
205 t.Fatalf("NewServerCreds() failed: %v", err)
206 }
207 cc := c.Clone()
208 s2av2Creds, ok := c.(*s2av2TransportCreds)
209 if !ok {
210 t.Fatal("The created creds is not of type s2av2TransportCreds")
211 }
212 s2av2CloneCreds, ok := cc.(*s2av2TransportCreds)
213 if !ok {
214 t.Fatal("The created clone creds is not of type s2aTransportCreds")
215 }
216 if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
217 xToken, err := x.DefaultToken()
218 if err != nil {
219 t.Errorf("Failed to compare cloned creds: %v", err)
220 }
221 yToken, err := y.DefaultToken()
222 if err != nil {
223 t.Errorf("Failed to compare cloned creds: %v", err)
224 }
225 if xToken == yToken {
226 return true
227 }
228 return false
229 })), true; got != want {
230 t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
231 }
232
233 s2av2CloneCreds.info.SecurityProtocol = "s2a"
234 if got, want := cmp.Equal(s2av2Creds, s2av2CloneCreds, protocmp.Transform(), cmp.AllowUnexported(s2av2TransportCreds{}), cmp.Comparer(func(x, y tokenmanager.AccessTokenManager) bool {
235 xToken, err := x.DefaultToken()
236 if err != nil {
237 t.Errorf("Failed to compare cloned creds: %v", err)
238 }
239 yToken, err := y.DefaultToken()
240 if err != nil {
241 t.Errorf("Failed to compare cloned creds: %v", err)
242 }
243 if xToken == yToken {
244 return true
245 }
246 return false
247 })), false; got != want {
248 t.Errorf("cmp.Equal(%+v, %+v) = %v, want %v", s2av2Creds, s2av2CloneCreds, got, want)
249 }
250 }
251
252 func TestOverrideServerName(t *testing.T) {
253
254 os.Setenv("S2A_ACCESS_TOKEN", "TestOverrideServerName_s2a_access_token")
255 c, err := NewClientCreds(fakes2av2Address, nil, &commonpbv1.Identity{
256 IdentityOneof: &commonpbv1.Identity_Hostname{
257 Hostname: "test_rsa_client_identity",
258 },
259 }, s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE, nil, nil, nil)
260 s2av2Creds, ok := c.(*s2av2TransportCreds)
261 if !ok {
262 t.Fatal("The created creds is not of type s2av2TransportCreds")
263 }
264 if err != nil {
265 t.Fatalf("NewClientCreds() failed: %v", err)
266 }
267 if got, want := c.Info().ServerName, ""; got != want {
268 t.Errorf("c.Info().ServerName = %v, want %v", got, want)
269 }
270 if got, want := s2av2Creds.serverName, ""; got != want {
271 t.Errorf("c.serverName = %v, want %v", got, want)
272 }
273 for _, tc := range []struct {
274 description string
275 override string
276 wantServerName string
277 expectError bool
278 }{
279 {
280 description: "empty string",
281 override: "",
282 wantServerName: "",
283 },
284 {
285 description: "host only",
286 override: "server.name",
287 wantServerName: "server.name",
288 },
289 {
290 description: "invalid syntax",
291 override: "server::",
292 wantServerName: "server::",
293 },
294 {
295 description: "split host port",
296 override: "host:port",
297 wantServerName: "host",
298 },
299 } {
300 t.Run(tc.description, func(t *testing.T) {
301 c.OverrideServerName(tc.override)
302 if got, want := c.Info().ServerName, tc.wantServerName; got != want {
303 t.Errorf("c.Info().ServerName = %v, want %v", got, want)
304 }
305 if got, want := s2av2Creds.serverName, tc.wantServerName; got != want {
306 t.Errorf("c.serverName = %v, want %v", got, want)
307 }
308 })
309 }
310 }
311
312 type s2ATestStream struct {
313 debug string
314 }
315
316 func (x s2ATestStream) Send(m *s2av2pb.SessionReq) error {
317 return nil
318 }
319
320 func (x s2ATestStream) Recv() (*s2av2pb.SessionResp, error) {
321 return nil, nil
322 }
323
324 func (x s2ATestStream) CloseSend() error {
325 return nil
326 }
327
328 func TestCreateStream(t *testing.T) {
329 for _, tc := range []struct {
330 description string
331 }{
332 {
333 description: "static",
334 },
335 } {
336 t.Run(tc.description, func(t *testing.T) {
337 s2AStream, err := createStream(context.TODO(), "fake address", nil, func(ctx context.Context, s2av2Address string) (stream.S2AStream, error) {
338 return s2ATestStream{debug: "test s2a stream"}, nil
339 })
340 if err != nil {
341 t.Fatalf("New S2AStream failed: %v", err)
342 }
343 testStream, ok := s2AStream.(s2ATestStream)
344 if !ok {
345 t.Fatal("The created stream is not of type s2ATestStream")
346 }
347 if testStream.debug != "test s2a stream" {
348 t.Errorf("The created stream is not the intended stream")
349 }
350 })
351 }
352 }
353
354 func TestGetS2ATimeout(t *testing.T) {
355 oldEnvValue := os.Getenv(s2aTimeoutEnv)
356 defer os.Setenv(s2aTimeoutEnv, oldEnvValue)
357
358
359 os.Unsetenv(s2aTimeoutEnv)
360 if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
361 t.Fatalf("GetS2ATimeout should return default if S2A_TIMEOUT is not set")
362 }
363
364
365 os.Setenv(s2aTimeoutEnv, "")
366 if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
367 t.Fatalf("GetS2ATimeout should return default if S2A_TIMEOUT is set to empty string")
368 }
369
370
371 os.Setenv(s2aTimeoutEnv, "5s")
372 if got, want := GetS2ATimeout(), 5*time.Second; got != want {
373 t.Fatalf("expected timeout to be 5s")
374 }
375
376
377 os.Setenv(s2aTimeoutEnv, "5abc")
378 if got, want := GetS2ATimeout(), defaultS2ATimeout; got != want {
379 t.Fatalf("expected timeout to be default if the set timeout is invalid")
380 }
381 }
382
View as plain text