...

Source file src/github.com/ory/fosite/handler/oauth2/strategy_jwt.go

Documentation: github.com/ory/fosite/handler/oauth2

     1  /*
     2   * Copyright © 2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * @author		Aeneas Rekkas <aeneas+oss@aeneas.io>
    17   * @copyright 	2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
    18   * @license 	Apache-2.0
    19   *
    20   */
    21  
    22  package oauth2
    23  
    24  import (
    25  	"context"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/pkg/errors"
    30  
    31  	"github.com/ory/fosite"
    32  	"github.com/ory/fosite/token/jwt"
    33  	"github.com/ory/x/errorsx"
    34  )
    35  
    36  // DefaultJWTStrategy is a JWT RS256 strategy.
    37  type DefaultJWTStrategy struct {
    38  	jwt.JWTStrategy
    39  	HMACSHAStrategy *HMACSHAStrategy
    40  	Issuer          string
    41  	ScopeField      jwt.JWTScopeFieldEnum
    42  }
    43  
    44  func (h *DefaultJWTStrategy) WithIssuer(issuer string) *DefaultJWTStrategy {
    45  	h.Issuer = issuer
    46  	return h
    47  }
    48  
    49  func (h *DefaultJWTStrategy) WithScopeField(scopeField jwt.JWTScopeFieldEnum) *DefaultJWTStrategy {
    50  	h.ScopeField = scopeField
    51  	return h
    52  }
    53  
    54  func (h DefaultJWTStrategy) signature(token string) string {
    55  	split := strings.Split(token, ".")
    56  	if len(split) != 3 {
    57  		return ""
    58  	}
    59  
    60  	return split[2]
    61  }
    62  
    63  func (h DefaultJWTStrategy) AccessTokenSignature(token string) string {
    64  	return h.signature(token)
    65  }
    66  
    67  func (h *DefaultJWTStrategy) GenerateAccessToken(ctx context.Context, requester fosite.Requester) (token string, signature string, err error) {
    68  	return h.generate(ctx, fosite.AccessToken, requester)
    69  }
    70  
    71  func (h *DefaultJWTStrategy) ValidateAccessToken(ctx context.Context, _ fosite.Requester, token string) error {
    72  	_, err := validate(ctx, h.JWTStrategy, token)
    73  	return err
    74  }
    75  
    76  func (h DefaultJWTStrategy) RefreshTokenSignature(token string) string {
    77  	return h.HMACSHAStrategy.RefreshTokenSignature(token)
    78  }
    79  
    80  func (h DefaultJWTStrategy) AuthorizeCodeSignature(token string) string {
    81  	return h.HMACSHAStrategy.AuthorizeCodeSignature(token)
    82  }
    83  
    84  func (h *DefaultJWTStrategy) GenerateRefreshToken(ctx context.Context, req fosite.Requester) (token string, signature string, err error) {
    85  	return h.HMACSHAStrategy.GenerateRefreshToken(ctx, req)
    86  }
    87  
    88  func (h *DefaultJWTStrategy) ValidateRefreshToken(ctx context.Context, req fosite.Requester, token string) error {
    89  	return h.HMACSHAStrategy.ValidateRefreshToken(ctx, req, token)
    90  }
    91  
    92  func (h *DefaultJWTStrategy) GenerateAuthorizeCode(ctx context.Context, req fosite.Requester) (token string, signature string, err error) {
    93  	return h.HMACSHAStrategy.GenerateAuthorizeCode(ctx, req)
    94  }
    95  
    96  func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosite.Requester, token string) error {
    97  	return h.HMACSHAStrategy.ValidateAuthorizeCode(ctx, req, token)
    98  }
    99  
   100  func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwt.Token, err error) {
   101  	t, err = jwtStrategy.Decode(ctx, token)
   102  	if err == nil {
   103  		err = t.Claims.Valid()
   104  		return
   105  	}
   106  
   107  	var e *jwt.ValidationError
   108  	if err != nil && errors.As(err, &e) {
   109  		err = errorsx.WithStack(toRFCErr(e).WithWrap(err).WithDebug(err.Error()))
   110  	}
   111  
   112  	return
   113  }
   114  
   115  func toRFCErr(v *jwt.ValidationError) *fosite.RFC6749Error {
   116  	switch {
   117  	case v == nil:
   118  		return nil
   119  	case v.Has(jwt.ValidationErrorMalformed):
   120  		return fosite.ErrInvalidTokenFormat
   121  	case v.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid):
   122  		return fosite.ErrTokenSignatureMismatch
   123  	case v.Has(jwt.ValidationErrorExpired):
   124  		return fosite.ErrTokenExpired
   125  	case v.Has(jwt.ValidationErrorAudience |
   126  		jwt.ValidationErrorIssuedAt |
   127  		jwt.ValidationErrorIssuer |
   128  		jwt.ValidationErrorNotValidYet |
   129  		jwt.ValidationErrorId |
   130  		jwt.ValidationErrorClaimsInvalid):
   131  		return fosite.ErrTokenClaim
   132  	default:
   133  		return fosite.ErrRequestUnauthorized
   134  	}
   135  }
   136  
   137  func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.TokenType, requester fosite.Requester) (string, string, error) {
   138  	if jwtSession, ok := requester.GetSession().(JWTSessionContainer); !ok {
   139  		return "", "", errors.Errorf("Session must be of type JWTSessionContainer but got type: %T", requester.GetSession())
   140  	} else if jwtSession.GetJWTClaims() == nil {
   141  		return "", "", errors.New("GetTokenClaims() must not be nil")
   142  	} else {
   143  		claims := jwtSession.GetJWTClaims().
   144  			With(
   145  				jwtSession.GetExpiresAt(tokenType),
   146  				requester.GetGrantedScopes(),
   147  				requester.GetGrantedAudience(),
   148  			).
   149  			WithDefaults(
   150  				time.Now().UTC(),
   151  				h.Issuer,
   152  			).
   153  			WithScopeField(
   154  				h.ScopeField,
   155  			)
   156  
   157  		return h.JWTStrategy.Generate(ctx, claims.ToMapClaims(), jwtSession.GetJWTHeader())
   158  	}
   159  }
   160  

View as plain text