...

Source file src/google.golang.org/api/idtoken/validate.go

Documentation: google.golang.org/api/idtoken

     1  // Copyright 2020 Google LLC.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package idtoken
     6  
     7  import (
     8  	"context"
     9  	"crypto"
    10  	"crypto/ecdsa"
    11  	"crypto/elliptic"
    12  	"crypto/rsa"
    13  	"crypto/sha256"
    14  	"encoding/base64"
    15  	"encoding/json"
    16  	"fmt"
    17  	"math/big"
    18  	"net/http"
    19  	"strings"
    20  	"time"
    21  
    22  	"google.golang.org/api/option"
    23  	"google.golang.org/api/option/internaloption"
    24  	htransport "google.golang.org/api/transport/http"
    25  )
    26  
    27  const (
    28  	es256KeySize      int    = 32
    29  	googleIAPCertsURL string = "https://www.gstatic.com/iap/verify/public_key-jwk"
    30  	googleSACertsURL  string = "https://www.googleapis.com/oauth2/v3/certs"
    31  )
    32  
    33  var (
    34  	defaultValidator = &Validator{client: newCachingClient(http.DefaultClient)}
    35  	// now aliases time.Now for testing.
    36  	now = time.Now
    37  )
    38  
    39  func defaultValidatorOpts() []ClientOption {
    40  	return []ClientOption{
    41  		internaloption.WithDefaultScopes("https://www.googleapis.com/auth/cloud-platform"),
    42  		option.WithoutAuthentication(),
    43  	}
    44  }
    45  
    46  // Payload represents a decoded payload of an ID Token.
    47  type Payload struct {
    48  	Issuer   string                 `json:"iss"`
    49  	Audience string                 `json:"aud"`
    50  	Expires  int64                  `json:"exp"`
    51  	IssuedAt int64                  `json:"iat"`
    52  	Subject  string                 `json:"sub,omitempty"`
    53  	Claims   map[string]interface{} `json:"-"`
    54  }
    55  
    56  // jwt represents the segments of a jwt and exposes convenience methods for
    57  // working with the different segments.
    58  type jwt struct {
    59  	header    string
    60  	payload   string
    61  	signature string
    62  }
    63  
    64  // jwtHeader represents a parted jwt's header segment.
    65  type jwtHeader struct {
    66  	Algorithm string `json:"alg"`
    67  	Type      string `json:"typ"`
    68  	KeyID     string `json:"kid"`
    69  }
    70  
    71  // certResponse represents a list jwks. It is the format returned from known
    72  // Google cert endpoints.
    73  type certResponse struct {
    74  	Keys []jwk `json:"keys"`
    75  }
    76  
    77  // jwk is a simplified representation of a standard jwk. It only includes the
    78  // fields used by Google's cert endpoints.
    79  type jwk struct {
    80  	Alg string `json:"alg"`
    81  	Crv string `json:"crv"`
    82  	Kid string `json:"kid"`
    83  	Kty string `json:"kty"`
    84  	Use string `json:"use"`
    85  	E   string `json:"e"`
    86  	N   string `json:"n"`
    87  	X   string `json:"x"`
    88  	Y   string `json:"y"`
    89  }
    90  
    91  // Validator provides a way to validate Google ID Tokens with a user provided
    92  // http.Client.
    93  type Validator struct {
    94  	client *cachingClient
    95  }
    96  
    97  // NewValidator creates a Validator that uses the options provided to configure
    98  // a the internal http.Client that will be used to make requests to fetch JWKs.
    99  func NewValidator(ctx context.Context, opts ...ClientOption) (*Validator, error) {
   100  	opts = append(defaultValidatorOpts(), opts...)
   101  	client, _, err := htransport.NewClient(ctx, opts...)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	return &Validator{client: newCachingClient(client)}, nil
   106  }
   107  
   108  // Validate is used to validate the provided idToken with a known Google cert
   109  // URL. If audience is not empty the audience claim of the Token is validated.
   110  // Upon successful validation a parsed token Payload is returned allowing the
   111  // caller to validate any additional claims.
   112  func (v *Validator) Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
   113  	return v.validate(ctx, idToken, audience)
   114  }
   115  
   116  // Validate is used to validate the provided idToken with a known Google cert
   117  // URL. If audience is not empty the audience claim of the Token is validated.
   118  // Upon successful validation a parsed token Payload is returned allowing the
   119  // caller to validate any additional claims.
   120  func Validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
   121  	// TODO(codyoss): consider adding a check revoked version of the api. See: https://pkg.go.dev/firebase.google.com/go/auth?tab=doc#Client.VerifyIDTokenAndCheckRevoked
   122  	return defaultValidator.validate(ctx, idToken, audience)
   123  }
   124  
   125  // ParsePayload parses the given token and returns its payload.
   126  //
   127  // Warning: This function does not validate the token prior to parsing it.
   128  //
   129  // ParsePayload is primarily meant to be used to inspect a token's payload. This is
   130  // useful when validation fails and the payload needs to be inspected.
   131  //
   132  // Note: A successful Validate() invocation with the same token will return an
   133  // identical payload.
   134  func ParsePayload(idToken string) (*Payload, error) {
   135  	jwt, err := parseJWT(idToken)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return jwt.parsedPayload()
   140  }
   141  
   142  func (v *Validator) validate(ctx context.Context, idToken string, audience string) (*Payload, error) {
   143  	jwt, err := parseJWT(idToken)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	header, err := jwt.parsedHeader()
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	payload, err := jwt.parsedPayload()
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	sig, err := jwt.decodedSignature()
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	if audience != "" && payload.Audience != audience {
   161  		return nil, fmt.Errorf("idtoken: audience provided does not match aud claim in the JWT")
   162  	}
   163  
   164  	if now().Unix() > payload.Expires {
   165  		return nil, fmt.Errorf("idtoken: token expired: now=%v, expires=%v", now().Unix(), payload.Expires)
   166  	}
   167  
   168  	switch header.Algorithm {
   169  	case "RS256":
   170  		if err := v.validateRS256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
   171  			return nil, err
   172  		}
   173  	case "ES256":
   174  		if err := v.validateES256(ctx, header.KeyID, jwt.hashedContent(), sig); err != nil {
   175  			return nil, err
   176  		}
   177  	default:
   178  		return nil, fmt.Errorf("idtoken: expected JWT signed with RS256 or ES256 but found %q", header.Algorithm)
   179  	}
   180  
   181  	return payload, nil
   182  }
   183  
   184  func (v *Validator) validateRS256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
   185  	certResp, err := v.client.getCert(ctx, googleSACertsURL)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	j, err := findMatchingKey(certResp, keyID)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	dn, err := decode(j.N)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	de, err := decode(j.E)
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	pk := &rsa.PublicKey{
   203  		N: new(big.Int).SetBytes(dn),
   204  		E: int(new(big.Int).SetBytes(de).Int64()),
   205  	}
   206  	return rsa.VerifyPKCS1v15(pk, crypto.SHA256, hashedContent, sig)
   207  }
   208  
   209  func (v *Validator) validateES256(ctx context.Context, keyID string, hashedContent []byte, sig []byte) error {
   210  	certResp, err := v.client.getCert(ctx, googleIAPCertsURL)
   211  	if err != nil {
   212  		return err
   213  	}
   214  	j, err := findMatchingKey(certResp, keyID)
   215  	if err != nil {
   216  		return err
   217  	}
   218  	dx, err := decode(j.X)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	dy, err := decode(j.Y)
   223  	if err != nil {
   224  		return err
   225  	}
   226  
   227  	pk := &ecdsa.PublicKey{
   228  		Curve: elliptic.P256(),
   229  		X:     new(big.Int).SetBytes(dx),
   230  		Y:     new(big.Int).SetBytes(dy),
   231  	}
   232  	r := big.NewInt(0).SetBytes(sig[:es256KeySize])
   233  	s := big.NewInt(0).SetBytes(sig[es256KeySize:])
   234  	if valid := ecdsa.Verify(pk, hashedContent, r, s); !valid {
   235  		return fmt.Errorf("idtoken: ES256 signature not valid")
   236  	}
   237  	return nil
   238  }
   239  
   240  func findMatchingKey(response *certResponse, keyID string) (*jwk, error) {
   241  	if response == nil {
   242  		return nil, fmt.Errorf("idtoken: cert response is nil")
   243  	}
   244  	for _, v := range response.Keys {
   245  		if v.Kid == keyID {
   246  			return &v, nil
   247  		}
   248  	}
   249  	return nil, fmt.Errorf("idtoken: could not find matching cert keyId for the token provided")
   250  }
   251  
   252  func parseJWT(idToken string) (*jwt, error) {
   253  	segments := strings.Split(idToken, ".")
   254  	if len(segments) != 3 {
   255  		return nil, fmt.Errorf("idtoken: invalid token, token must have three segments; found %d", len(segments))
   256  	}
   257  	return &jwt{
   258  		header:    segments[0],
   259  		payload:   segments[1],
   260  		signature: segments[2],
   261  	}, nil
   262  }
   263  
   264  // decodedHeader base64 decodes the header segment.
   265  func (j *jwt) decodedHeader() ([]byte, error) {
   266  	dh, err := decode(j.header)
   267  	if err != nil {
   268  		return nil, fmt.Errorf("idtoken: unable to decode JWT header: %v", err)
   269  	}
   270  	return dh, nil
   271  }
   272  
   273  // decodedPayload base64 payload the header segment.
   274  func (j *jwt) decodedPayload() ([]byte, error) {
   275  	p, err := decode(j.payload)
   276  	if err != nil {
   277  		return nil, fmt.Errorf("idtoken: unable to decode JWT payload: %v", err)
   278  	}
   279  	return p, nil
   280  }
   281  
   282  // decodedPayload base64 payload the header segment.
   283  func (j *jwt) decodedSignature() ([]byte, error) {
   284  	p, err := decode(j.signature)
   285  	if err != nil {
   286  		return nil, fmt.Errorf("idtoken: unable to decode JWT signature: %v", err)
   287  	}
   288  	return p, nil
   289  }
   290  
   291  // parsedHeader returns a struct representing a JWT header.
   292  func (j *jwt) parsedHeader() (jwtHeader, error) {
   293  	var h jwtHeader
   294  	dh, err := j.decodedHeader()
   295  	if err != nil {
   296  		return h, err
   297  	}
   298  	err = json.Unmarshal(dh, &h)
   299  	if err != nil {
   300  		return h, fmt.Errorf("idtoken: unable to unmarshal JWT header: %v", err)
   301  	}
   302  	return h, nil
   303  }
   304  
   305  // parsedPayload returns a struct representing a JWT payload.
   306  func (j *jwt) parsedPayload() (*Payload, error) {
   307  	var p Payload
   308  	dp, err := j.decodedPayload()
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  	if err := json.Unmarshal(dp, &p); err != nil {
   313  		return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload: %v", err)
   314  	}
   315  	if err := json.Unmarshal(dp, &p.Claims); err != nil {
   316  		return nil, fmt.Errorf("idtoken: unable to unmarshal JWT payload claims: %v", err)
   317  	}
   318  	return &p, nil
   319  }
   320  
   321  // hashedContent gets the SHA256 checksum for verification of the JWT.
   322  func (j *jwt) hashedContent() []byte {
   323  	signedContent := j.header + "." + j.payload
   324  	hashed := sha256.Sum256([]byte(signedContent))
   325  	return hashed[:]
   326  }
   327  
   328  func (j *jwt) String() string {
   329  	return fmt.Sprintf("%s.%s.%s", j.header, j.payload, j.signature)
   330  }
   331  
   332  func decode(s string) ([]byte, error) {
   333  	return base64.RawURLEncoding.DecodeString(s)
   334  }
   335  

View as plain text