...

Source file src/cloud.google.com/go/auth/credentials/idtoken/validate.go

Documentation: cloud.google.com/go/auth/credentials/idtoken

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package idtoken
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"crypto/rsa"
    23  	"crypto/sha256"
    24  	"encoding/base64"
    25  	"encoding/json"
    26  	"fmt"
    27  	"math/big"
    28  	"net/http"
    29  	"strings"
    30  	"time"
    31  
    32  	"cloud.google.com/go/auth/internal"
    33  	"cloud.google.com/go/auth/internal/jwt"
    34  )
    35  
    36  const (
    37  	es256KeySize      int    = 32
    38  	googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk"
    39  	googleSACertsURL  string = "https://www.googleapis.com/oauth2/v3/certs"
    40  )
    41  
    42  var (
    43  	defaultValidator = &Validator{client: newCachingClient(internal.CloneDefaultClient())}
    44  	// now aliases time.Now for testing.
    45  	now = time.Now
    46  )
    47  
    48  // certResponse represents a list jwks. It is the format returned from known
    49  // Google cert endpoints.
    50  type certResponse struct {
    51  	Keys []jwk `json:"keys"`
    52  }
    53  
    54  // jwk is a simplified representation of a standard jwk. It only includes the
    55  // fields used by Google's cert endpoints.
    56  type jwk struct {
    57  	Alg string `json:"alg"`
    58  	Crv string `json:"crv"`
    59  	Kid string `json:"kid"`
    60  	Kty string `json:"kty"`
    61  	Use string `json:"use"`
    62  	E   string `json:"e"`
    63  	N   string `json:"n"`
    64  	X   string `json:"x"`
    65  	Y   string `json:"y"`
    66  }
    67  
    68  // Validator provides a way to validate Google ID Tokens
    69  type Validator struct {
    70  	client *cachingClient
    71  }
    72  
    73  // ValidatorOptions provides a way to configure a [Validator].
    74  type ValidatorOptions struct {
    75  	// Client used to make requests to the certs URL. Optional.
    76  	Client *http.Client
    77  }
    78  
    79  // NewValidator creates a Validator that uses the options provided to configure
    80  // a the internal http.Client that will be used to make requests to fetch JWKs.
    81  func NewValidator(opts *ValidatorOptions) (*Validator, error) {
    82  	var client *http.Client
    83  	if opts != nil && opts.Client != nil {
    84  		client = opts.Client
    85  	} else {
    86  		client = internal.CloneDefaultClient()
    87  	}
    88  	return &Validator{client: newCachingClient(client)}, nil
    89  }
    90  
    91  // Validate is used to validate the provided idToken with a known Google cert
    92  // URL. If audience is not empty the audience claim of the Token is validated.
    93  // Upon successful validation a parsed token Payload is returned allowing the
    94  // caller to validate any additional claims.
    95  func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
    96  	return v.validate(ctx, idToken, audience)
    97  }
    98  
    99  // Validate is used to validate the provided idToken with a known Google cert
   100  // URL. If audience is not empty the audience claim of the Token is validated.
   101  // Upon successful validation a parsed token Payload is returned allowing the
   102  // caller to validate any additional claims.
   103  func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
   104  	return defaultValidator.validate(ctx, idToken, audience)
   105  }
   106  
   107  // ParsePayload parses the given token and returns its payload.
   108  //
   109  // Warning: This function does not validate the token prior to parsing it.
   110  //
   111  // ParsePayload is primarily meant to be used to inspect a token's payload. This is
   112  // useful when validation fails and the payload needs to be inspected.
   113  //
   114  // Note: A successful Validate() invocation with the same token will return an
   115  // identical payload.
   116  func ParsePayload(idToken string) (*Payload, error) {
   117  	_, payload, _, err := parseToken(idToken)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	return payload, nil
   122  }
   123  
   124  func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
   125  	header, payload, sig, err := parseToken(idToken)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	if audience != "" && payload.Audience != audience {
   131  		return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
   132  	}
   133  
   134  	if now().Unix() > payload.Expires {
   135  		return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires)
   136  	}
   137  	hashedContent := hashHeaderPayload(idToken)
   138  	switch header.Algorithm {
   139  	case jwt.HeaderAlgRSA256:
   140  		if err := v.validateRS256(ctx, header.KeyID, hashedContent, sig); err != nil {
   141  			return nil, err
   142  		}
   143  	case "ES256":
   144  		if err := v.validateES256(ctx, header.KeyID, hashedContent, sig); err != nil {
   145  			return nil, err
   146  		}
   147  	default:
   148  		return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm)
   149  	}
   150  
   151  	return payload, nil
   152  }
   153  
   154  func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
   155  	certResp, err := v.client.getCert(ctx, googleSACertsURL)
   156  	if err != nil {
   157  		return err
   158  	}
   159  	j, err := findMatchingKey(certResp, keyID)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	dn, err := decode(j.N)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	de, err := decode(j.E)
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	pk := &rsa.PublicKey{
   173  		N: new(big.Int).SetBytes(dn),
   174  		E: int(new(big.Int).SetBytes(de).Int64()),
   175  	}
   176  	return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig)
   177  }
   178  
   179  func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
   180  	certResp, err := v.client.getCert(ctx, googleIAPCertsURL)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	j, err := findMatchingKey(certResp, keyID)
   185  	if err != nil {
   186  		return err
   187  	}
   188  	dx, err := decode(j.X)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	dy, err := decode(j.Y)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	pk := &ecdsa.PublicKey{
   198  		Curve: elliptic.P256(),
   199  		X:     new(big.Int).SetBytes(dx),
   200  		Y:     new(big.Int).SetBytes(dy),
   201  	}
   202  	r := big.NewInt(0).SetBytes(sig[:es256KeySize])
   203  	s := big.NewInt(0).SetBytes(sig[es256KeySize:])
   204  	if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid {
   205  		return fmt.Errorf("idtoken: ES256 signature not valid")
   206  	}
   207  	return nil
   208  }
   209  
   210  func findMatchingKey(response *certResponse, keyID string) (*jwk, error) {
   211  	if response == nil {
   212  		return nil, fmt.Errorf("idtoken: cert response is nil")
   213  	}
   214  	for _, v := range response.Keys {
   215  		if v.Kid == keyID {
   216  			return &v, nil
   217  		}
   218  	}
   219  	return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided")
   220  }
   221  
   222  func parseToken(idToken string) (*jwt.Header, *Payload, []byte, error) {
   223  	segments := strings.Split(idToken, ".")
   224  	if len(segments) != 3 {
   225  		return nil, nil, nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments))
   226  	}
   227  	// Header
   228  	dh, err := decode(segments[0])
   229  	if err != nil {
   230  		return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err)
   231  	}
   232  	var header *jwt.Header
   233  	err = json.Unmarshal(dh, &header)
   234  	if err != nil {
   235  		return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err)
   236  	}
   237  
   238  	// Payload
   239  	dp, err := decode(segments[1])
   240  	if err != nil {
   241  		return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT claims: %v", err)
   242  	}
   243  	var payload *Payload
   244  	if err := json.Unmarshal(dp, &payload); err != nil {
   245  		return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err)
   246  	}
   247  	if err := json.Unmarshal(dp, &payload.Claims); err != nil {
   248  		return nil, nil, nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err)
   249  	}
   250  
   251  	// Signature
   252  	signature, err := decode(segments[2])
   253  	if err != nil {
   254  		return nil, nil, nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err)
   255  	}
   256  	return header, payload, signature, nil
   257  }
   258  
   259  // hashHeaderPayload gets the SHA256 checksum for verification of the JWT.
   260  func hashHeaderPayload(idtoken string) []byte {
   261  	// remove the sig from the token
   262  	content := idtoken[:strings.LastIndex(idtoken, ".")]
   263  	hashed := sha256.Sum256([]byte(content))
   264  	return hashed[:]
   265  }
   266  
   267  func decode(s string) ([]byte, error) {
   268  	return base64.RawURLEncoding.DecodeString(s)
   269  }
   270  

View as plain text