...

Source file src/cloud.google.com/go/auth/internal/jwt/jwt.go

Documentation: cloud.google.com/go/auth/internal/jwt

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package jwt
    16  
    17  import (
    18  	"bytes"
    19  	"crypto"
    20  	"crypto/rand"
    21  	"crypto/rsa"
    22  	"crypto/sha256"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"strings"
    28  	"time"
    29  )
    30  
    31  const (
    32  	// HeaderAlgRSA256 is the RS256 [Header.Algorithm].
    33  	HeaderAlgRSA256 = "RS256"
    34  	// HeaderAlgES256 is the ES256 [Header.Algorithm].
    35  	HeaderAlgES256 = "ES256"
    36  	// HeaderType is the standard [Header.Type].
    37  	HeaderType = "JWT"
    38  )
    39  
    40  // Header represents a JWT header.
    41  type Header struct {
    42  	Algorithm string `json:"alg"`
    43  	Type      string `json:"typ"`
    44  	KeyID     string `json:"kid"`
    45  }
    46  
    47  func (h *Header) encode() (string, error) {
    48  	b, err := json.Marshal(h)
    49  	if err != nil {
    50  		return "", err
    51  	}
    52  	return base64.RawURLEncoding.EncodeToString(b), nil
    53  }
    54  
    55  // Claims represents the claims set of a JWT.
    56  type Claims struct {
    57  	// Iss is the issuer JWT claim.
    58  	Iss string `json:"iss"`
    59  	// Scope is the scope JWT claim.
    60  	Scope string `json:"scope,omitempty"`
    61  	// Exp is the expiry JWT claim. If unset, default is in one hour from now.
    62  	Exp int64 `json:"exp"`
    63  	// Iat is the subject issued at claim. If unset, default is now.
    64  	Iat int64 `json:"iat"`
    65  	// Aud is the audience JWT claim. Optional.
    66  	Aud string `json:"aud"`
    67  	// Sub is the subject JWT claim. Optional.
    68  	Sub string `json:"sub,omitempty"`
    69  	// AdditionalClaims contains any additional non-standard JWT claims. Optional.
    70  	AdditionalClaims map[string]interface{} `json:"-"`
    71  }
    72  
    73  func (c *Claims) encode() (string, error) {
    74  	// Compensate for skew
    75  	now := time.Now().Add(-10 * time.Second)
    76  	if c.Iat == 0 {
    77  		c.Iat = now.Unix()
    78  	}
    79  	if c.Exp == 0 {
    80  		c.Exp = now.Add(time.Hour).Unix()
    81  	}
    82  	if c.Exp < c.Iat {
    83  		return "", fmt.Errorf("jwt: invalid Exp = %d; must be later than Iat = %d", c.Exp, c.Iat)
    84  	}
    85  
    86  	b, err := json.Marshal(c)
    87  	if err != nil {
    88  		return "", err
    89  	}
    90  
    91  	if len(c.AdditionalClaims) == 0 {
    92  		return base64.RawURLEncoding.EncodeToString(b), nil
    93  	}
    94  
    95  	// Marshal private claim set and then append it to b.
    96  	prv, err := json.Marshal(c.AdditionalClaims)
    97  	if err != nil {
    98  		return "", fmt.Errorf("invalid map of additional claims %v: %w", c.AdditionalClaims, err)
    99  	}
   100  
   101  	// Concatenate public and private claim JSON objects.
   102  	if !bytes.HasSuffix(b, []byte{'}'}) {
   103  		return "", fmt.Errorf("invalid JSON %s", b)
   104  	}
   105  	if !bytes.HasPrefix(prv, []byte{'{'}) {
   106  		return "", fmt.Errorf("invalid JSON %s", prv)
   107  	}
   108  	b[len(b)-1] = ','         // Replace closing curly brace with a comma.
   109  	b = append(b, prv[1:]...) // Append private claims.
   110  	return base64.RawURLEncoding.EncodeToString(b), nil
   111  }
   112  
   113  // EncodeJWS encodes the data using the provided key as a JSON web signature.
   114  func EncodeJWS(header *Header, c *Claims, key *rsa.PrivateKey) (string, error) {
   115  	head, err := header.encode()
   116  	if err != nil {
   117  		return "", err
   118  	}
   119  	claims, err := c.encode()
   120  	if err != nil {
   121  		return "", err
   122  	}
   123  	ss := fmt.Sprintf("%s.%s", head, claims)
   124  	h := sha256.New()
   125  	h.Write([]byte(ss))
   126  	sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
   127  	if err != nil {
   128  		return "", err
   129  	}
   130  	return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
   131  }
   132  
   133  // DecodeJWS decodes a claim set from a JWS payload.
   134  func DecodeJWS(payload string) (*Claims, error) {
   135  	// decode returned id token to get expiry
   136  	s := strings.Split(payload, ".")
   137  	if len(s) < 2 {
   138  		return nil, errors.New("invalid token received")
   139  	}
   140  	decoded, err := base64.RawURLEncoding.DecodeString(s[1])
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	c := &Claims{}
   145  	if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c); err != nil {
   146  		return nil, err
   147  	}
   148  	if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c.AdditionalClaims); err != nil {
   149  		return nil, err
   150  	}
   151  	return c, err
   152  }
   153  
   154  // VerifyJWS tests whether the provided JWT token's signature was produced by
   155  // the private key associated with the provided public key.
   156  func VerifyJWS(token string, key *rsa.PublicKey) error {
   157  	parts := strings.Split(token, ".")
   158  	if len(parts) != 3 {
   159  		return errors.New("jwt: invalid token received, token must have 3 parts")
   160  	}
   161  
   162  	signedContent := parts[0] + "." + parts[1]
   163  	signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
   164  	if err != nil {
   165  		return err
   166  	}
   167  
   168  	h := sha256.New()
   169  	h.Write([]byte(signedContent))
   170  	return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
   171  }
   172  

View as plain text