1 package jwk
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "sort"
8
9 "github.com/lestrrat-go/iter/arrayiter"
10 "github.com/lestrrat-go/jwx/internal/json"
11 "github.com/lestrrat-go/jwx/internal/pool"
12 "github.com/pkg/errors"
13 )
14
15 const keysKey = `keys`
16
17
18 func NewSet() Set {
19 return &set{
20 privateParams: make(map[string]interface{}),
21 }
22 }
23
24 func (s *set) Set(n string, v interface{}) error {
25 s.mu.RLock()
26 defer s.mu.RUnlock()
27
28 if n == keysKey {
29 vl, ok := v.([]Key)
30 if !ok {
31 return errors.Errorf(`value for field "keys" must be []jwk.Key`)
32 }
33 s.keys = vl
34 return nil
35 }
36
37 s.privateParams[n] = v
38 return nil
39 }
40
41 func (s *set) Field(n string) (interface{}, bool) {
42 s.mu.RLock()
43 defer s.mu.RUnlock()
44
45 v, ok := s.privateParams[n]
46 return v, ok
47 }
48
49 func (s *set) Get(idx int) (Key, bool) {
50 s.mu.RLock()
51 defer s.mu.RUnlock()
52
53 if idx >= 0 && idx < len(s.keys) {
54 return s.keys[idx], true
55 }
56 return nil, false
57 }
58
59 func (s *set) Len() int {
60 s.mu.RLock()
61 defer s.mu.RUnlock()
62
63 return len(s.keys)
64 }
65
66
67 func (s *set) indexNL(key Key) int {
68 for i, k := range s.keys {
69 if k == key {
70 return i
71 }
72 }
73 return -1
74 }
75
76 func (s *set) Index(key Key) int {
77 s.mu.RLock()
78 defer s.mu.RUnlock()
79
80 return s.indexNL(key)
81 }
82
83 func (s *set) Add(key Key) bool {
84 s.mu.Lock()
85 defer s.mu.Unlock()
86
87 if i := s.indexNL(key); i > -1 {
88 return false
89 }
90 s.keys = append(s.keys, key)
91 return true
92 }
93
94 func (s *set) Remove(key Key) bool {
95 s.mu.Lock()
96 defer s.mu.Unlock()
97
98 for i, k := range s.keys {
99 if k == key {
100 switch i {
101 case 0:
102 s.keys = s.keys[1:]
103 case len(s.keys) - 1:
104 s.keys = s.keys[:i]
105 default:
106 s.keys = append(s.keys[:i], s.keys[i+1:]...)
107 }
108 return true
109 }
110 }
111 return false
112 }
113
114 func (s *set) Clear() {
115 s.mu.Lock()
116 defer s.mu.Unlock()
117
118 s.keys = nil
119 }
120
121 func (s *set) Iterate(ctx context.Context) KeyIterator {
122 ch := make(chan *KeyPair, s.Len())
123 go iterate(ctx, s.keys, ch)
124 return arrayiter.New(ch)
125 }
126
127 func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
128 defer close(ch)
129
130 for i, key := range keys {
131 pair := &KeyPair{Index: i, Value: key}
132 select {
133 case <-ctx.Done():
134 return
135 case ch <- pair:
136 }
137 }
138 }
139
140 func (s *set) MarshalJSON() ([]byte, error) {
141 s.mu.RLock()
142 defer s.mu.RUnlock()
143
144 buf := pool.GetBytesBuffer()
145 defer pool.ReleaseBytesBuffer(buf)
146 enc := json.NewEncoder(buf)
147
148 fields := []string{keysKey}
149 for k := range s.privateParams {
150 fields = append(fields, k)
151 }
152 sort.Strings(fields)
153
154 buf.WriteByte('{')
155 for i, field := range fields {
156 if i > 0 {
157 buf.WriteByte(',')
158 }
159 fmt.Fprintf(buf, `%q:`, field)
160 if field != keysKey {
161 if err := enc.Encode(s.privateParams[field]); err != nil {
162 return nil, errors.Wrapf(err, `failed to marshal field %q`, field)
163 }
164 } else {
165 buf.WriteByte('[')
166 for j, k := range s.keys {
167 if j > 0 {
168 buf.WriteByte(',')
169 }
170 if err := enc.Encode(k); err != nil {
171 return nil, errors.Wrapf(err, `failed to marshal key #%d`, i)
172 }
173 }
174 buf.WriteByte(']')
175 }
176 }
177 buf.WriteByte('}')
178
179 ret := make([]byte, buf.Len())
180 copy(ret, buf.Bytes())
181 return ret, nil
182 }
183
184 func (s *set) UnmarshalJSON(data []byte) error {
185 s.mu.Lock()
186 defer s.mu.Unlock()
187
188 s.privateParams = make(map[string]interface{})
189 s.keys = nil
190
191 var options []ParseOption
192 var ignoreParseError bool
193 if dc := s.dc; dc != nil {
194 if localReg := dc.Registry(); localReg != nil {
195 options = append(options, withLocalRegistry(localReg))
196 }
197 ignoreParseError = dc.IgnoreParseError()
198 }
199
200 var sawKeysField bool
201 dec := json.NewDecoder(bytes.NewReader(data))
202 LOOP:
203 for {
204 tok, err := dec.Token()
205 if err != nil {
206 return errors.Wrap(err, `error reading token`)
207 }
208
209 switch tok := tok.(type) {
210 case json.Delim:
211
212
213 if tok == '}' {
214 break LOOP
215 } else if tok != '{' {
216 return errors.Errorf(`expected '{', but got '%c'`, tok)
217 }
218 case string:
219 switch tok {
220 case "keys":
221 sawKeysField = true
222 var list []json.RawMessage
223 if err := dec.Decode(&list); err != nil {
224 return errors.Wrap(err, `failed to decode "keys"`)
225 }
226
227 for i, keysrc := range list {
228 key, err := ParseKey(keysrc, options...)
229 if err != nil {
230 if !ignoreParseError {
231 return errors.Wrapf(err, `failed to decode key #%d in "keys"`, i)
232 }
233 continue
234 }
235 s.keys = append(s.keys, key)
236 }
237 default:
238 var v interface{}
239 if err := dec.Decode(&v); err != nil {
240 return errors.Wrapf(err, `failed to decode value for key %q`, tok)
241 }
242 s.privateParams[tok] = v
243 }
244 }
245 }
246
247
248
249
250
251
252 if !sawKeysField {
253 key, err := ParseKey(data, options...)
254 if err != nil {
255 return errors.Wrapf(err, `failed to parse sole key in key set`)
256 }
257 s.keys = append(s.keys, key)
258 }
259 return nil
260 }
261
262 func (s *set) LookupKeyID(kid string) (Key, bool) {
263 s.mu.RLock()
264 defer s.mu.RUnlock()
265
266 n := s.Len()
267 for i := 0; i < n; i++ {
268 key, ok := s.Get(i)
269 if !ok {
270 return nil, false
271 }
272 if key.KeyID() == kid {
273 return key, true
274 }
275 }
276 return nil, false
277 }
278
279 func (s *set) DecodeCtx() DecodeCtx {
280 s.mu.RLock()
281 defer s.mu.RUnlock()
282 return s.dc
283 }
284
285 func (s *set) SetDecodeCtx(dc DecodeCtx) {
286 s.mu.Lock()
287 defer s.mu.Unlock()
288 s.dc = dc
289 }
290
291 func (s *set) Clone() (Set, error) {
292 s2 := &set{}
293
294 s.mu.RLock()
295 defer s.mu.RUnlock()
296
297 s2.keys = make([]Key, len(s.keys))
298
299 for i := 0; i < len(s.keys); i++ {
300 s2.keys[i] = s.keys[i]
301 }
302 return s2, nil
303 }
304
View as plain text