...
1 package keyfunc
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "net/http"
9 "sync"
10 "time"
11 )
12
13 var (
14
15
16 ErrJWKAlgMismatch = errors.New(`the given JWK was found, but its "alg" parameter's value did not match the expected algorithm`)
17
18
19 ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`)
20
21
22 ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
23
24
25 ErrMissingAssets = errors.New("required assets are missing to create a public key")
26 )
27
28
29 type ErrorHandler func(err error)
30
31 const (
32
33 UseEncryption JWKUse = "enc"
34
35 UseOmitted JWKUse = ""
36
37 UseSignature JWKUse = "sig"
38 )
39
40
41
42 type JWKUse string
43
44
45 type jsonWebKey struct {
46 Algorithm string `json:"alg"`
47 Curve string `json:"crv"`
48 Exponent string `json:"e"`
49 K string `json:"k"`
50 ID string `json:"kid"`
51 Modulus string `json:"n"`
52 Type string `json:"kty"`
53 Use string `json:"use"`
54 X string `json:"x"`
55 Y string `json:"y"`
56 }
57
58
59 type parsedJWK struct {
60 algorithm string
61 public interface{}
62 use JWKUse
63 }
64
65
66 type JWKS struct {
67 jwkUseWhitelist map[JWKUse]struct{}
68 cancel context.CancelFunc
69 client *http.Client
70 ctx context.Context
71 raw []byte
72 givenKeys map[string]GivenKey
73 givenKIDOverride bool
74 jwksURL string
75 keys map[string]parsedJWK
76 mux sync.RWMutex
77 refreshErrorHandler ErrorHandler
78 refreshInterval time.Duration
79 refreshRateLimit time.Duration
80 refreshRequests chan refreshRequest
81 refreshTimeout time.Duration
82 refreshUnknownKID bool
83 requestFactory func(ctx context.Context, url string) (*http.Request, error)
84 responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
85 }
86
87
88 type rawJWKS struct {
89 Keys []*jsonWebKey `json:"keys"`
90 }
91
92
93 func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
94 var rawKS rawJWKS
95 err = json.Unmarshal(jwksBytes, &rawKS)
96 if err != nil {
97 return nil, err
98 }
99
100
101 jwks = &JWKS{
102 keys: make(map[string]parsedJWK, len(rawKS.Keys)),
103 }
104 for _, key := range rawKS.Keys {
105 var keyInter interface{}
106 switch keyType := key.Type; keyType {
107 case ktyEC:
108 keyInter, err = key.ECDSA()
109 if err != nil {
110 continue
111 }
112 case ktyOKP:
113 keyInter, err = key.EdDSA()
114 if err != nil {
115 continue
116 }
117 case ktyOct:
118 keyInter, err = key.Oct()
119 if err != nil {
120 continue
121 }
122 case ktyRSA:
123 keyInter, err = key.RSA()
124 if err != nil {
125 continue
126 }
127 default:
128
129 continue
130 }
131
132 jwks.keys[key.ID] = parsedJWK{
133 algorithm: key.Algorithm,
134 use: JWKUse(key.Use),
135 public: keyInter,
136 }
137 }
138
139 return jwks, nil
140 }
141
142
143
144 func (j *JWKS) EndBackground() {
145 if j.cancel != nil {
146 j.cancel()
147 }
148 }
149
150
151 func (j *JWKS) KIDs() (kids []string) {
152 j.mux.RLock()
153 defer j.mux.RUnlock()
154 kids = make([]string, len(j.keys))
155 index := 0
156 for kid := range j.keys {
157 kids[index] = kid
158 index++
159 }
160 return kids
161 }
162
163
164 func (j *JWKS) Len() int {
165 j.mux.RLock()
166 defer j.mux.RUnlock()
167 return len(j.keys)
168 }
169
170
171 func (j *JWKS) RawJWKS() []byte {
172 j.mux.RLock()
173 defer j.mux.RUnlock()
174 raw := make([]byte, len(j.raw))
175 copy(raw, j.raw)
176 return raw
177 }
178
179
180 func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
181 keys := make(map[string]interface{})
182 j.mux.Lock()
183 for kid, cryptoKey := range j.keys {
184 keys[kid] = cryptoKey.public
185 }
186 j.mux.Unlock()
187 return keys
188 }
189
190
191 func (j *JWKS) getKey(alg, kid string) (jsonKey interface{}, err error) {
192 j.mux.RLock()
193 pubKey, ok := j.keys[kid]
194 j.mux.RUnlock()
195
196 if !ok {
197 if !j.refreshUnknownKID {
198 return nil, ErrKIDNotFound
199 }
200
201 ctx, cancel := context.WithCancel(j.ctx)
202 req := refreshRequest{
203 cancel: cancel,
204 }
205
206
207 select {
208 case <-j.ctx.Done():
209 return
210 case j.refreshRequests <- req:
211 default:
212
213 return nil, ErrKIDNotFound
214 }
215
216
217 <-ctx.Done()
218
219 j.mux.RLock()
220 defer j.mux.RUnlock()
221 if pubKey, ok = j.keys[kid]; !ok {
222 return nil, ErrKIDNotFound
223 }
224 }
225
226
227 if len(j.jwkUseWhitelist) > 0 {
228 _, ok = j.jwkUseWhitelist[pubKey.use]
229 if !ok {
230 return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not whitelisted`, ErrJWKUseWhitelist, pubKey.use)
231 }
232 }
233
234 if pubKey.algorithm != "" && pubKey.algorithm != alg {
235 return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrJWKAlgMismatch, pubKey.algorithm, alg)
236 }
237
238 return pubKey.public, nil
239 }
240
View as plain text