...

Source file src/github.com/okta/okta-jwt-verifier-golang/jwtverifier.go

Documentation: github.com/okta/okta-jwt-verifier-golang

     1  /*******************************************************************************
     2   * Copyright 2018 - Present Okta, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *      http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   ******************************************************************************/
    16  
    17  package jwtverifier
    18  
    19  import (
    20  	"encoding/base64"
    21  	"encoding/json"
    22  	"fmt"
    23  	"net/http"
    24  	"regexp"
    25  	"strings"
    26  	"time"
    27  
    28  	"github.com/okta/okta-jwt-verifier-golang/adaptors"
    29  	"github.com/okta/okta-jwt-verifier-golang/adaptors/lestrratGoJwx"
    30  	"github.com/okta/okta-jwt-verifier-golang/discovery"
    31  	"github.com/okta/okta-jwt-verifier-golang/discovery/oidc"
    32  	"github.com/okta/okta-jwt-verifier-golang/errors"
    33  	"github.com/okta/okta-jwt-verifier-golang/utils"
    34  )
    35  
    36  var (
    37  	regx = regexp.MustCompile(`[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+\.?([a-zA-Z0-9-_]+)[/a-zA-Z0-9-_]+?$`)
    38  )
    39  
    40  type JwtVerifier struct {
    41  	Issuer string
    42  
    43  	ClaimsToValidate map[string]string
    44  
    45  	Discovery discovery.Discovery
    46  
    47  	Adaptor adaptors.Adaptor
    48  
    49  	// Cache allows customization of the cache used to store resources
    50  	Cache func(func(string) (interface{}, error)) (utils.Cacher, error)
    51  
    52  	metadataCache utils.Cacher
    53  
    54  	leeway int64
    55  }
    56  
    57  type Jwt struct {
    58  	Claims map[string]interface{}
    59  }
    60  
    61  func fetchMetaData(url string) (interface{}, error) {
    62  	resp, err := http.Get(url)
    63  	if err != nil {
    64  		return nil, fmt.Errorf("request for metadata was not successful: %w", err)
    65  	}
    66  	defer resp.Body.Close()
    67  
    68  	metadata := make(map[string]interface{})
    69  	if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
    70  		return nil, err
    71  	}
    72  	return metadata, nil
    73  }
    74  
    75  func (j *JwtVerifier) New() *JwtVerifier {
    76  	// Default to OIDC discovery if none is defined
    77  	if j.Discovery == nil {
    78  		disc := oidc.Oidc{}
    79  		j.Discovery = disc.New()
    80  	}
    81  
    82  	if j.Cache == nil {
    83  		j.Cache = utils.NewDefaultCache
    84  	}
    85  
    86  	// Default to LestrratGoJwx Adaptor if none is defined
    87  	if j.Adaptor == nil {
    88  		adaptor := &lestrratGoJwx.LestrratGoJwx{Cache: j.Cache}
    89  		j.Adaptor = adaptor.New()
    90  	}
    91  
    92  	// Default to PT2M Leeway
    93  	j.leeway = 120
    94  
    95  	return j
    96  }
    97  
    98  func (j *JwtVerifier) SetLeeway(duration string) {
    99  	dur, _ := time.ParseDuration(duration)
   100  	j.leeway = int64(dur.Seconds())
   101  }
   102  
   103  func (j *JwtVerifier) VerifyAccessToken(jwt string) (*Jwt, error) {
   104  	validJwt, err := j.isValidJwt(jwt)
   105  	if !validJwt {
   106  		return nil, fmt.Errorf("token is not valid: %w", err)
   107  	}
   108  
   109  	resp, err := j.decodeJwt(jwt)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	token := resp.(map[string]interface{})
   115  
   116  	myJwt := Jwt{
   117  		Claims: token,
   118  	}
   119  
   120  	err = j.validateIss(token["iss"])
   121  	if err != nil {
   122  		return &myJwt, fmt.Errorf("the `Issuer` was not able to be validated. %w", err)
   123  	}
   124  
   125  	err = j.validateAudience(token["aud"])
   126  	if err != nil {
   127  		return &myJwt, fmt.Errorf("the `Audience` was not able to be validated. %w", err)
   128  	}
   129  
   130  	err = j.validateClientId(token["cid"])
   131  	if err != nil {
   132  		return &myJwt, fmt.Errorf("the `Client Id` was not able to be validated. %w", err)
   133  	}
   134  
   135  	err = j.validateExp(token["exp"])
   136  	if err != nil {
   137  		return &myJwt, fmt.Errorf("the `Expiration` was not able to be validated. %w", err)
   138  	}
   139  
   140  	err = j.validateIat(token["iat"])
   141  	if err != nil {
   142  		return &myJwt, fmt.Errorf("the `Issued At` was not able to be validated. %w", err)
   143  	}
   144  
   145  	return &myJwt, nil
   146  }
   147  
   148  func (j *JwtVerifier) decodeJwt(jwt string) (interface{}, error) {
   149  	metaData, err := j.getMetaData()
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	jwksURI, ok := metaData["jwks_uri"].(string)
   154  	if !ok {
   155  		return nil, fmt.Errorf("failed to decode JWT: missing 'jwks_uri' from metadata")
   156  	}
   157  	resp, err := j.Adaptor.Decode(jwt, jwksURI)
   158  	if err != nil {
   159  		return nil, fmt.Errorf("could not decode token: %w", err)
   160  	}
   161  
   162  	return resp, nil
   163  }
   164  
   165  func (j *JwtVerifier) VerifyIdToken(jwt string) (*Jwt, error) {
   166  	validJwt, err := j.isValidJwt(jwt)
   167  	if !validJwt {
   168  		return nil, fmt.Errorf("token is not valid: %w", err)
   169  	}
   170  
   171  	resp, err := j.decodeJwt(jwt)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	token := resp.(map[string]interface{})
   177  
   178  	myJwt := Jwt{
   179  		Claims: token,
   180  	}
   181  
   182  	err = j.validateIss(token["iss"])
   183  	if err != nil {
   184  		return &myJwt, fmt.Errorf("the `Issuer` was not able to be validated. %w", err)
   185  	}
   186  
   187  	err = j.validateAudience(token["aud"])
   188  	if err != nil {
   189  		return &myJwt, fmt.Errorf("the `Audience` was not able to be validated. %w", err)
   190  	}
   191  
   192  	err = j.validateExp(token["exp"])
   193  	if err != nil {
   194  		return &myJwt, fmt.Errorf("the `Expiration` was not able to be validated. %w", err)
   195  	}
   196  
   197  	err = j.validateIat(token["iat"])
   198  	if err != nil {
   199  		return &myJwt, fmt.Errorf("the `Issued At` was not able to be validated. %w", err)
   200  	}
   201  
   202  	err = j.validateNonce(token["nonce"])
   203  	if err != nil {
   204  		return &myJwt, fmt.Errorf("the `Nonce` was not able to be validated. %w", err)
   205  	}
   206  
   207  	return &myJwt, nil
   208  }
   209  
   210  func (j *JwtVerifier) GetDiscovery() discovery.Discovery {
   211  	return j.Discovery
   212  }
   213  
   214  func (j *JwtVerifier) GetAdaptor() adaptors.Adaptor {
   215  	return j.Adaptor
   216  }
   217  
   218  func (j *JwtVerifier) validateNonce(nonce interface{}) error {
   219  	if nonce == nil {
   220  		nonce = ""
   221  	}
   222  
   223  	if nonce != j.ClaimsToValidate["nonce"] {
   224  		return fmt.Errorf("nonce: %s does not match %s", nonce, j.ClaimsToValidate["nonce"])
   225  	}
   226  	return nil
   227  }
   228  
   229  func (j *JwtVerifier) validateAudience(audience interface{}) error {
   230  	switch v := audience.(type) {
   231  	case string:
   232  		if v != j.ClaimsToValidate["aud"] {
   233  			return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
   234  		}
   235  	case []string:
   236  		for _, element := range v {
   237  			if element == j.ClaimsToValidate["aud"] {
   238  				return nil
   239  			}
   240  		}
   241  		return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
   242  	case []interface{}:
   243  		for _, e := range v {
   244  			element, ok := e.(string)
   245  			if !ok {
   246  				return fmt.Errorf("unknown type for audience validation")
   247  			}
   248  			if element == j.ClaimsToValidate["aud"] {
   249  				return nil
   250  			}
   251  		}
   252  		return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
   253  	default:
   254  		return fmt.Errorf("unknown type for audience validation")
   255  	}
   256  
   257  	return nil
   258  }
   259  
   260  func (j *JwtVerifier) validateClientId(clientId interface{}) error {
   261  	// Client Id can be optional, it will be validated if it is present in the ClaimsToValidate array
   262  	if cid, exists := j.ClaimsToValidate["cid"]; exists && clientId != cid {
   263  		switch v := clientId.(type) {
   264  		case string:
   265  			if v != cid {
   266  				return fmt.Errorf("aud: %s does not match %s", v, cid)
   267  			}
   268  		case []string:
   269  			for _, element := range v {
   270  				if element == cid {
   271  					return nil
   272  				}
   273  			}
   274  			return fmt.Errorf("aud: %s does not match %s", v, cid)
   275  		default:
   276  			return fmt.Errorf("unknown type for clientId validation")
   277  		}
   278  	}
   279  	return nil
   280  }
   281  
   282  func (j *JwtVerifier) validateExp(exp interface{}) error {
   283  	expf, ok := exp.(float64)
   284  	if !ok {
   285  		return fmt.Errorf("exp: missing")
   286  	}
   287  	if float64(time.Now().Unix()-j.leeway) > expf {
   288  		return fmt.Errorf("the token is expired")
   289  	}
   290  	return nil
   291  }
   292  
   293  func (j *JwtVerifier) validateIat(iat interface{}) error {
   294  	iatf, ok := iat.(float64)
   295  	if !ok {
   296  		return fmt.Errorf("iat: missing")
   297  	}
   298  	if float64(time.Now().Unix()+j.leeway) < iatf {
   299  		return fmt.Errorf("the token was issued in the future")
   300  	}
   301  	return nil
   302  }
   303  
   304  func (j *JwtVerifier) validateIss(issuer interface{}) error {
   305  	if issuer != j.Issuer {
   306  		return fmt.Errorf("iss: %s does not match %s", issuer, j.Issuer)
   307  	}
   308  	return nil
   309  }
   310  
   311  func (j *JwtVerifier) getMetaData() (map[string]interface{}, error) {
   312  	metaDataUrl := j.Issuer + j.Discovery.GetWellKnownUrl()
   313  
   314  	if j.metadataCache == nil {
   315  		metadataCache, err := j.Cache(fetchMetaData)
   316  		if err != nil {
   317  			return nil, err
   318  		}
   319  		j.metadataCache = metadataCache
   320  	}
   321  
   322  	value, err := j.metadataCache.Get(metaDataUrl)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  
   327  	metadata, ok := value.(map[string]interface{})
   328  	if !ok {
   329  		return nil, fmt.Errorf("unable to cast %v to metadata", value)
   330  	}
   331  	return metadata, nil
   332  }
   333  
   334  func (j *JwtVerifier) isValidJwt(jwt string) (bool, error) {
   335  	if jwt == "" {
   336  		return false, errors.JwtEmptyStringError()
   337  	}
   338  
   339  	// Verify that the JWT Follows correct JWT encoding.
   340  	jwtRegex := regx.MatchString
   341  	if !jwtRegex(jwt) {
   342  		return false, fmt.Errorf("token must contain at least 1 period ('.') and only characters 'a-Z 0-9 _'")
   343  	}
   344  
   345  	parts := strings.Split(jwt, ".")
   346  	header := parts[0]
   347  	header = padHeader(header)
   348  	headerDecoded, err := base64.StdEncoding.DecodeString(header)
   349  	if err != nil {
   350  		return false, fmt.Errorf("the tokens header does not appear to be a base64 encoded string")
   351  	}
   352  
   353  	var jsonObject map[string]interface{}
   354  	isHeaderJson := json.Unmarshal([]byte(headerDecoded), &jsonObject) == nil
   355  	if !isHeaderJson {
   356  		return false, fmt.Errorf("the tokens header is not a json object")
   357  	}
   358  
   359  	_, algExists := jsonObject["alg"]
   360  	_, kidExists := jsonObject["kid"]
   361  
   362  	if !algExists {
   363  		return false, fmt.Errorf("the tokens header must contain an 'alg'")
   364  	}
   365  
   366  	if !kidExists {
   367  		return false, fmt.Errorf("the tokens header must contain a 'kid'")
   368  	}
   369  
   370  	if jsonObject["alg"] != "RS256" {
   371  		return false, fmt.Errorf("the only supported alg is RS256")
   372  	}
   373  
   374  	return true, nil
   375  }
   376  
   377  func padHeader(header string) string {
   378  	if i := len(header) % 4; i != 0 {
   379  		header += strings.Repeat("=", 4-i)
   380  	}
   381  	return header
   382  }
   383  

View as plain text