...
1 package keyfunc
2
3 import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "net/http"
9 "sync"
10 "time"
11 )
12
13 var (
14
15 ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`)
16
17
18 ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
19
20
21 ErrMissingAssets = errors.New("required assets are missing to create a public key")
22 )
23
24
25 type ErrorHandler func(err error)
26
27 const (
28
29 UseEncryption JWKUse = "enc"
30
31 UseOmitted JWKUse = ""
32
33 UseSignature JWKUse = "sig"
34 )
35
36
37
38 type JWKUse string
39
40
41 type jsonWebKey struct {
42 Curve string `json:"crv"`
43 Exponent string `json:"e"`
44 K string `json:"k"`
45 ID string `json:"kid"`
46 Modulus string `json:"n"`
47 Type string `json:"kty"`
48 Use string `json:"use"`
49 X string `json:"x"`
50 Y string `json:"y"`
51 }
52
53
54 type parsedJWK struct {
55 use JWKUse
56 public interface{}
57 }
58
59
60 type JWKS struct {
61 jwkUseWhitelist map[JWKUse]struct{}
62 cancel context.CancelFunc
63 client *http.Client
64 ctx context.Context
65 raw []byte
66 givenKeys map[string]GivenKey
67 givenKIDOverride bool
68 jwksURL string
69 keys map[string]parsedJWK
70 mux sync.RWMutex
71 refreshErrorHandler ErrorHandler
72 refreshInterval time.Duration
73 refreshRateLimit time.Duration
74 refreshRequests chan context.CancelFunc
75 refreshTimeout time.Duration
76 refreshUnknownKID bool
77 requestFactory func(ctx context.Context, url string) (*http.Request, error)
78 responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
79 }
80
81
82 type rawJWKS struct {
83 Keys []*jsonWebKey `json:"keys"`
84 }
85
86
87 func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
88 var rawKS rawJWKS
89 err = json.Unmarshal(jwksBytes, &rawKS)
90 if err != nil {
91 return nil, err
92 }
93
94
95 jwks = &JWKS{
96 keys: make(map[string]parsedJWK, len(rawKS.Keys)),
97 }
98 for _, key := range rawKS.Keys {
99 var keyInter interface{}
100 switch keyType := key.Type; keyType {
101 case ktyEC:
102 keyInter, err = key.ECDSA()
103 if err != nil {
104 continue
105 }
106 case ktyOKP:
107 keyInter, err = key.EdDSA()
108 if err != nil {
109 continue
110 }
111 case ktyOct:
112 keyInter, err = key.Oct()
113 if err != nil {
114 continue
115 }
116 case ktyRSA:
117 keyInter, err = key.RSA()
118 if err != nil {
119 continue
120 }
121 default:
122
123 continue
124 }
125
126 jwks.keys[key.ID] = parsedJWK{
127 use: JWKUse(key.Use),
128 public: keyInter,
129 }
130 }
131
132 return jwks, nil
133 }
134
135
136
137 func (j *JWKS) EndBackground() {
138 if j.cancel != nil {
139 j.cancel()
140 }
141 }
142
143
144 func (j *JWKS) KIDs() (kids []string) {
145 j.mux.RLock()
146 defer j.mux.RUnlock()
147 kids = make([]string, len(j.keys))
148 index := 0
149 for kid := range j.keys {
150 kids[index] = kid
151 index++
152 }
153 return kids
154 }
155
156
157 func (j *JWKS) Len() int {
158 j.mux.RLock()
159 defer j.mux.RUnlock()
160 return len(j.keys)
161 }
162
163
164 func (j *JWKS) RawJWKS() []byte {
165 j.mux.RLock()
166 defer j.mux.RUnlock()
167 raw := make([]byte, len(j.raw))
168 copy(raw, j.raw)
169 return raw
170 }
171
172
173 func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
174 keys := make(map[string]interface{})
175 j.mux.Lock()
176 for kid, cryptoKey := range j.keys {
177 keys[kid] = cryptoKey.public
178 }
179 j.mux.Unlock()
180 return keys
181 }
182
183
184 func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) {
185 j.mux.RLock()
186 pubKey, ok := j.keys[kid]
187 j.mux.RUnlock()
188
189 if !ok {
190 if !j.refreshUnknownKID {
191 return nil, ErrKIDNotFound
192 }
193
194 ctx, cancel := context.WithCancel(j.ctx)
195
196
197 select {
198 case <-j.ctx.Done():
199 return
200 case j.refreshRequests <- cancel:
201 default:
202
203 return nil, ErrKIDNotFound
204 }
205
206
207 <-ctx.Done()
208
209 j.mux.RLock()
210 defer j.mux.RUnlock()
211 if pubKey, ok = j.keys[kid]; !ok {
212 return nil, ErrKIDNotFound
213 }
214 }
215
216
217 if len(j.jwkUseWhitelist) > 0 {
218 _, ok = j.jwkUseWhitelist[pubKey.use]
219 if !ok {
220 return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not whitelisted`, ErrJWKUseWhitelist, pubKey.use)
221 }
222 }
223
224 return pubKey.public, nil
225 }
226
View as plain text