1
21
22 package oauth2
23
24 import (
25 "context"
26 "strings"
27 "time"
28
29 "github.com/pkg/errors"
30
31 "github.com/ory/fosite"
32 "github.com/ory/fosite/token/jwt"
33 "github.com/ory/x/errorsx"
34 )
35
36
37 type DefaultJWTStrategy struct {
38 jwt.JWTStrategy
39 HMACSHAStrategy *HMACSHAStrategy
40 Issuer string
41 ScopeField jwt.JWTScopeFieldEnum
42 }
43
44 func (h *DefaultJWTStrategy) WithIssuer(issuer string) *DefaultJWTStrategy {
45 h.Issuer = issuer
46 return h
47 }
48
49 func (h *DefaultJWTStrategy) WithScopeField(scopeField jwt.JWTScopeFieldEnum) *DefaultJWTStrategy {
50 h.ScopeField = scopeField
51 return h
52 }
53
54 func (h DefaultJWTStrategy) signature(token string) string {
55 split := strings.Split(token, ".")
56 if len(split) != 3 {
57 return ""
58 }
59
60 return split[2]
61 }
62
63 func (h DefaultJWTStrategy) AccessTokenSignature(token string) string {
64 return h.signature(token)
65 }
66
67 func (h *DefaultJWTStrategy) GenerateAccessToken(ctx context.Context, requester fosite.Requester) (token string, signature string, err error) {
68 return h.generate(ctx, fosite.AccessToken, requester)
69 }
70
71 func (h *DefaultJWTStrategy) ValidateAccessToken(ctx context.Context, _ fosite.Requester, token string) error {
72 _, err := validate(ctx, h.JWTStrategy, token)
73 return err
74 }
75
76 func (h DefaultJWTStrategy) RefreshTokenSignature(token string) string {
77 return h.HMACSHAStrategy.RefreshTokenSignature(token)
78 }
79
80 func (h DefaultJWTStrategy) AuthorizeCodeSignature(token string) string {
81 return h.HMACSHAStrategy.AuthorizeCodeSignature(token)
82 }
83
84 func (h *DefaultJWTStrategy) GenerateRefreshToken(ctx context.Context, req fosite.Requester) (token string, signature string, err error) {
85 return h.HMACSHAStrategy.GenerateRefreshToken(ctx, req)
86 }
87
88 func (h *DefaultJWTStrategy) ValidateRefreshToken(ctx context.Context, req fosite.Requester, token string) error {
89 return h.HMACSHAStrategy.ValidateRefreshToken(ctx, req, token)
90 }
91
92 func (h *DefaultJWTStrategy) GenerateAuthorizeCode(ctx context.Context, req fosite.Requester) (token string, signature string, err error) {
93 return h.HMACSHAStrategy.GenerateAuthorizeCode(ctx, req)
94 }
95
96 func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosite.Requester, token string) error {
97 return h.HMACSHAStrategy.ValidateAuthorizeCode(ctx, req, token)
98 }
99
100 func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwt.Token, err error) {
101 t, err = jwtStrategy.Decode(ctx, token)
102 if err == nil {
103 err = t.Claims.Valid()
104 return
105 }
106
107 var e *jwt.ValidationError
108 if err != nil && errors.As(err, &e) {
109 err = errorsx.WithStack(toRFCErr(e).WithWrap(err).WithDebug(err.Error()))
110 }
111
112 return
113 }
114
115 func toRFCErr(v *jwt.ValidationError) *fosite.RFC6749Error {
116 switch {
117 case v == nil:
118 return nil
119 case v.Has(jwt.ValidationErrorMalformed):
120 return fosite.ErrInvalidTokenFormat
121 case v.Has(jwt.ValidationErrorUnverifiable | jwt.ValidationErrorSignatureInvalid):
122 return fosite.ErrTokenSignatureMismatch
123 case v.Has(jwt.ValidationErrorExpired):
124 return fosite.ErrTokenExpired
125 case v.Has(jwt.ValidationErrorAudience |
126 jwt.ValidationErrorIssuedAt |
127 jwt.ValidationErrorIssuer |
128 jwt.ValidationErrorNotValidYet |
129 jwt.ValidationErrorId |
130 jwt.ValidationErrorClaimsInvalid):
131 return fosite.ErrTokenClaim
132 default:
133 return fosite.ErrRequestUnauthorized
134 }
135 }
136
137 func (h *DefaultJWTStrategy) generate(ctx context.Context, tokenType fosite.TokenType, requester fosite.Requester) (string, string, error) {
138 if jwtSession, ok := requester.GetSession().(JWTSessionContainer); !ok {
139 return "", "", errors.Errorf("Session must be of type JWTSessionContainer but got type: %T", requester.GetSession())
140 } else if jwtSession.GetJWTClaims() == nil {
141 return "", "", errors.New("GetTokenClaims() must not be nil")
142 } else {
143 claims := jwtSession.GetJWTClaims().
144 With(
145 jwtSession.GetExpiresAt(tokenType),
146 requester.GetGrantedScopes(),
147 requester.GetGrantedAudience(),
148 ).
149 WithDefaults(
150 time.Now().UTC(),
151 h.Issuer,
152 ).
153 WithScopeField(
154 h.ScopeField,
155 )
156
157 return h.JWTStrategy.Generate(ctx, claims.ToMapClaims(), jwtSession.GetJWTHeader())
158 }
159 }
160
View as plain text