1 package jwt
2
3 import (
4 "encoding/base64"
5 "encoding/json"
6 "fmt"
7 "reflect"
8
9 "github.com/ory/x/errorsx"
10 "gopkg.in/square/go-jose.v2"
11 "gopkg.in/square/go-jose.v2/jwt"
12 )
13
14
15
16
17
18
19
20 type Token struct {
21 Header map[string]interface{}
22 Claims MapClaims
23 Method jose.SignatureAlgorithm
24 valid bool
25 }
26
27 const (
28 SigningMethodNone = jose.SignatureAlgorithm("none")
29
30 UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed"
31
32 JWTHeaderType = jose.HeaderKey("typ")
33 JWTHeaderTypeValue = "JWT"
34 )
35
36 type unsafeNoneMagicConstant string
37
38
39
40 func (t *Token) Valid() bool {
41 return t.valid
42 }
43
44
45
46
47
48
49 type Claims interface {
50 Valid() error
51 }
52
53
54 func NewWithClaims(method jose.SignatureAlgorithm, claims MapClaims) *Token {
55 return &Token{
56 Claims: claims,
57 Method: method,
58 Header: map[string]interface{}{},
59 }
60 }
61
62 func (t *Token) toJoseHeader() map[jose.HeaderKey]interface{} {
63 h := map[jose.HeaderKey]interface{}{
64 JWTHeaderType: JWTHeaderTypeValue,
65 }
66 for k, v := range t.Header {
67 h[jose.HeaderKey(k)] = v
68 }
69 return h
70 }
71
72
73
74
75 func (t *Token) SignedString(k interface{}) (rawToken string, err error) {
76 if _, ok := k.(unsafeNoneMagicConstant); ok {
77 rawToken, err = unsignedToken(t)
78 return
79
80 }
81 var signer jose.Signer
82 key := jose.SigningKey{
83 Algorithm: t.Method,
84 Key: k,
85 }
86 opts := &jose.SignerOptions{ExtraHeaders: t.toJoseHeader()}
87 signer, err = jose.NewSigner(key, opts)
88 if err != nil {
89 err = errorsx.WithStack(err)
90 return
91 }
92
93
94
95
96
97 claims := map[string]interface{}(t.Claims)
98 rawToken, err = jwt.Signed(signer).Claims(claims).CompactSerialize()
99 if err != nil {
100 err = &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err}
101 return
102 }
103 return
104 }
105
106 func unsignedToken(t *Token) (string, error) {
107 t.Header["alg"] = "none"
108 t.Header[string(JWTHeaderType)] = JWTHeaderTypeValue
109 hbytes, err := json.Marshal(&t.Header)
110 if err != nil {
111 return "", errorsx.WithStack(err)
112 }
113 bbytes, err := json.Marshal(&t.Claims)
114 if err != nil {
115 return "", errorsx.WithStack(err)
116 }
117 h := base64.RawURLEncoding.EncodeToString(hbytes)
118 b := base64.RawURLEncoding.EncodeToString(bbytes)
119 return fmt.Sprintf("%v.%v.", h, b), nil
120 }
121
122 func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) {
123 token := &Token{Claims: claims}
124 if len(parsedToken.Headers) != 1 {
125 return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed}
126 }
127
128
129 h := parsedToken.Headers[0]
130 token.Header = map[string]interface{}{
131 "alg": h.Algorithm,
132 }
133 if h.KeyID != "" {
134 token.Header["kid"] = h.KeyID
135 }
136 for k, v := range h.ExtraHeaders {
137 token.Header[string(k)] = v
138 }
139
140 token.Method = jose.SignatureAlgorithm(h.Algorithm)
141
142 return token, nil
143 }
144
145
146
147
148
149 type Keyfunc func(*Token) (interface{}, error)
150
151 func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) {
152 return ParseWithClaims(tokenString, MapClaims{}, keyFunc)
153 }
154
155
156
157
158 func ParseWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc) (*Token, error) {
159
160 parsedToken, err := jwt.ParseSigned(rawToken)
161 if err != nil {
162 return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, text: err.Error()}
163 }
164
165
166
167
168
169
170
171
172 if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
173 return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, text: err.Error()}
174 }
175
176
177 token, err := newToken(parsedToken, claims)
178 if err != nil {
179 return nil, err
180 }
181
182 if keyFunc == nil {
183
184 return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."}
185 }
186
187
188 verificationKey, err := keyFunc(token)
189 if err != nil {
190
191 if ve, ok := err.(*ValidationError); ok {
192 return token, ve
193 }
194 return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err}
195 }
196 if verificationKey == nil {
197 return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"}
198 }
199
200
201
202
203 verificationKey = pointer(verificationKey)
204
205
206 _, validNoneKey := verificationKey.(*unsafeNoneMagicConstant)
207 isSignedToken := !(token.Method == SigningMethodNone && validNoneKey)
208 if isSignedToken {
209 if err := parsedToken.Claims(verificationKey, &claims); err != nil {
210 return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()}
211 }
212 }
213
214
215
216
217 if err := claims.Valid(); err != nil {
218 if e, ok := err.(*ValidationError); !ok {
219 err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid}
220 }
221 return token, err
222 }
223
224
225 token.valid = true
226 return token, nil
227 }
228
229
230
231 func pointer(v interface{}) interface{} {
232 if reflect.ValueOf(v).Kind() != reflect.Ptr {
233 value := reflect.New(reflect.ValueOf(v).Type())
234 value.Elem().Set(reflect.ValueOf(v))
235 return value.Interface()
236 }
237 return v
238 }
239
View as plain text