...

Source file src/cloud.google.com/go/auth/grpctransport/grpctransport_test.go

Documentation: cloud.google.com/go/auth/grpctransport

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grpctransport
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"log"
    21  	"net"
    22  	"testing"
    23  
    24  	"cloud.google.com/go/auth"
    25  	"cloud.google.com/go/auth/credentials"
    26  	echo "cloud.google.com/go/auth/grpctransport/testdata"
    27  	"cloud.google.com/go/auth/internal"
    28  	"github.com/google/go-cmp/cmp"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/credentials/insecure"
    31  	"google.golang.org/grpc/metadata"
    32  )
    33  
    34  func TestCheckDirectPathEndPoint(t *testing.T) {
    35  	for _, testcase := range []struct {
    36  		name     string
    37  		endpoint string
    38  		want     bool
    39  	}{
    40  		{
    41  			name:     "empty endpoint are disallowed",
    42  			endpoint: "",
    43  			want:     false,
    44  		},
    45  		{
    46  			name:     "dns schemes are allowed",
    47  			endpoint: "dns:///foo",
    48  			want:     true,
    49  		},
    50  		{
    51  			name:     "host without no prefix are allowed",
    52  			endpoint: "foo",
    53  			want:     true,
    54  		},
    55  		{
    56  			name:     "host with port are allowed",
    57  			endpoint: "foo:1234",
    58  			want:     true,
    59  		},
    60  		{
    61  			name:     "non-dns schemes are disallowed",
    62  			endpoint: "https://foo",
    63  			want:     false,
    64  		},
    65  	} {
    66  		t.Run(testcase.name, func(t *testing.T) {
    67  			if got := checkDirectPathEndPoint(testcase.endpoint); got != testcase.want {
    68  				t.Fatalf("got %v, want %v", got, testcase.want)
    69  			}
    70  		})
    71  	}
    72  }
    73  
    74  func TestDial_FailsValidation(t *testing.T) {
    75  	tests := []struct {
    76  		name string
    77  		opts *Options
    78  	}{
    79  		{
    80  			name: "missing options",
    81  		},
    82  		{
    83  			name: "has creds with disable options, tp",
    84  			opts: &Options{
    85  				DisableAuthentication: true,
    86  				Credentials: auth.NewCredentials(&auth.CredentialsOptions{
    87  					TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
    88  				}),
    89  			},
    90  		},
    91  		{
    92  			name: "has creds with disable options, cred file",
    93  			opts: &Options{
    94  				DisableAuthentication: true,
    95  				DetectOpts: &credentials.DetectOptions{
    96  					CredentialsFile: "abc.123",
    97  				},
    98  			},
    99  		},
   100  		{
   101  			name: "has creds with disable options, cred json",
   102  			opts: &Options{
   103  				DisableAuthentication: true,
   104  				DetectOpts: &credentials.DetectOptions{
   105  					CredentialsJSON: []byte(`{"foo":"bar"}`),
   106  				},
   107  			},
   108  		},
   109  	}
   110  	for _, tt := range tests {
   111  		t.Run(tt.name, func(t *testing.T) {
   112  			_, err := Dial(context.Background(), false, tt.opts)
   113  			if err == nil {
   114  				t.Fatal("NewClient() = _, nil, want error")
   115  			}
   116  		})
   117  	}
   118  }
   119  
   120  func TestDial_SkipValidation(t *testing.T) {
   121  	opts := &Options{
   122  		DisableAuthentication: true,
   123  		Credentials: auth.NewCredentials(&auth.CredentialsOptions{
   124  			TokenProvider: &staticTP{tok: &auth.Token{Value: "fakeToken"}},
   125  		}),
   126  	}
   127  	t.Run("invalid opts", func(t *testing.T) {
   128  		if err := opts.validate(); err == nil {
   129  			t.Fatalf("opts.validate() = nil, want error")
   130  		}
   131  	})
   132  
   133  	t.Run("skip invalid opts", func(t *testing.T) {
   134  		opts.InternalOptions = &InternalOptions{SkipValidation: true}
   135  		if err := opts.validate(); err != nil {
   136  			t.Fatalf("opts.validate() = %v, want nil", err)
   137  		}
   138  	})
   139  }
   140  
   141  func TestOptions_ResolveDetectOptions(t *testing.T) {
   142  	tests := []struct {
   143  		name string
   144  		in   *Options
   145  		want *credentials.DetectOptions
   146  	}{
   147  		{
   148  			name: "base",
   149  			in: &Options{
   150  				DetectOpts: &credentials.DetectOptions{
   151  					Scopes:          []string{"scope"},
   152  					CredentialsFile: "/path/to/a/file",
   153  				},
   154  			},
   155  			want: &credentials.DetectOptions{
   156  				Scopes:          []string{"scope"},
   157  				CredentialsFile: "/path/to/a/file",
   158  			},
   159  		},
   160  		{
   161  			name: "self-signed, with scope",
   162  			in: &Options{
   163  				InternalOptions: &InternalOptions{
   164  					EnableJWTWithScope: true,
   165  				},
   166  				DetectOpts: &credentials.DetectOptions{
   167  					Scopes:          []string{"scope"},
   168  					CredentialsFile: "/path/to/a/file",
   169  				},
   170  			},
   171  			want: &credentials.DetectOptions{
   172  				Scopes:           []string{"scope"},
   173  				CredentialsFile:  "/path/to/a/file",
   174  				UseSelfSignedJWT: true,
   175  			},
   176  		},
   177  		{
   178  			name: "self-signed, with aud",
   179  			in: &Options{
   180  				DetectOpts: &credentials.DetectOptions{
   181  					Audience:        "aud",
   182  					CredentialsFile: "/path/to/a/file",
   183  				},
   184  			},
   185  			want: &credentials.DetectOptions{
   186  				Audience:         "aud",
   187  				CredentialsFile:  "/path/to/a/file",
   188  				UseSelfSignedJWT: true,
   189  			},
   190  		},
   191  		{
   192  			name: "use default scopes",
   193  			in: &Options{
   194  				InternalOptions: &InternalOptions{
   195  					DefaultScopes:   []string{"default"},
   196  					DefaultAudience: "default",
   197  				},
   198  				DetectOpts: &credentials.DetectOptions{
   199  					CredentialsFile: "/path/to/a/file",
   200  				},
   201  			},
   202  			want: &credentials.DetectOptions{
   203  				Scopes:          []string{"default"},
   204  				CredentialsFile: "/path/to/a/file",
   205  			},
   206  		},
   207  		{
   208  			name: "don't use default scopes, scope provided",
   209  			in: &Options{
   210  				InternalOptions: &InternalOptions{
   211  					DefaultScopes:   []string{"default"},
   212  					DefaultAudience: "default",
   213  				},
   214  				DetectOpts: &credentials.DetectOptions{
   215  					Scopes:          []string{"non-default"},
   216  					CredentialsFile: "/path/to/a/file",
   217  				},
   218  			},
   219  			want: &credentials.DetectOptions{
   220  				Scopes:          []string{"non-default"},
   221  				CredentialsFile: "/path/to/a/file",
   222  			},
   223  		},
   224  		{
   225  			name: "don't use default scopes, aud provided",
   226  			in: &Options{
   227  				InternalOptions: &InternalOptions{
   228  					DefaultScopes:   []string{"default"},
   229  					DefaultAudience: "default",
   230  				},
   231  				DetectOpts: &credentials.DetectOptions{
   232  					Audience:        "non-default",
   233  					CredentialsFile: "/path/to/a/file",
   234  				},
   235  			},
   236  			want: &credentials.DetectOptions{
   237  				Audience:         "non-default",
   238  				CredentialsFile:  "/path/to/a/file",
   239  				UseSelfSignedJWT: true,
   240  			},
   241  		},
   242  		{
   243  			name: "use default aud",
   244  			in: &Options{
   245  				InternalOptions: &InternalOptions{
   246  					DefaultAudience: "default",
   247  				},
   248  				DetectOpts: &credentials.DetectOptions{
   249  					CredentialsFile: "/path/to/a/file",
   250  				},
   251  			},
   252  			want: &credentials.DetectOptions{
   253  				Audience:        "default",
   254  				CredentialsFile: "/path/to/a/file",
   255  			},
   256  		},
   257  	}
   258  	for _, tt := range tests {
   259  		t.Run(tt.name, func(t *testing.T) {
   260  			got := tt.in.resolveDetectOptions()
   261  			if diff := cmp.Diff(tt.want, got); diff != "" {
   262  				t.Errorf("mismatch (-want +got):\n%s", diff)
   263  			}
   264  		})
   265  	}
   266  }
   267  
   268  func TestGrpcCredentialsProvider_GetClientUniverseDomain(t *testing.T) {
   269  	nonDefault := "example.com"
   270  	tests := []struct {
   271  		name           string
   272  		universeDomain string
   273  		want           string
   274  	}{
   275  		{
   276  			name:           "default",
   277  			universeDomain: "",
   278  			want:           internal.DefaultUniverseDomain,
   279  		},
   280  		{
   281  			name:           "non-default",
   282  			universeDomain: nonDefault,
   283  			want:           nonDefault,
   284  		},
   285  	}
   286  	for _, tt := range tests {
   287  		t.Run(tt.name, func(t *testing.T) {
   288  			at := &grpcCredentialsProvider{clientUniverseDomain: tt.universeDomain}
   289  			got := at.getClientUniverseDomain()
   290  			if got != tt.want {
   291  				t.Errorf("got %q, want %q", got, tt.want)
   292  			}
   293  		})
   294  	}
   295  }
   296  
   297  func TestGrpcCredentialsProvider_TokenType(t *testing.T) {
   298  	tests := []struct {
   299  		name string
   300  		tok  *auth.Token
   301  		want string
   302  	}{
   303  		{
   304  			name: "type set",
   305  			tok: &auth.Token{
   306  				Value: "token",
   307  				Type:  "Basic",
   308  			},
   309  			want: "Basic token",
   310  		},
   311  		{
   312  			name: "type set",
   313  			tok: &auth.Token{
   314  				Value: "token",
   315  			},
   316  			want: "Bearer token",
   317  		},
   318  	}
   319  	for _, tc := range tests {
   320  		cp := grpcCredentialsProvider{
   321  			creds: &auth.Credentials{
   322  				TokenProvider: &staticTP{tok: tc.tok},
   323  			},
   324  		}
   325  		m, err := cp.GetRequestMetadata(context.Background(), "")
   326  		if err != nil {
   327  			log.Fatalf("cp.GetRequestMetadata() = %v, want nil", err)
   328  		}
   329  		if got := m["authorization"]; got != tc.want {
   330  			t.Fatalf("got %q, want %q", got, tc.want)
   331  		}
   332  	}
   333  }
   334  
   335  func TestNewClient_DetectedServiceAccount(t *testing.T) {
   336  	testQuota := "testquota"
   337  	wantHeader := "bar"
   338  	t.Setenv(internal.QuotaProjectEnvVar, testQuota)
   339  	l, err := net.Listen("tcp", "localhost:0")
   340  	if err != nil {
   341  		t.Fatal(err)
   342  	}
   343  	gsrv := grpc.NewServer()
   344  	defer gsrv.Stop()
   345  	echo.RegisterEchoerServer(gsrv, &fakeEchoService{
   346  		Fn: func(ctx context.Context, _ *echo.EchoRequest) (*echo.EchoReply, error) {
   347  			md, ok := metadata.FromIncomingContext(ctx)
   348  			if !ok {
   349  				t.Error("unable to extract metadata")
   350  				return nil, errors.New("oops")
   351  			}
   352  			if got := md.Get("authorization"); len(got) != 1 {
   353  				t.Errorf(`got "", want an auth token`)
   354  			}
   355  			if got := md.Get("Foo"); len(got) != 1 || got[0] != wantHeader {
   356  				t.Errorf("got %q, want %q", got, wantHeader)
   357  			}
   358  			if got := md.Get(quotaProjectHeaderKey); len(got) != 1 || got[0] != testQuota {
   359  				t.Errorf("got %q, want %q", got, testQuota)
   360  			}
   361  			return &echo.EchoReply{}, nil
   362  		},
   363  	})
   364  	go func() {
   365  		if err := gsrv.Serve(l); err != nil {
   366  			panic(err)
   367  		}
   368  	}()
   369  
   370  	pool, err := Dial(context.Background(), false, &Options{
   371  		Metadata: map[string]string{"Foo": wantHeader},
   372  		InternalOptions: &InternalOptions{
   373  			DefaultEndpointTemplate: l.Addr().String(),
   374  		},
   375  		DetectOpts: &credentials.DetectOptions{
   376  			Audience:         l.Addr().String(),
   377  			CredentialsFile:  "../internal/testdata/sa_universe_domain.json",
   378  			UseSelfSignedJWT: true,
   379  		},
   380  		GRPCDialOpts:   []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
   381  		UniverseDomain: "example.com", // Also configured in sa_universe_domain.json
   382  	})
   383  	if err != nil {
   384  		t.Fatalf("NewClient() = %v", err)
   385  	}
   386  	client := echo.NewEchoerClient(pool)
   387  	if _, err := client.Echo(context.Background(), &echo.EchoRequest{}); err != nil {
   388  		t.Fatalf("client.Echo() = %v", err)
   389  	}
   390  }
   391  
   392  type staticTP struct {
   393  	tok *auth.Token
   394  }
   395  
   396  func (tp *staticTP) Token(context.Context) (*auth.Token, error) {
   397  	return tp.tok, nil
   398  }
   399  
   400  type fakeEchoService struct {
   401  	Fn func(context.Context, *echo.EchoRequest) (*echo.EchoReply, error)
   402  	echo.UnimplementedEchoerServer
   403  }
   404  
   405  func (s *fakeEchoService) Echo(c context.Context, r *echo.EchoRequest) (*echo.EchoReply, error) {
   406  	return s.Fn(c, r)
   407  }
   408  

View as plain text