1
2
3
4
5 package idtoken
6
7 import (
8 "context"
9 "crypto"
10 "crypto/ecdsa"
11 "crypto/elliptic"
12 "crypto/rsa"
13 "crypto/sha256"
14 "encoding/base64"
15 "encoding/json"
16 "fmt"
17 "math/big"
18 "net/http"
19 "strings"
20 "time"
21
22 "google.golang.org/api/option"
23 "google.golang.org/api/option/internaloption"
24 htransport "google.golang.org/api/transport/http"
25 )
26
27 const (
28 es256KeySize int = 32
29 googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk"
30 googleSACertsURL string = "https://www.googleapis.com/oauth2/v3/certs"
31 )
32
33 var (
34 defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
35
36 now = time.Now
37 )
38
39 func defaultValidatorOpts() []ClientOption {
40 return []ClientOption{
41 internaloption.WithDefaultScopes("https://www.googleapis.com/auth/cloud-platform"),
42 option.WithoutAuthentication(),
43 }
44 }
45
46
47 type Payload struct {
48 Issuer string `json:"iss"`
49 Audience string `json:"aud"`
50 Expires int64 `json:"exp"`
51 IssuedAt int64 `json:"iat"`
52 Subject string `json:"sub,omitempty"`
53 Claims map[string]interface{} `json:"-"`
54 }
55
56
57
58 type jwt struct {
59 header string
60 payload string
61 signature string
62 }
63
64
65 type jwtHeader struct {
66 Algorithm string `json:"alg"`
67 Type string `json:"typ"`
68 KeyID string `json:"kid"`
69 }
70
71
72
73 type certResponse struct {
74 Keys []jwk `json:"keys"`
75 }
76
77
78
79 type jwk struct {
80 Alg string `json:"alg"`
81 Crv string `json:"crv"`
82 Kid string `json:"kid"`
83 Kty string `json:"kty"`
84 Use string `json:"use"`
85 E string `json:"e"`
86 N string `json:"n"`
87 X string `json:"x"`
88 Y string `json:"y"`
89 }
90
91
92
93 type Validator struct {
94 client *cachingClient
95 }
96
97
98
99 func NewValidator(ctx context.Context, opts ...ClientOption) (*Validator, error) {
100 opts = append(defaultValidatorOpts(), opts...)
101 client, _, err := htransport.NewClient(ctx, opts...)
102 if err != nil {
103 return nil, err
104 }
105 return &Validator{client: newCachingClient(client)}, nil
106 }
107
108
109
110
111
112 func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
113 return v.validate(ctx, idToken, audience)
114 }
115
116
117
118
119
120 func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
121
122 return defaultValidator.validate(ctx, idToken, audience)
123 }
124
125
126
127
128
129
130
131
132
133
134 func ParsePayload(idToken string) (*Payload, error) {
135 jwt, err := parseJWT(idToken)
136 if err != nil {
137 return nil, err
138 }
139 return jwt.parsedPayload()
140 }
141
142 func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
143 jwt, err := parseJWT(idToken)
144 if err != nil {
145 return nil, err
146 }
147 header, err := jwt.parsedHeader()
148 if err != nil {
149 return nil, err
150 }
151 payload, err := jwt.parsedPayload()
152 if err != nil {
153 return nil, err
154 }
155 sig, err := jwt.decodedSignature()
156 if err != nil {
157 return nil, err
158 }
159
160 if audience != "" && payload.Audience != audience {
161 return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
162 }
163
164 if now().Unix() > payload.Expires {
165 return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires)
166 }
167
168 switch header.Algorithm {
169 case "RS256":
170 if err := v.validateRS256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
171 return nil, err
172 }
173 case "ES256":
174 if err := v.validateES256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
175 return nil, err
176 }
177 default:
178 return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm)
179 }
180
181 return payload, nil
182 }
183
184 func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
185 certResp, err := v.client.getCert(ctx, googleSACertsURL)
186 if err != nil {
187 return err
188 }
189 j, err := findMatchingKey(certResp, keyID)
190 if err != nil {
191 return err
192 }
193 dn, err := decode(j.N)
194 if err != nil {
195 return err
196 }
197 de, err := decode(j.E)
198 if err != nil {
199 return err
200 }
201
202 pk := &rsa.PublicKey{
203 N: new(big.Int).SetBytes(dn),
204 E: int(new(big.Int).SetBytes(de).Int64()),
205 }
206 return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig)
207 }
208
209 func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
210 certResp, err := v.client.getCert(ctx, googleIAPCertsURL)
211 if err != nil {
212 return err
213 }
214 j, err := findMatchingKey(certResp, keyID)
215 if err != nil {
216 return err
217 }
218 dx, err := decode(j.X)
219 if err != nil {
220 return err
221 }
222 dy, err := decode(j.Y)
223 if err != nil {
224 return err
225 }
226
227 pk := &ecdsa.PublicKey{
228 Curve: elliptic.P256(),
229 X: new(big.Int).SetBytes(dx),
230 Y: new(big.Int).SetBytes(dy),
231 }
232 r := big.NewInt(0).SetBytes(sig[:es256KeySize])
233 s := big.NewInt(0).SetBytes(sig[es256KeySize:])
234 if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid {
235 return fmt.Errorf("idtoken: ES256 signature not valid")
236 }
237 return nil
238 }
239
240 func findMatchingKey(response *certResponse, keyID string) (*jwk, error) {
241 if response == nil {
242 return nil, fmt.Errorf("idtoken: cert response is nil")
243 }
244 for _, v := range response.Keys {
245 if v.Kid == keyID {
246 return &v, nil
247 }
248 }
249 return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided")
250 }
251
252 func parseJWT(idToken string) (*jwt, error) {
253 segments := strings.Split(idToken, ".")
254 if len(segments) != 3 {
255 return nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments))
256 }
257 return &jwt{
258 header: segments[0],
259 payload: segments[1],
260 signature: segments[2],
261 }, nil
262 }
263
264
265 func (j *jwt) decodedHeader() ([]byte, error) {
266 dh, err := decode(j.header)
267 if err != nil {
268 return nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err)
269 }
270 return dh, nil
271 }
272
273
274 func (j *jwt) decodedPayload() ([]byte, error) {
275 p, err := decode(j.payload)
276 if err != nil {
277 return nil, fmt.Errorf("idtoken: unable to decode JWT payload: %v", err)
278 }
279 return p, nil
280 }
281
282
283 func (j *jwt) decodedSignature() ([]byte, error) {
284 p, err := decode(j.signature)
285 if err != nil {
286 return nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err)
287 }
288 return p, nil
289 }
290
291
292 func (j *jwt) parsedHeader() (jwtHeader, error) {
293 var h jwtHeader
294 dh, err := j.decodedHeader()
295 if err != nil {
296 return h, err
297 }
298 err = json.Unmarshal(dh, &h)
299 if err != nil {
300 return h, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err)
301 }
302 return h, nil
303 }
304
305
306 func (j *jwt) parsedPayload() (*Payload, error) {
307 var p Payload
308 dp, err := j.decodedPayload()
309 if err != nil {
310 return nil, err
311 }
312 if err := json.Unmarshal(dp, &p); err != nil {
313 return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err)
314 }
315 if err := json.Unmarshal(dp, &p.Claims); err != nil {
316 return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err)
317 }
318 return &p, nil
319 }
320
321
322 func (j *jwt) hashedContent() []byte {
323 signedContent := j.header + "." + j.payload
324 hashed := sha256.Sum256([]byte(signedContent))
325 return hashed[:]
326 }
327
328 func (j *jwt) String() string {
329 return fmt.Sprintf("%s.%s.%s", j.header, j.payload, j.signature)
330 }
331
332 func decode(s string) ([]byte, error) {
333 return base64.RawURLEncoding.DecodeString(s)
334 }
335
View as plain text