1
2
3
4
5 package impersonate
6
7 import (
8 "bytes"
9 "context"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "net/url"
15 "strings"
16 "time"
17
18 "golang.org/x/oauth2"
19 )
20
21
22
23 func user(ctx context.Context, c CredentialsConfig, client *http.Client, lifetime time.Duration, isStaticToken bool) (oauth2.TokenSource, error) {
24 u := userTokenSource{
25 client: client,
26 targetPrincipal: c.TargetPrincipal,
27 subject: c.Subject,
28 lifetime: lifetime,
29 }
30 u.delegates = make([]string, len(c.Delegates))
31 for i, v := range c.Delegates {
32 u.delegates[i] = formatIAMServiceAccountName(v)
33 }
34 u.scopes = make([]string, len(c.Scopes))
35 copy(u.scopes, c.Scopes)
36 if isStaticToken {
37 tok, err := u.Token()
38 if err != nil {
39 return nil, err
40 }
41 return oauth2.StaticTokenSource(tok), nil
42 }
43 return oauth2.ReuseTokenSource(nil, u), nil
44 }
45
46 type claimSet struct {
47 Iss string `json:"iss"`
48 Scope string `json:"scope,omitempty"`
49 Sub string `json:"sub,omitempty"`
50 Aud string `json:"aud"`
51 Iat int64 `json:"iat"`
52 Exp int64 `json:"exp"`
53 }
54
55 type signJWTRequest struct {
56 Payload string `json:"payload"`
57 Delegates []string `json:"delegates,omitempty"`
58 }
59
60 type signJWTResponse struct {
61
62 KeyID string `json:"keyId"`
63
64
65
66 SignedJWT string `json:"signedJwt"`
67 }
68
69 type exchangeTokenResponse struct {
70 AccessToken string `json:"access_token"`
71 TokenType string `json:"token_type"`
72 ExpiresIn int64 `json:"expires_in"`
73 }
74
75 type userTokenSource struct {
76 client *http.Client
77
78 targetPrincipal string
79 subject string
80 scopes []string
81 lifetime time.Duration
82 delegates []string
83 }
84
85 func (u userTokenSource) Token() (*oauth2.Token, error) {
86 signedJWT, err := u.signJWT()
87 if err != nil {
88 return nil, err
89 }
90 return u.exchangeToken(signedJWT)
91 }
92
93 func (u userTokenSource) signJWT() (string, error) {
94 now := time.Now()
95 exp := now.Add(u.lifetime)
96 claims := claimSet{
97 Iss: u.targetPrincipal,
98 Scope: strings.Join(u.scopes, " "),
99 Sub: u.subject,
100 Aud: fmt.Sprintf("%s/token", oauth2Endpoint),
101 Iat: now.Unix(),
102 Exp: exp.Unix(),
103 }
104 payloadBytes, err := json.Marshal(claims)
105 if err != nil {
106 return "", fmt.Errorf("impersonate: unable to marshal claims: %v", err)
107 }
108 signJWTReq := signJWTRequest{
109 Payload: string(payloadBytes),
110 Delegates: u.delegates,
111 }
112
113 bodyBytes, err := json.Marshal(signJWTReq)
114 if err != nil {
115 return "", fmt.Errorf("impersonate: unable to marshal request: %v", err)
116 }
117 reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentailsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
118 req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
119 if err != nil {
120 return "", fmt.Errorf("impersonate: unable to create request: %v", err)
121 }
122 req.Header.Set("Content-Type", "application/json")
123 rawResp, err := u.client.Do(req)
124 if err != nil {
125 return "", fmt.Errorf("impersonate: unable to sign JWT: %v", err)
126 }
127 body, err := io.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
128 if err != nil {
129 return "", fmt.Errorf("impersonate: unable to read body: %v", err)
130 }
131 if c := rawResp.StatusCode; c < 200 || c > 299 {
132 return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
133 }
134
135 var signJWTResp signJWTResponse
136 if err := json.Unmarshal(body, &signJWTResp); err != nil {
137 return "", fmt.Errorf("impersonate: unable to parse response: %v", err)
138 }
139 return signJWTResp.SignedJWT, nil
140 }
141
142 func (u userTokenSource) exchangeToken(signedJWT string) (*oauth2.Token, error) {
143 now := time.Now()
144 v := url.Values{}
145 v.Set("grant_type", "assertion")
146 v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
147 v.Set("assertion", signedJWT)
148 rawResp, err := u.client.PostForm(fmt.Sprintf("%s/token", oauth2Endpoint), v)
149 if err != nil {
150 return nil, fmt.Errorf("impersonate: unable to exchange token: %v", err)
151 }
152 body, err := io.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
153 if err != nil {
154 return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
155 }
156 if c := rawResp.StatusCode; c < 200 || c > 299 {
157 return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
158 }
159
160 var tokenResp exchangeTokenResponse
161 if err := json.Unmarshal(body, &tokenResp); err != nil {
162 return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
163 }
164
165 return &oauth2.Token{
166 AccessToken: tokenResp.AccessToken,
167 TokenType: tokenResp.TokenType,
168 Expiry: now.Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
169 }, nil
170 }
171
View as plain text