1 package keyfunc_test
2
3 import (
4 "crypto/ecdsa"
5 "crypto/ed25519"
6 "crypto/elliptic"
7 "crypto/rand"
8 "crypto/rsa"
9 "crypto/sha256"
10 "errors"
11 "fmt"
12 "testing"
13
14 "github.com/golang-jwt/jwt/v5"
15
16 "github.com/MicahParks/keyfunc/v2"
17 "github.com/MicahParks/keyfunc/v2/examples/custom/method"
18 )
19
20 const (
21
22 algAttribute = "alg"
23
24
25 kidAttribute = "kid"
26
27
28 testKID = "testkid"
29 )
30
31
32 func TestNewGivenCustom(t *testing.T) {
33 jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod {
34 return method.EmptyCustom{}
35 })
36
37 givenKeys := make(map[string]keyfunc.GivenKey)
38 key := addCustom(givenKeys, testKID)
39
40 jwks := keyfunc.NewGiven(givenKeys)
41
42 token := jwt.New(method.EmptyCustom{})
43 token.Header[algAttribute] = method.CustomAlgHeader
44 token.Header[kidAttribute] = testKID
45
46 signParseValidate(t, token, key, jwks)
47 }
48
49
50 func TestNewGivenCustomAlg(t *testing.T) {
51 jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod {
52 return method.EmptyCustom{}
53 })
54
55 const key = "test-key"
56 givenKeys := make(map[string]keyfunc.GivenKey)
57 givenKeys[testKID] = keyfunc.NewGivenCustom(key, keyfunc.GivenKeyOptions{
58 Algorithm: method.CustomAlgHeader,
59 })
60
61 jwks := keyfunc.NewGiven(givenKeys)
62
63 token := jwt.New(method.EmptyCustom{})
64 token.Header[algAttribute] = method.CustomAlgHeader
65 token.Header[kidAttribute] = testKID
66
67 signParseValidate(t, token, key, jwks)
68 }
69
70
71
72 func TestNewGivenCustomAlg_NegativeCase(t *testing.T) {
73 jwt.RegisterSigningMethod(method.CustomAlgHeader, func() jwt.SigningMethod {
74 return method.EmptyCustom{}
75 })
76
77 const key = jwt.UnsafeAllowNoneSignatureType
78 givenKeys := make(map[string]keyfunc.GivenKey)
79 givenKeys[testKID] = keyfunc.NewGivenCustom(key, keyfunc.GivenKeyOptions{
80 Algorithm: method.CustomAlgHeader,
81 })
82
83 jwks := keyfunc.NewGiven(givenKeys)
84
85 token := jwt.New(method.EmptyCustom{})
86 token.Header[algAttribute] = jwt.SigningMethodNone.Alg()
87 token.Header[kidAttribute] = testKID
88
89 jwtB64, err := token.SignedString(key)
90 if err != nil {
91 t.Fatalf(logFmt, "Failed to sign the JWT.", err)
92 }
93
94 parsed, err := jwt.NewParser().Parse(jwtB64, jwks.Keyfunc)
95 if !errors.Is(err, keyfunc.ErrJWKAlgMismatch) {
96 t.Fatalf("Failed to return ErrJWKAlgMismatch: %v.", err)
97 }
98
99 if parsed.Valid {
100 t.Fatalf("The JWT was valid.")
101 }
102 }
103
104
105 func TestNewGivenKeyECDSA(t *testing.T) {
106 givenKeys := make(map[string]keyfunc.GivenKey)
107 key, err := addECDSA(givenKeys, testKID)
108 if err != nil {
109 t.Fatalf(err.Error())
110 }
111
112 jwks := keyfunc.NewGiven(givenKeys)
113
114 token := jwt.New(jwt.SigningMethodES256)
115 token.Header[kidAttribute] = testKID
116
117 signParseValidate(t, token, key, jwks)
118 }
119
120
121 func TestNewGivenKeyEdDSA(t *testing.T) {
122 givenKeys := make(map[string]keyfunc.GivenKey)
123 key, err := addEdDSA(givenKeys, testKID)
124 if err != nil {
125 t.Fatalf(err.Error())
126 }
127
128 jwks := keyfunc.NewGiven(givenKeys)
129
130 token := jwt.New(jwt.SigningMethodEdDSA)
131 token.Header[kidAttribute] = testKID
132
133 signParseValidate(t, token, key, jwks)
134 }
135
136
137 func TestNewGivenKeyHMAC(t *testing.T) {
138 givenKeys := make(map[string]keyfunc.GivenKey)
139 key, err := addHMAC(givenKeys, testKID)
140 if err != nil {
141 t.Fatalf(err.Error())
142 }
143
144 jwks := keyfunc.NewGiven(givenKeys)
145
146 token := jwt.New(jwt.SigningMethodHS256)
147 token.Header[kidAttribute] = testKID
148
149 signParseValidate(t, token, key, jwks)
150 }
151
152
153 func TestNewGivenKeyRSA(t *testing.T) {
154 givenKeys := make(map[string]keyfunc.GivenKey)
155 key, err := addRSA(givenKeys, testKID)
156 if err != nil {
157 t.Fatalf(err.Error())
158 }
159
160 jwks := keyfunc.NewGiven(givenKeys)
161
162 token := jwt.New(jwt.SigningMethodRS256)
163 token.Header[kidAttribute] = testKID
164
165 signParseValidate(t, token, key, jwks)
166 }
167
168
169 func TestNewGivenKeysFromJSON(t *testing.T) {
170
171 key := []byte("test-hmac-secret")
172 const testJSON = `{
173 "keys": [
174 {
175 "kid": "testkid",
176 "kty": "oct",
177 "alg": "HS256",
178 "use": "sig",
179 "k": "dGVzdC1obWFjLXNlY3JldA"
180 }
181 ]
182 }`
183
184 givenKeys, err := keyfunc.NewGivenKeysFromJSON([]byte(testJSON))
185 if err != nil {
186 t.Fatalf(logFmt, "Failed to parse given keys from JSON.", err)
187 }
188
189 jwks := keyfunc.NewGiven(givenKeys)
190
191 token := jwt.New(jwt.SigningMethodHS256)
192 token.Header[kidAttribute] = testKID
193
194 signParseValidate(t, token, key, jwks)
195 }
196
197
198 func TestNewGivenKeysFromJSON_BadParse(t *testing.T) {
199 const testJSON = "{not the best syntax"
200 _, err := keyfunc.NewGivenKeysFromJSON([]byte(testJSON))
201 if err == nil {
202 t.Fatalf("Expected a JSON parse error")
203 }
204 }
205
206
207 func addCustom(givenKeys map[string]keyfunc.GivenKey, kid string) (key string) {
208 key = ""
209 givenKeys[kid] = keyfunc.NewGivenCustom(key, keyfunc.GivenKeyOptions{
210 Algorithm: method.CustomAlgHeader,
211 })
212 return key
213 }
214
215
216 func addECDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *ecdsa.PrivateKey, err error) {
217 key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
218 if err != nil {
219 return nil, fmt.Errorf("failed to create ECDSA key: %w", err)
220 }
221
222 givenKeys[kid] = keyfunc.NewGivenECDSA(&key.PublicKey, keyfunc.GivenKeyOptions{
223 Algorithm: jwt.SigningMethodES256.Alg(),
224 })
225
226 return key, nil
227 }
228
229
230 func addEdDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key ed25519.PrivateKey, err error) {
231 pub, key, err := ed25519.GenerateKey(rand.Reader)
232 if err != nil {
233 return nil, fmt.Errorf("failed to create ECDSA key: %w", err)
234 }
235
236 givenKeys[kid] = keyfunc.NewGivenEdDSA(pub, keyfunc.GivenKeyOptions{
237 Algorithm: jwt.SigningMethodEdDSA.Alg(),
238 })
239
240 return key, nil
241 }
242
243
244 func addHMAC(givenKeys map[string]keyfunc.GivenKey, kid string) (secret []byte, err error) {
245 secret = make([]byte, sha256.BlockSize)
246 _, err = rand.Read(secret)
247 if err != nil {
248 return nil, fmt.Errorf("failed to create HMAC secret: %w", err)
249 }
250
251 givenKeys[kid] = keyfunc.NewGivenHMAC(secret, keyfunc.GivenKeyOptions{
252 Algorithm: jwt.SigningMethodHS256.Alg(),
253 })
254
255 return secret, nil
256 }
257
258
259 func addRSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *rsa.PrivateKey, err error) {
260 key, err = rsa.GenerateKey(rand.Reader, 2048)
261 if err != nil {
262 return nil, fmt.Errorf("failed to create RSA key: %w", err)
263 }
264
265 givenKeys[kid] = keyfunc.NewGivenRSA(&key.PublicKey, keyfunc.GivenKeyOptions{
266 Algorithm: jwt.SigningMethodRS256.Alg(),
267 })
268
269 return key, nil
270 }
271
272
273 func signParseValidate(t *testing.T, token *jwt.Token, key interface{}, jwks *keyfunc.JWKS) {
274 jwtB64, err := token.SignedString(key)
275 if err != nil {
276 t.Fatalf(logFmt, "Failed to sign the JWT.", err)
277 }
278
279 parsed, err := jwt.Parse(jwtB64, jwks.Keyfunc)
280 if err != nil {
281 t.Fatalf(logFmt, "Failed to parse the JWT.", err)
282 }
283
284 if !parsed.Valid {
285 t.Fatalf("The JWT was not valid.")
286 }
287 }
288
View as plain text