...

Source file src/cloud.google.com/go/auth/credentials/impersonate/impersonate_test.go

Documentation: cloud.google.com/go/auth/credentials/impersonate

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package impersonate
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"io"
    22  	"net/http"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/google/go-cmp/cmp"
    28  )
    29  
    30  func TestNewCredentials_serviceAccount(t *testing.T) {
    31  	ctx := context.Background()
    32  	tests := []struct {
    33  		name    string
    34  		config  CredentialsOptions
    35  		wantErr error
    36  	}{
    37  		{
    38  			name:    "missing targetPrincipal",
    39  			wantErr: errMissingTargetPrincipal,
    40  		},
    41  		{
    42  			name: "missing scopes",
    43  			config: CredentialsOptions{
    44  				TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    45  			},
    46  			wantErr: errMissingScopes,
    47  		},
    48  		{
    49  			name: "lifetime over max",
    50  			config: CredentialsOptions{
    51  				TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    52  				Scopes:          []string{"scope"},
    53  				Lifetime:        13 * time.Hour,
    54  			},
    55  			wantErr: errLifetimeOverMax,
    56  		},
    57  		{
    58  			name: "works",
    59  			config: CredentialsOptions{
    60  				TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    61  				Scopes:          []string{"scope"},
    62  			},
    63  			wantErr: nil,
    64  		},
    65  		{
    66  			name: "universe domain",
    67  			config: CredentialsOptions{
    68  				TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    69  				Scopes:          []string{"scope"},
    70  				Subject:         "admin@example.com",
    71  				UniverseDomain:  "example.com",
    72  			},
    73  			wantErr: errUniverseNotSupportedDomainWideDelegation,
    74  		},
    75  	}
    76  
    77  	for _, tt := range tests {
    78  		name := tt.name
    79  		t.Run(name, func(t *testing.T) {
    80  			saTok := "sa-token"
    81  			client := &http.Client{
    82  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
    83  					if strings.Contains(req.URL.Path, "generateAccessToken") {
    84  						defer req.Body.Close()
    85  						b, err := io.ReadAll(req.Body)
    86  						if err != nil {
    87  							t.Error(err)
    88  						}
    89  						var r generateAccessTokenRequest
    90  						if err := json.Unmarshal(b, &r); err != nil {
    91  							t.Error(err)
    92  						}
    93  						if !cmp.Equal(r.Scope, tt.config.Scopes) {
    94  							t.Errorf("got %v, want %v", r.Scope, tt.config.Scopes)
    95  						}
    96  						if !strings.Contains(req.URL.Path, tt.config.TargetPrincipal) {
    97  							t.Errorf("got %q, want %q", req.URL.Path, tt.config.TargetPrincipal)
    98  						}
    99  
   100  						resp := generateAccessTokenResponse{
   101  							AccessToken: saTok,
   102  							ExpireTime:  time.Now().Format(time.RFC3339),
   103  						}
   104  						b, err = json.Marshal(&resp)
   105  						if err != nil {
   106  							t.Fatalf("unable to marshal response: %v", err)
   107  						}
   108  						return &http.Response{
   109  							StatusCode: 200,
   110  							Body:       io.NopCloser(bytes.NewReader(b)),
   111  							Header:     http.Header{},
   112  						}
   113  					}
   114  					return nil
   115  				}),
   116  			}
   117  			tt.config.Client = client
   118  			ts, err := NewCredentials(&tt.config)
   119  			if err != nil {
   120  				if err != tt.wantErr {
   121  					t.Fatalf("err: %v", err)
   122  				}
   123  			} else {
   124  				tok, err := ts.Token(ctx)
   125  				if err != nil {
   126  					t.Fatal(err)
   127  				}
   128  				if tok.Value != saTok {
   129  					t.Fatalf("got %q, want %q", tok.Value, saTok)
   130  				}
   131  			}
   132  		})
   133  	}
   134  }
   135  
   136  type RoundTripFn func(req *http.Request) *http.Response
   137  
   138  func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
   139  
   140  func TestCredentialsOptions_UniverseDomain(t *testing.T) {
   141  	testCases := []struct {
   142  		name               string
   143  		opts               *CredentialsOptions
   144  		wantUniverseDomain string
   145  		wantIsGDU          bool
   146  	}{
   147  		{
   148  			name:               "empty",
   149  			opts:               &CredentialsOptions{},
   150  			wantUniverseDomain: "googleapis.com",
   151  			wantIsGDU:          true,
   152  		},
   153  		{
   154  			name: "defaults",
   155  			opts: &CredentialsOptions{
   156  				UniverseDomain: "googleapis.com",
   157  			},
   158  			wantUniverseDomain: "googleapis.com",
   159  			wantIsGDU:          true,
   160  		},
   161  		{
   162  			name: "non-GDU",
   163  			opts: &CredentialsOptions{
   164  				UniverseDomain: "example.com",
   165  			},
   166  			wantUniverseDomain: "example.com",
   167  			wantIsGDU:          false,
   168  		},
   169  	}
   170  	for _, tc := range testCases {
   171  		t.Run(tc.name, func(t *testing.T) {
   172  			if got := tc.opts.getUniverseDomain(); got != tc.wantUniverseDomain {
   173  				t.Errorf("got %v, want %v", got, tc.wantUniverseDomain)
   174  			}
   175  			if got := tc.opts.isUniverseDomainGDU(); got != tc.wantIsGDU {
   176  				t.Errorf("got %v, want %v", got, tc.wantIsGDU)
   177  			}
   178  		})
   179  	}
   180  }
   181  

View as plain text