1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package gdch
16
17 import (
18 "context"
19 "crypto/rsa"
20 "crypto/tls"
21 "crypto/x509"
22 "encoding/json"
23 "errors"
24 "fmt"
25 "net/http"
26 "net/url"
27 "os"
28 "time"
29
30 "cloud.google.com/go/auth"
31 "cloud.google.com/go/auth/internal"
32 "cloud.google.com/go/auth/internal/credsfile"
33 "cloud.google.com/go/auth/internal/jwt"
34 )
35
36 const (
37
38 GrantType = "urn:ietf:params:oauth:token-type:token-exchange"
39 requestTokenType = "urn:ietf:params:oauth:token-type:access_token"
40 subjectTokenType = "urn:k8s:params:oauth:token-type:serviceaccount"
41 )
42
43 var (
44 gdchSupportFormatVersions map[string]bool = map[string]bool{
45 "1": true,
46 }
47 )
48
49
50 type Options struct {
51 STSAudience string
52 Client *http.Client
53 }
54
55
56
57 func NewTokenProvider(f *credsfile.GDCHServiceAccountFile, o *Options) (auth.TokenProvider, error) {
58 if !gdchSupportFormatVersions[f.FormatVersion] {
59 return nil, fmt.Errorf("credentials: unsupported gdch_service_account format %q", f.FormatVersion)
60 }
61 if o.STSAudience == "" {
62 return nil, errors.New("credentials: STSAudience must be set for the GDCH auth flows")
63 }
64 pk, err := internal.ParseKey([]byte(f.PrivateKey))
65 if err != nil {
66 return nil, err
67 }
68 certPool, err := loadCertPool(f.CertPath)
69 if err != nil {
70 return nil, err
71 }
72
73 tp := gdchProvider{
74 serviceIdentity: fmt.Sprintf("system:serviceaccount:%s:%s", f.Project, f.Name),
75 tokenURL: f.TokenURL,
76 aud: o.STSAudience,
77 pk: pk,
78 pkID: f.PrivateKeyID,
79 certPool: certPool,
80 client: o.Client,
81 }
82 return tp, nil
83 }
84
85 func loadCertPool(path string) (*x509.CertPool, error) {
86 pool := x509.NewCertPool()
87 pem, err := os.ReadFile(path)
88 if err != nil {
89 return nil, fmt.Errorf("credentials: failed to read certificate: %w", err)
90 }
91 pool.AppendCertsFromPEM(pem)
92 return pool, nil
93 }
94
95 type gdchProvider struct {
96 serviceIdentity string
97 tokenURL string
98 aud string
99 pk *rsa.PrivateKey
100 pkID string
101 certPool *x509.CertPool
102
103 client *http.Client
104 }
105
106 func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
107 addCertToTransport(g.client, g.certPool)
108 iat := time.Now()
109 exp := iat.Add(time.Hour)
110 claims := jwt.Claims{
111 Iss: g.serviceIdentity,
112 Sub: g.serviceIdentity,
113 Aud: g.tokenURL,
114 Iat: iat.Unix(),
115 Exp: exp.Unix(),
116 }
117 h := jwt.Header{
118 Algorithm: jwt.HeaderAlgRSA256,
119 Type: jwt.HeaderType,
120 KeyID: string(g.pkID),
121 }
122 payload, err := jwt.EncodeJWS(&h, &claims, g.pk)
123 if err != nil {
124 return nil, err
125 }
126 v := url.Values{}
127 v.Set("grant_type", GrantType)
128 v.Set("audience", g.aud)
129 v.Set("requested_token_type", requestTokenType)
130 v.Set("subject_token", payload)
131 v.Set("subject_token_type", subjectTokenType)
132 resp, err := g.client.PostForm(g.tokenURL, v)
133 if err != nil {
134 return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
135 }
136 defer resp.Body.Close()
137 body, err := internal.ReadAll(resp.Body)
138 if err != nil {
139 return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
140 }
141 if c := resp.StatusCode; c < http.StatusOK || c > http.StatusMultipleChoices {
142 return nil, &auth.Error{
143 Response: resp,
144 Body: body,
145 }
146 }
147
148 var tokenRes struct {
149 AccessToken string `json:"access_token"`
150 TokenType string `json:"token_type"`
151 ExpiresIn int64 `json:"expires_in"`
152 }
153 if err := json.Unmarshal(body, &tokenRes); err != nil {
154 return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
155 }
156 token := &auth.Token{
157 Value: tokenRes.AccessToken,
158 Type: tokenRes.TokenType,
159 }
160 raw := make(map[string]interface{})
161 json.Unmarshal(body, &raw)
162 token.Metadata = raw
163
164 if secs := tokenRes.ExpiresIn; secs > 0 {
165 token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
166 }
167 return token, nil
168 }
169
170
171
172
173
174 func addCertToTransport(hc *http.Client, certPool *x509.CertPool) {
175 trans, ok := hc.Transport.(*http.Transport)
176 if !ok {
177 trans = http.DefaultTransport.(*http.Transport).Clone()
178 }
179 trans.TLSClientConfig = &tls.Config{
180 RootCAs: certPool,
181 }
182 }
183
View as plain text