1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package idtoken
16
17 import (
18 "context"
19 "crypto"
20 "crypto/ecdsa"
21 "crypto/elliptic"
22 "crypto/rsa"
23 "crypto/sha256"
24 "encoding/base64"
25 "encoding/json"
26 "fmt"
27 "math/big"
28 "net/http"
29 "strings"
30 "time"
31
32 "cloud.google.com/go/auth/internal"
33 "cloud.google.com/go/auth/internal/jwt"
34 )
35
36 const (
37 es256KeySize int = 32
38 googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk"
39 googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs"
40 )
41
42 var (
43 defaultValidator = &Validator{client: newCachingClient(internal.CloneDefaultClient())}
44
45 now = time.Now
46 )
47
48
49
50 type certResponse struct {
51 Keys []jwk `json:"keys"`
52 }
53
54
55
56 type jwk struct {
57 Alg string `json:"alg"`
58 Crv string `json:"crv"`
59 Kid string `json:"kid"`
60 Kty string `json:"kty"`
61 Use string `json:"use"`
62 E string `json:"e"`
63 N string `json:"n"`
64 X string `json:"x"`
65 Y string `json:"y"`
66 }
67
68
69 type Validator struct {
70 client *cachingClient
71 }
72
73
74 type ValidatorOptions struct {
75
76 Client *http.Client
77 }
78
79
80
81 func NewValidator(opts *ValidatorOptions) (*Validator, error) {
82 var client *http.Client
83 if opts != nil && opts.Client != nil {
84 client = opts.Client
85 } else {
86 client = internal.CloneDefaultClient()
87 }
88 return &Validator{client: newCachingClient(client)}, nil
89 }
90
91
92
93
94
95 func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
96 return v.validate(ctx, idToken, audience)
97 }
98
99
100
101
102
103 func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
104 return defaultValidator.validate(ctx, idToken, audience)
105 }
106
107
108
109
110
111
112
113
114
115
116 func ParsePayload(idToken string) (*Payload, error) {
117 _, payload, _, err := parseToken(idToken)
118 if err != nil {
119 return nil, err
120 }
121 return payload, nil
122 }
123
124 func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
125 header, payload, sig, err := parseToken(idToken)
126 if err != nil {
127 return nil, err
128 }
129
130 if audience != "" && payload.Audience != audience {
131 return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
132 }
133
134 if now().Unix() > payload.Expires {
135 return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires)
136 }
137 hashedContent := hashHeaderPayload(idToken)
138 switch header.Algorithm {
139 case jwt.HeaderAlgRSA256:
140 if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil {
141 return nil, err
142 }
143 case "ES256":
144 if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil {
145 return nil, err
146 }
147 default:
148 return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm)
149 }
150
151 return payload, nil
152 }
153
154 func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
155 certResp, err := v.client.getCert(ctx, googleSACertsURL)
156 if err != nil {
157 return err
158 }
159 j, err := findMatchingKey(certResp, keyID)
160 if err != nil {
161 return err
162 }
163 dn, err := decode(j.N)
164 if err != nil {
165 return err
166 }
167 de, err := decode(j.E)
168 if err != nil {
169 return err
170 }
171
172 pk := &rsa.PublicKey{
173 N: new(big.Int).SetBytes(dn),
174 E: int(new(big.Int).SetBytes(de).Int64()),
175 }
176 return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig)
177 }
178
179 func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
180 certResp, err := v.client.getCert(ctx, googleIAPCertsURL)
181 if err != nil {
182 return err
183 }
184 j, err := findMatchingKey(certResp, keyID)
185 if err != nil {
186 return err
187 }
188 dx, err := decode(j.X)
189 if err != nil {
190 return err
191 }
192 dy, err := decode(j.Y)
193 if err != nil {
194 return err
195 }
196
197 pk := &ecdsa.PublicKey{
198 Curve: elliptic.P256(),
199 X: new(big.Int).SetBytes(dx),
200 Y: new(big.Int).SetBytes(dy),
201 }
202 r := big.NewInt(0).SetBytes(sig[:es256KeySize])
203 s := big.NewInt(0).SetBytes(sig[es256KeySize:])
204 if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid {
205 return fmt.Errorf("idtoken: ES256 signature not valid")
206 }
207 return nil
208 }
209
210 func findMatchingKey(response *certResponse, keyID string) (*jwk, error) {
211 if response == nil {
212 return nil, fmt.Errorf("idtoken: cert response is nil")
213 }
214 for _, v := range response.Keys {
215 if v.Kid == keyID {
216 return &v, nil
217 }
218 }
219 return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided")
220 }
221
222 func parseToken(idToken string) (*jwt.Header, *Payload, []byte, error) {
223 segments := strings.Split(idToken, ".")
224 if len(segments) != 3 {
225 return nil, nil, nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments))
226 }
227
228 dh, err := decode(segments[0])
229 if err != nil {
230 return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err)
231 }
232 var header *jwt.Header
233 err = json.Unmarshal(dh, &header)
234 if err != nil {
235 return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err)
236 }
237
238
239 dp, err := decode(segments[1])
240 if err != nil {
241 return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT claims: %v", err)
242 }
243 var payload *Payload
244 if err := json.Unmarshal(dp, &payload); err != nil {
245 return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err)
246 }
247 if err := json.Unmarshal(dp, &payload.Claims); err != nil {
248 return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err)
249 }
250
251
252 signature, err := decode(segments[2])
253 if err != nil {
254 return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err)
255 }
256 return header, payload, signature, nil
257 }
258
259
260 func hashHeaderPayload(idtoken string) []byte {
261
262 content := idtoken[:strings.LastIndex(idtoken, ".")]
263 hashed := sha256.Sum256([]byte(content))
264 return hashed[:]
265 }
266
267 func decode(s string) ([]byte, error) {
268 return base64.RawURLEncoding.DecodeString(s)
269 }
270
View as plain text