...

Source file src/cloud.google.com/go/auth/credentials/impersonate/user_test.go

Documentation: cloud.google.com/go/auth/credentials/impersonate

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package impersonate
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"io"
    22  	"net/http"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"cloud.google.com/go/auth/internal"
    28  	"cloud.google.com/go/auth/internal/jwt"
    29  )
    30  
    31  func TestNewCredentials_user(t *testing.T) {
    32  	ctx := context.Background()
    33  	tests := []struct {
    34  		name            string
    35  		targetPrincipal string
    36  		scopes          []string
    37  		lifetime        time.Duration
    38  		subject         string
    39  		wantErr         bool
    40  		universeDomain  string
    41  	}{
    42  		{
    43  			name:    "missing targetPrincipal",
    44  			wantErr: true,
    45  		},
    46  		{
    47  			name:            "missing scopes",
    48  			targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    49  			wantErr:         true,
    50  		},
    51  		{
    52  			name:            "lifetime over max",
    53  			targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    54  			scopes:          []string{"scope"},
    55  			lifetime:        13 * time.Hour,
    56  			wantErr:         true,
    57  		},
    58  		{
    59  			name:            "works",
    60  			targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    61  			scopes:          []string{"scope"},
    62  			subject:         "admin@example.com",
    63  			wantErr:         false,
    64  		},
    65  		{
    66  			name:            "universeDomain",
    67  			targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
    68  			scopes:          []string{"scope"},
    69  			subject:         "admin@example.com",
    70  			wantErr:         true,
    71  			// Non-GDU Universe Domain should result in error if
    72  			// CredentialsConfig.Subject is present for domain-wide delegation.
    73  			universeDomain: "example.com",
    74  		},
    75  	}
    76  
    77  	for _, tt := range tests {
    78  		userTok := "user-token"
    79  		name := tt.name
    80  		t.Run(name, func(t *testing.T) {
    81  			client := &http.Client{
    82  				Transport: RoundTripFn(func(req *http.Request) *http.Response {
    83  					defer req.Body.Close()
    84  					if strings.Contains(req.URL.Path, "signJwt") {
    85  						b, err := io.ReadAll(req.Body)
    86  						if err != nil {
    87  							t.Error(err)
    88  						}
    89  						var r signJWTRequest
    90  						if err := json.Unmarshal(b, &r); err != nil {
    91  							t.Error(err)
    92  						}
    93  						jwtPayload := map[string]interface{}{}
    94  						if err := json.Unmarshal([]byte(r.Payload), &jwtPayload); err != nil {
    95  							t.Error(err)
    96  						}
    97  						if got, want := jwtPayload["iss"].(string), tt.targetPrincipal; got != want {
    98  							t.Errorf("got %q, want %q", got, want)
    99  						}
   100  						if got, want := jwtPayload["sub"].(string), tt.subject; got != want {
   101  							t.Errorf("got %q, want %q", got, want)
   102  						}
   103  						if got, want := jwtPayload["scope"].(string), strings.Join(tt.scopes, ","); got != want {
   104  							t.Errorf("got %q, want %q", got, want)
   105  						}
   106  
   107  						resp := signJWTResponse{
   108  							KeyID:     "123",
   109  							SignedJWT: jwt.HeaderType,
   110  						}
   111  						b, err = json.Marshal(&resp)
   112  						if err != nil {
   113  							t.Fatalf("unable to marshal response: %v", err)
   114  						}
   115  						return &http.Response{
   116  							StatusCode: 200,
   117  							Body:       io.NopCloser(bytes.NewReader(b)),
   118  							Header:     make(http.Header),
   119  						}
   120  					}
   121  					if strings.Contains(req.URL.Path, "/token") {
   122  						resp := exchangeTokenResponse{
   123  							AccessToken: userTok,
   124  							TokenType:   internal.TokenTypeBearer,
   125  							ExpiresIn:   int64(time.Hour.Seconds()),
   126  						}
   127  						b, err := json.Marshal(&resp)
   128  						if err != nil {
   129  							t.Fatalf("unable to marshal response: %v", err)
   130  						}
   131  						return &http.Response{
   132  							StatusCode: 200,
   133  							Body:       io.NopCloser(bytes.NewReader(b)),
   134  							Header:     make(http.Header),
   135  						}
   136  					}
   137  					return nil
   138  				}),
   139  			}
   140  			ts, err := NewCredentials(&CredentialsOptions{
   141  				TargetPrincipal: tt.targetPrincipal,
   142  				Scopes:          tt.scopes,
   143  				Lifetime:        tt.lifetime,
   144  				Subject:         tt.subject,
   145  				Client:          client,
   146  				UniverseDomain:  tt.universeDomain,
   147  			})
   148  			if tt.wantErr && err != nil {
   149  				return
   150  			}
   151  			if err != nil {
   152  				t.Fatal(err)
   153  			}
   154  			tok, err := ts.Token(ctx)
   155  			if err != nil {
   156  				t.Fatal(err)
   157  			}
   158  			if tok.Value != userTok {
   159  				t.Fatalf("got %q, want %q", tok.Value, userTok)
   160  			}
   161  		})
   162  	}
   163  }
   164  

View as plain text