1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package grpctransport
16
17 import (
18 "context"
19 "errors"
20 "log"
21 "net"
22 "testing"
23
24 "cloud.google.com/go/auth"
25 "cloud.google.com/go/auth/credentials"
26 echo "cloud.google.com/go/auth/grpctransport/testdata"
27 "cloud.google.com/go/auth/internal"
28 "github.com/google/go-cmp/cmp"
29 "google.golang.org/grpc"
30 "google.golang.org/grpc/credentials/insecure"
31 "google.golang.org/grpc/metadata"
32 )
33
34 func TestCheckDirectPathEndPoint(t *testing.T) {
35 for _, testcase := range []struct {
36 name string
37 endpoint string
38 want bool
39 }{
40 {
41 name: "empty endpoint are disallowed",
42 endpoint: "",
43 want: false,
44 },
45 {
46 name: "dns schemes are allowed",
47 endpoint: "dns:///foo",
48 want: true,
49 },
50 {
51 name: "host without no prefix are allowed",
52 endpoint: "foo",
53 want: true,
54 },
55 {
56 name: "host with port are allowed",
57 endpoint: "foo:1234",
58 want: true,
59 },
60 {
61 name: "non-dns schemes are disallowed",
62 endpoint: "https://foo",
63 want: false,
64 },
65 } {
66 t.Run(testcase.name, func(t *testing.T) {
67 if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
68 t.Fatalf("got %v, want %v", got, testcase.want)
69 }
70 })
71 }
72 }
73
74 func TestDial_FailsValidation(t *testing.T) {
75 tests := []struct {
76 name string
77 opts *Options
78 }{
79 {
80 name: "missing options",
81 },
82 {
83 name: "has creds with disable options, tp",
84 opts: &Options{
85 DisableAuthentication: true,
86 Credentials: auth.NewCredentials(&auth.CredentialsOptions{
87 TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
88 }),
89 },
90 },
91 {
92 name: "has creds with disable options, cred file",
93 opts: &Options{
94 DisableAuthentication: true,
95 DetectOpts: &credentials.DetectOptions{
96 CredentialsFile: "abc.123",
97 },
98 },
99 },
100 {
101 name: "has creds with disable options, cred json",
102 opts: &Options{
103 DisableAuthentication: true,
104 DetectOpts: &credentials.DetectOptions{
105 CredentialsJSON: []byte(`{"foo":"bar"}`),
106 },
107 },
108 },
109 }
110 for _, tt := range tests {
111 t.Run(tt.name, func(t *testing.T) {
112 _, err := Dial(context.Background(), false, tt.opts)
113 if err == nil {
114 t.Fatal("NewClient() = _, nil, want error")
115 }
116 })
117 }
118 }
119
120 func TestDial_SkipValidation(t *testing.T) {
121 opts := &Options{
122 DisableAuthentication: true,
123 Credentials: auth.NewCredentials(&auth.CredentialsOptions{
124 TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
125 }),
126 }
127 t.Run("invalid opts", func(t *testing.T) {
128 if err := opts.validate(); err == nil {
129 t.Fatalf("opts.validate() = nil, want error")
130 }
131 })
132
133 t.Run("skip invalid opts", func(t *testing.T) {
134 opts.InternalOptions = &InternalOptions{SkipValidation: true}
135 if err := opts.validate(); err != nil {
136 t.Fatalf("opts.validate() = %v, want nil", err)
137 }
138 })
139 }
140
141 func TestOptions_ResolveDetectOptions(t *testing.T) {
142 tests := []struct {
143 name string
144 in *Options
145 want *credentials.DetectOptions
146 }{
147 {
148 name: "base",
149 in: &Options{
150 DetectOpts: &credentials.DetectOptions{
151 Scopes: []string{"scope"},
152 CredentialsFile: "/path/to/a/file",
153 },
154 },
155 want: &credentials.DetectOptions{
156 Scopes: []string{"scope"},
157 CredentialsFile: "/path/to/a/file",
158 },
159 },
160 {
161 name: "self-signed, with scope",
162 in: &Options{
163 InternalOptions: &InternalOptions{
164 EnableJWTWithScope: true,
165 },
166 DetectOpts: &credentials.DetectOptions{
167 Scopes: []string{"scope"},
168 CredentialsFile: "/path/to/a/file",
169 },
170 },
171 want: &credentials.DetectOptions{
172 Scopes: []string{"scope"},
173 CredentialsFile: "/path/to/a/file",
174 UseSelfSignedJWT: true,
175 },
176 },
177 {
178 name: "self-signed, with aud",
179 in: &Options{
180 DetectOpts: &credentials.DetectOptions{
181 Audience: "aud",
182 CredentialsFile: "/path/to/a/file",
183 },
184 },
185 want: &credentials.DetectOptions{
186 Audience: "aud",
187 CredentialsFile: "/path/to/a/file",
188 UseSelfSignedJWT: true,
189 },
190 },
191 {
192 name: "use default scopes",
193 in: &Options{
194 InternalOptions: &InternalOptions{
195 DefaultScopes: []string{"default"},
196 DefaultAudience: "default",
197 },
198 DetectOpts: &credentials.DetectOptions{
199 CredentialsFile: "/path/to/a/file",
200 },
201 },
202 want: &credentials.DetectOptions{
203 Scopes: []string{"default"},
204 CredentialsFile: "/path/to/a/file",
205 },
206 },
207 {
208 name: "don't use default scopes, scope provided",
209 in: &Options{
210 InternalOptions: &InternalOptions{
211 DefaultScopes: []string{"default"},
212 DefaultAudience: "default",
213 },
214 DetectOpts: &credentials.DetectOptions{
215 Scopes: []string{"non-default"},
216 CredentialsFile: "/path/to/a/file",
217 },
218 },
219 want: &credentials.DetectOptions{
220 Scopes: []string{"non-default"},
221 CredentialsFile: "/path/to/a/file",
222 },
223 },
224 {
225 name: "don't use default scopes, aud provided",
226 in: &Options{
227 InternalOptions: &InternalOptions{
228 DefaultScopes: []string{"default"},
229 DefaultAudience: "default",
230 },
231 DetectOpts: &credentials.DetectOptions{
232 Audience: "non-default",
233 CredentialsFile: "/path/to/a/file",
234 },
235 },
236 want: &credentials.DetectOptions{
237 Audience: "non-default",
238 CredentialsFile: "/path/to/a/file",
239 UseSelfSignedJWT: true,
240 },
241 },
242 {
243 name: "use default aud",
244 in: &Options{
245 InternalOptions: &InternalOptions{
246 DefaultAudience: "default",
247 },
248 DetectOpts: &credentials.DetectOptions{
249 CredentialsFile: "/path/to/a/file",
250 },
251 },
252 want: &credentials.DetectOptions{
253 Audience: "default",
254 CredentialsFile: "/path/to/a/file",
255 },
256 },
257 }
258 for _, tt := range tests {
259 t.Run(tt.name, func(t *testing.T) {
260 got := tt.in.resolveDetectOptions()
261 if diff := cmp.Diff(tt.want, got); diff != "" {
262 t.Errorf("mismatch (-want +got):\n%s", diff)
263 }
264 })
265 }
266 }
267
268 func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) {
269 nonDefault := "example.com"
270 tests := []struct {
271 name string
272 universeDomain string
273 want string
274 }{
275 {
276 name: "default",
277 universeDomain: "",
278 want: internal.DefaultUniverseDomain,
279 },
280 {
281 name: "non-default",
282 universeDomain: nonDefault,
283 want: nonDefault,
284 },
285 }
286 for _, tt := range tests {
287 t.Run(tt.name, func(t *testing.T) {
288 at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain}
289 got := at.getClientUniverseDomain()
290 if got != tt.want {
291 t.Errorf("got %q, want %q", got, tt.want)
292 }
293 })
294 }
295 }
296
297 func TestGrpcCredentialsProvider_TokenType(t *testing.T) {
298 tests := []struct {
299 name string
300 tok *auth.Token
301 want string
302 }{
303 {
304 name: "type set",
305 tok: &auth.Token{
306 Value: "token",
307 Type: "Basic",
308 },
309 want: "Basic token",
310 },
311 {
312 name: "type set",
313 tok: &auth.Token{
314 Value: "token",
315 },
316 want: "Bearer token",
317 },
318 }
319 for _, tc := range tests {
320 cp := grpcCredentialsProvider{
321 creds: &auth.Credentials{
322 TokenProvider: &staticTP{tok: tc.tok},
323 },
324 }
325 m, err := cp.GetRequestMetadata(context.Background(), "")
326 if err != nil {
327 log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err)
328 }
329 if got := m["authorization"]; got != tc.want {
330 t.Fatalf("got %q, want %q", got, tc.want)
331 }
332 }
333 }
334
335 func TestNewClient_DetectedServiceAccount(t *testing.T) {
336 testQuota := "testquota"
337 wantHeader := "bar"
338 t.Setenv(internal.QuotaProjectEnvVar, testQuota)
339 l, err := net.Listen("tcp", "localhost:0")
340 if err != nil {
341 t.Fatal(err)
342 }
343 gsrv := grpc.NewServer()
344 defer gsrv.Stop()
345 echo.RegisterEchoerServer(gsrv, &fakeEchoService{
346 Fn: func(ctx context.Context, _ *echo.EchoRequest) (*echo.EchoReply, error) {
347 md, ok := metadata.FromIncomingContext(ctx)
348 if !ok {
349 t.Error("unable to extract metadata")
350 return nil, errors.New("oops")
351 }
352 if got := md.Get("authorization"); len(got) != 1 {
353 t.Errorf(`got "", want an auth token`)
354 }
355 if got := md.Get("Foo"); len(got) != 1 || got[0] != wantHeader {
356 t.Errorf("got %q, want %q", got, wantHeader)
357 }
358 if got := md.Get(quotaProjectHeaderKey); len(got) != 1 || got[0] != testQuota {
359 t.Errorf("got %q, want %q", got, testQuota)
360 }
361 return &echo.EchoReply{}, nil
362 },
363 })
364 go func() {
365 if err := gsrv.Serve(l); err != nil {
366 panic(err)
367 }
368 }()
369
370 pool, err := Dial(context.Background(), false, &Options{
371 Metadata: map[string]string{"Foo": wantHeader},
372 InternalOptions: &InternalOptions{
373 DefaultEndpointTemplate: l.Addr().String(),
374 },
375 DetectOpts: &credentials.DetectOptions{
376 Audience: l.Addr().String(),
377 CredentialsFile: "../internal/testdata/sa_universe_domain.json",
378 UseSelfSignedJWT: true,
379 },
380 GRPCDialOpts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
381 UniverseDomain: "example.com",
382 })
383 if err != nil {
384 t.Fatalf("NewClient() = %v", err)
385 }
386 client := echo.NewEchoerClient(pool)
387 if _, err := client.Echo(context.Background(), &echo.EchoRequest{}); err != nil {
388 t.Fatalf("client.Echo() = %v", err)
389 }
390 }
391
392 type staticTP struct {
393 tok *auth.Token
394 }
395
396 func (tp *staticTP) Token(context.Context) (*auth.Token, error) {
397 return tp.tok, nil
398 }
399
400 type fakeEchoService struct {
401 Fn func(context.Context, *echo.EchoRequest) (*echo.EchoReply, error)
402 echo.UnimplementedEchoerServer
403 }
404
405 func (s *fakeEchoService) Echo(c context.Context, r *echo.EchoRequest) (*echo.EchoReply, error) {
406 return s.Fn(c, r)
407 }
408
View as plain text