...

Source file src/cloud.google.com/go/auth/credentials/impersonate/user.go

Documentation: cloud.google.com/go/auth/credentials/impersonate

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package impersonate
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"net/http"
    23  	"net/url"
    24  	"strings"
    25  	"time"
    26  
    27  	"cloud.google.com/go/auth"
    28  	"cloud.google.com/go/auth/internal"
    29  )
    30  
    31  // user provides an auth flow for domain-wide delegation, setting
    32  // CredentialsConfig.Subject to be the impersonated user.
    33  func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool) (auth.TokenProvider, error) {
    34  	u := userTokenProvider{
    35  		client:          client,
    36  		targetPrincipal: opts.TargetPrincipal,
    37  		subject:         opts.Subject,
    38  		lifetime:        lifetime,
    39  	}
    40  	u.delegates = make([]string, len(opts.Delegates))
    41  	for i, v := range opts.Delegates {
    42  		u.delegates[i] = formatIAMServiceAccountName(v)
    43  	}
    44  	u.scopes = make([]string, len(opts.Scopes))
    45  	copy(u.scopes, opts.Scopes)
    46  	var tpo *auth.CachedTokenProviderOptions
    47  	if isStaticToken {
    48  		tpo = &auth.CachedTokenProviderOptions{
    49  			DisableAutoRefresh: true,
    50  		}
    51  	}
    52  	return auth.NewCachedTokenProvider(u, tpo), nil
    53  }
    54  
    55  type claimSet struct {
    56  	Iss   string `json:"iss"`
    57  	Scope string `json:"scope,omitempty"`
    58  	Sub   string `json:"sub,omitempty"`
    59  	Aud   string `json:"aud"`
    60  	Iat   int64  `json:"iat"`
    61  	Exp   int64  `json:"exp"`
    62  }
    63  
    64  type signJWTRequest struct {
    65  	Payload   string   `json:"payload"`
    66  	Delegates []string `json:"delegates,omitempty"`
    67  }
    68  
    69  type signJWTResponse struct {
    70  	// KeyID is the key used to sign the JWT.
    71  	KeyID string `json:"keyId"`
    72  	// SignedJwt contains the automatically generated header; the
    73  	// client-supplied payload; and the signature, which is generated using
    74  	// the key referenced by the `kid` field in the header.
    75  	SignedJWT string `json:"signedJwt"`
    76  }
    77  
    78  type exchangeTokenResponse struct {
    79  	AccessToken string `json:"access_token"`
    80  	TokenType   string `json:"token_type"`
    81  	ExpiresIn   int64  `json:"expires_in"`
    82  }
    83  
    84  type userTokenProvider struct {
    85  	client *http.Client
    86  
    87  	targetPrincipal string
    88  	subject         string
    89  	scopes          []string
    90  	lifetime        time.Duration
    91  	delegates       []string
    92  }
    93  
    94  func (u userTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
    95  	signedJWT, err := u.signJWT()
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	return u.exchangeToken(ctx, signedJWT)
   100  }
   101  
   102  func (u userTokenProvider) signJWT() (string, error) {
   103  	now := time.Now()
   104  	exp := now.Add(u.lifetime)
   105  	claims := claimSet{
   106  		Iss:   u.targetPrincipal,
   107  		Scope: strings.Join(u.scopes, " "),
   108  		Sub:   u.subject,
   109  		Aud:   fmt.Sprintf("%s/token", oauth2Endpoint),
   110  		Iat:   now.Unix(),
   111  		Exp:   exp.Unix(),
   112  	}
   113  	payloadBytes, err := json.Marshal(claims)
   114  	if err != nil {
   115  		return "", fmt.Errorf("impersonate: unable to marshal claims: %w", err)
   116  	}
   117  	signJWTReq := signJWTRequest{
   118  		Payload:   string(payloadBytes),
   119  		Delegates: u.delegates,
   120  	}
   121  
   122  	bodyBytes, err := json.Marshal(signJWTReq)
   123  	if err != nil {
   124  		return "", fmt.Errorf("impersonate: unable to marshal request: %w", err)
   125  	}
   126  	reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentialsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
   127  	req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
   128  	if err != nil {
   129  		return "", fmt.Errorf("impersonate: unable to create request: %w", err)
   130  	}
   131  	req.Header.Set("Content-Type", "application/json")
   132  	rawResp, err := u.client.Do(req)
   133  	if err != nil {
   134  		return "", fmt.Errorf("impersonate: unable to sign JWT: %w", err)
   135  	}
   136  	body, err := internal.ReadAll(rawResp.Body)
   137  	if err != nil {
   138  		return "", fmt.Errorf("impersonate: unable to read body: %w", err)
   139  	}
   140  	if c := rawResp.StatusCode; c < 200 || c > 299 {
   141  		return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
   142  	}
   143  
   144  	var signJWTResp signJWTResponse
   145  	if err := json.Unmarshal(body, &signJWTResp); err != nil {
   146  		return "", fmt.Errorf("impersonate: unable to parse response: %w", err)
   147  	}
   148  	return signJWTResp.SignedJWT, nil
   149  }
   150  
   151  func (u userTokenProvider) exchangeToken(ctx context.Context, signedJWT string) (*auth.Token, error) {
   152  	v := url.Values{}
   153  	v.Set("grant_type", "assertion")
   154  	v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
   155  	v.Set("assertion", signedJWT)
   156  	req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/token", oauth2Endpoint), strings.NewReader(v.Encode()))
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	rawResp, err := u.client.Do(req)
   161  	if err != nil {
   162  		return nil, fmt.Errorf("impersonate: unable to exchange token: %w", err)
   163  	}
   164  	body, err := internal.ReadAll(rawResp.Body)
   165  	if err != nil {
   166  		return nil, fmt.Errorf("impersonate: unable to read body: %w", err)
   167  	}
   168  	if c := rawResp.StatusCode; c < 200 || c > 299 {
   169  		return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
   170  	}
   171  
   172  	var tokenResp exchangeTokenResponse
   173  	if err := json.Unmarshal(body, &tokenResp); err != nil {
   174  		return nil, fmt.Errorf("impersonate: unable to parse response: %w", err)
   175  	}
   176  
   177  	return &auth.Token{
   178  		Value:  tokenResp.AccessToken,
   179  		Type:   tokenResp.TokenType,
   180  		Expiry: time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
   181  	}, nil
   182  }
   183  

View as plain text