1 package keyfunc_test
2
3 import (
4 "crypto/rand"
5 "crypto/rsa"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "math/big"
11 "net/http"
12 "net/http/httptest"
13 "os"
14 "path/filepath"
15 "testing"
16
17 "github.com/golang-jwt/jwt/v4"
18
19 "github.com/MicahParks/keyfunc"
20 )
21
22 const (
23 givenKID = "givenKID"
24 remoteKID = "remoteKID"
25 )
26
27
28 type pseudoJWKS struct {
29 Keys []pseudoJSONKey `json:"keys"`
30 }
31
32
33 type pseudoJSONKey struct {
34 KID string `json:"kid"`
35 KTY string `json:"kty"`
36 E string `json:"e"`
37 N string `json:"n"`
38 }
39
40
41 func TestNewGiven(t *testing.T) {
42 tempDir, err := os.MkdirTemp("", "*")
43 if err != nil {
44 t.Fatalf(logFmt, "Failed to create a temporary directory.", err)
45 }
46 defer func() {
47 err = os.RemoveAll(tempDir)
48 if err != nil {
49 t.Fatalf(logFmt, "Failed to remove temporary directory.", err)
50 }
51 }()
52
53 jwksFile := filepath.Join(tempDir, jwksFilePath)
54
55 givenKeys, givenPrivateKeys, jwksBytes, remotePrivateKeys, err := keysAndJWKS()
56 if err != nil {
57 t.Fatalf(logFmt, "Failed to create cryptographic keys for the test.", err)
58 }
59
60 err = os.WriteFile(jwksFile, jwksBytes, 0600)
61 if err != nil {
62 t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
63 }
64
65 server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
66 defer server.Close()
67
68 testingRefreshErrorHandler := func(err error) {
69 panic(fmt.Sprintf(logFmt, "Unhandled JWKS error.", err))
70 }
71
72 jwksURL := server.URL + jwksFilePath
73
74 options := keyfunc.Options{
75 GivenKeys: givenKeys,
76 RefreshErrorHandler: testingRefreshErrorHandler,
77 }
78
79 jwks, err := keyfunc.Get(jwksURL, options)
80 if err != nil {
81 t.Fatalf(logFmt, "Failed to get the JWKS the testing URL.", err)
82 }
83
84
85 createSignParseValidate(t, givenPrivateKeys, jwks, givenKID, true)
86
87
88 createSignParseValidate(t, givenPrivateKeys, jwks, remoteKID, false)
89
90
91 createSignParseValidate(t, remotePrivateKeys, jwks, remoteKID, true)
92
93
94 options.GivenKIDOverride = true
95 jwks, err = keyfunc.Get(jwksURL, options)
96 if err != nil {
97 t.Fatalf(logFmt, "Failed to recreate JWKS.", err)
98 }
99
100
101 createSignParseValidate(t, givenPrivateKeys, jwks, givenKID, true)
102
103
104 createSignParseValidate(t, givenPrivateKeys, jwks, remoteKID, true)
105
106
107 createSignParseValidate(t, remotePrivateKeys, jwks, remoteKID, false)
108 }
109
110
111 func createSignParseValidate(t *testing.T, keys map[string]*rsa.PrivateKey, jwks *keyfunc.JWKS, kid string, shouldValidate bool) {
112 unsignedToken := jwt.New(jwt.SigningMethodRS256)
113 unsignedToken.Header[kidAttribute] = kid
114
115 jwtB64, err := unsignedToken.SignedString(keys[kid])
116 if err != nil {
117 t.Fatalf(logFmt, "Failed to sign the JWT.", err)
118 }
119
120 token, err := jwt.Parse(jwtB64, jwks.Keyfunc)
121 if err != nil {
122 if !shouldValidate && errors.Is(err, rsa.ErrVerification) {
123 return
124 }
125 t.Fatalf(logFmt, "Failed to parse the JWT.", err)
126 }
127
128 if !shouldValidate {
129 t.Fatalf("The token should not have parsed properly.")
130 }
131
132 if !token.Valid {
133 t.Fatalf("The JWT is not valid.")
134 }
135 }
136
137
138 func keysAndJWKS() (givenKeys map[string]keyfunc.GivenKey, givenPrivateKeys map[string]*rsa.PrivateKey, jwksBytes []byte, remotePrivateKeys map[string]*rsa.PrivateKey, err error) {
139 const rsaErrMessage = "failed to create RSA key: %w"
140 givenKeys = make(map[string]keyfunc.GivenKey)
141 givenPrivateKeys = make(map[string]*rsa.PrivateKey)
142 remotePrivateKeys = make(map[string]*rsa.PrivateKey)
143
144
145 key1, err := addRSA(givenKeys, givenKID)
146 if err != nil {
147 return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
148 }
149 givenPrivateKeys[givenKID] = key1
150
151
152 key2, err := addRSA(givenKeys, remoteKID)
153 if err != nil {
154 return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
155 }
156 givenPrivateKeys[remoteKID] = key2
157
158
159 key3, err := rsa.GenerateKey(rand.Reader, 2048)
160 if err != nil {
161 return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
162 }
163 remotePrivateKeys[remoteKID] = key3
164
165 jwks := pseudoJWKS{Keys: []pseudoJSONKey{{
166 KID: remoteKID,
167 KTY: "RSA",
168 E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(key3.PublicKey.E)).Bytes()),
169 N: base64.RawURLEncoding.EncodeToString(key3.PublicKey.N.Bytes()),
170 }}}
171
172 jwksBytes, err = json.Marshal(jwks)
173 if err != nil {
174 return nil, nil, nil, nil, fmt.Errorf("failed to marshal the JWKS to JSON: %w", err)
175 }
176
177 return givenKeys, givenPrivateKeys, jwksBytes, remotePrivateKeys, nil
178 }
179
View as plain text