1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package impersonate
16
17 import (
18 "bytes"
19 "context"
20 "encoding/json"
21 "fmt"
22 "net/http"
23 "net/url"
24 "strings"
25 "time"
26
27 "cloud.google.com/go/auth"
28 "cloud.google.com/go/auth/internal"
29 )
30
31
32
33 func user(opts *CredentialsOptions, client *http.Client, lifetime time.Duration, isStaticToken bool) (auth.TokenProvider, error) {
34 u := userTokenProvider{
35 client: client,
36 targetPrincipal: opts.TargetPrincipal,
37 subject: opts.Subject,
38 lifetime: lifetime,
39 }
40 u.delegates = make([]string, len(opts.Delegates))
41 for i, v := range opts.Delegates {
42 u.delegates[i] = formatIAMServiceAccountName(v)
43 }
44 u.scopes = make([]string, len(opts.Scopes))
45 copy(u.scopes, opts.Scopes)
46 var tpo *auth.CachedTokenProviderOptions
47 if isStaticToken {
48 tpo = &auth.CachedTokenProviderOptions{
49 DisableAutoRefresh: true,
50 }
51 }
52 return auth.NewCachedTokenProvider(u, tpo), nil
53 }
54
55 type claimSet struct {
56 Iss string `json:"iss"`
57 Scope string `json:"scope,omitempty"`
58 Sub string `json:"sub,omitempty"`
59 Aud string `json:"aud"`
60 Iat int64 `json:"iat"`
61 Exp int64 `json:"exp"`
62 }
63
64 type signJWTRequest struct {
65 Payload string `json:"payload"`
66 Delegates []string `json:"delegates,omitempty"`
67 }
68
69 type signJWTResponse struct {
70
71 KeyID string `json:"keyId"`
72
73
74
75 SignedJWT string `json:"signedJwt"`
76 }
77
78 type exchangeTokenResponse struct {
79 AccessToken string `json:"access_token"`
80 TokenType string `json:"token_type"`
81 ExpiresIn int64 `json:"expires_in"`
82 }
83
84 type userTokenProvider struct {
85 client *http.Client
86
87 targetPrincipal string
88 subject string
89 scopes []string
90 lifetime time.Duration
91 delegates []string
92 }
93
94 func (u userTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
95 signedJWT, err := u.signJWT()
96 if err != nil {
97 return nil, err
98 }
99 return u.exchangeToken(ctx, signedJWT)
100 }
101
102 func (u userTokenProvider) signJWT() (string, error) {
103 now := time.Now()
104 exp := now.Add(u.lifetime)
105 claims := claimSet{
106 Iss: u.targetPrincipal,
107 Scope: strings.Join(u.scopes, " "),
108 Sub: u.subject,
109 Aud: fmt.Sprintf("%s/token", oauth2Endpoint),
110 Iat: now.Unix(),
111 Exp: exp.Unix(),
112 }
113 payloadBytes, err := json.Marshal(claims)
114 if err != nil {
115 return "", fmt.Errorf("impersonate: unable to marshal claims: %w", err)
116 }
117 signJWTReq := signJWTRequest{
118 Payload: string(payloadBytes),
119 Delegates: u.delegates,
120 }
121
122 bodyBytes, err := json.Marshal(signJWTReq)
123 if err != nil {
124 return "", fmt.Errorf("impersonate: unable to marshal request: %w", err)
125 }
126 reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentialsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
127 req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
128 if err != nil {
129 return "", fmt.Errorf("impersonate: unable to create request: %w", err)
130 }
131 req.Header.Set("Content-Type", "application/json")
132 rawResp, err := u.client.Do(req)
133 if err != nil {
134 return "", fmt.Errorf("impersonate: unable to sign JWT: %w", err)
135 }
136 body, err := internal.ReadAll(rawResp.Body)
137 if err != nil {
138 return "", fmt.Errorf("impersonate: unable to read body: %w", err)
139 }
140 if c := rawResp.StatusCode; c < 200 || c > 299 {
141 return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
142 }
143
144 var signJWTResp signJWTResponse
145 if err := json.Unmarshal(body, &signJWTResp); err != nil {
146 return "", fmt.Errorf("impersonate: unable to parse response: %w", err)
147 }
148 return signJWTResp.SignedJWT, nil
149 }
150
151 func (u userTokenProvider) exchangeToken(ctx context.Context, signedJWT string) (*auth.Token, error) {
152 v := url.Values{}
153 v.Set("grant_type", "assertion")
154 v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
155 v.Set("assertion", signedJWT)
156 req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/token", oauth2Endpoint), strings.NewReader(v.Encode()))
157 if err != nil {
158 return nil, err
159 }
160 rawResp, err := u.client.Do(req)
161 if err != nil {
162 return nil, fmt.Errorf("impersonate: unable to exchange token: %w", err)
163 }
164 body, err := internal.ReadAll(rawResp.Body)
165 if err != nil {
166 return nil, fmt.Errorf("impersonate: unable to read body: %w", err)
167 }
168 if c := rawResp.StatusCode; c < 200 || c > 299 {
169 return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
170 }
171
172 var tokenResp exchangeTokenResponse
173 if err := json.Unmarshal(body, &tokenResp); err != nil {
174 return nil, fmt.Errorf("impersonate: unable to parse response: %w", err)
175 }
176
177 return &auth.Token{
178 Value: tokenResp.AccessToken,
179 Type: tokenResp.TokenType,
180 Expiry: time.Now().Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
181 }, nil
182 }
183
View as plain text