1
2
3 package jwk
4
5 import (
6 "bytes"
7 "context"
8 "crypto/x509"
9 "fmt"
10 "sort"
11 "sync"
12
13 "github.com/lestrrat-go/iter/mapiter"
14 "github.com/lestrrat-go/jwx/internal/base64"
15 "github.com/lestrrat-go/jwx/internal/iter"
16 "github.com/lestrrat-go/jwx/internal/json"
17 "github.com/lestrrat-go/jwx/internal/pool"
18 "github.com/lestrrat-go/jwx/jwa"
19 "github.com/pkg/errors"
20 )
21
22 const (
23 SymmetricOctetsKey = "k"
24 )
25
26 type SymmetricKey interface {
27 Key
28 FromRaw([]byte) error
29 Octets() []byte
30 }
31
32 type symmetricKey struct {
33 algorithm *string
34 keyID *string
35 keyOps *KeyOperationList
36 keyUsage *string
37 octets []byte
38 x509CertChain *CertificateChain
39 x509CertThumbprint *string
40 x509CertThumbprintS256 *string
41 x509URL *string
42 privateParams map[string]interface{}
43 mu *sync.RWMutex
44 dc json.DecodeCtx
45 }
46
47 func NewSymmetricKey() SymmetricKey {
48 return newSymmetricKey()
49 }
50
51 func newSymmetricKey() *symmetricKey {
52 return &symmetricKey{
53 mu: &sync.RWMutex{},
54 privateParams: make(map[string]interface{}),
55 }
56 }
57
58 func (h symmetricKey) KeyType() jwa.KeyType {
59 return jwa.OctetSeq
60 }
61
62 func (h *symmetricKey) Algorithm() string {
63 if h.algorithm != nil {
64 return *(h.algorithm)
65 }
66 return ""
67 }
68
69 func (h *symmetricKey) KeyID() string {
70 if h.keyID != nil {
71 return *(h.keyID)
72 }
73 return ""
74 }
75
76 func (h *symmetricKey) KeyOps() KeyOperationList {
77 if h.keyOps != nil {
78 return *(h.keyOps)
79 }
80 return nil
81 }
82
83 func (h *symmetricKey) KeyUsage() string {
84 if h.keyUsage != nil {
85 return *(h.keyUsage)
86 }
87 return ""
88 }
89
90 func (h *symmetricKey) Octets() []byte {
91 return h.octets
92 }
93
94 func (h *symmetricKey) X509CertChain() []*x509.Certificate {
95 if h.x509CertChain != nil {
96 return h.x509CertChain.Get()
97 }
98 return nil
99 }
100
101 func (h *symmetricKey) X509CertThumbprint() string {
102 if h.x509CertThumbprint != nil {
103 return *(h.x509CertThumbprint)
104 }
105 return ""
106 }
107
108 func (h *symmetricKey) X509CertThumbprintS256() string {
109 if h.x509CertThumbprintS256 != nil {
110 return *(h.x509CertThumbprintS256)
111 }
112 return ""
113 }
114
115 func (h *symmetricKey) X509URL() string {
116 if h.x509URL != nil {
117 return *(h.x509URL)
118 }
119 return ""
120 }
121
122 func (h *symmetricKey) makePairs() []*HeaderPair {
123 h.mu.RLock()
124 defer h.mu.RUnlock()
125
126 var pairs []*HeaderPair
127 pairs = append(pairs, &HeaderPair{Key: "kty", Value: jwa.OctetSeq})
128 if h.algorithm != nil {
129 pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)})
130 }
131 if h.keyID != nil {
132 pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)})
133 }
134 if h.keyOps != nil {
135 pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyOps)})
136 }
137 if h.keyUsage != nil {
138 pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)})
139 }
140 if h.octets != nil {
141 pairs = append(pairs, &HeaderPair{Key: SymmetricOctetsKey, Value: h.octets})
142 }
143 if h.x509CertChain != nil {
144 pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: *(h.x509CertChain)})
145 }
146 if h.x509CertThumbprint != nil {
147 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)})
148 }
149 if h.x509CertThumbprintS256 != nil {
150 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)})
151 }
152 if h.x509URL != nil {
153 pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)})
154 }
155 for k, v := range h.privateParams {
156 pairs = append(pairs, &HeaderPair{Key: k, Value: v})
157 }
158 return pairs
159 }
160
161 func (h *symmetricKey) PrivateParams() map[string]interface{} {
162 return h.privateParams
163 }
164
165 func (h *symmetricKey) Get(name string) (interface{}, bool) {
166 h.mu.RLock()
167 defer h.mu.RUnlock()
168 switch name {
169 case KeyTypeKey:
170 return h.KeyType(), true
171 case AlgorithmKey:
172 if h.algorithm == nil {
173 return nil, false
174 }
175 return *(h.algorithm), true
176 case KeyIDKey:
177 if h.keyID == nil {
178 return nil, false
179 }
180 return *(h.keyID), true
181 case KeyOpsKey:
182 if h.keyOps == nil {
183 return nil, false
184 }
185 return *(h.keyOps), true
186 case KeyUsageKey:
187 if h.keyUsage == nil {
188 return nil, false
189 }
190 return *(h.keyUsage), true
191 case SymmetricOctetsKey:
192 if h.octets == nil {
193 return nil, false
194 }
195 return h.octets, true
196 case X509CertChainKey:
197 if h.x509CertChain == nil {
198 return nil, false
199 }
200 return h.x509CertChain.Get(), true
201 case X509CertThumbprintKey:
202 if h.x509CertThumbprint == nil {
203 return nil, false
204 }
205 return *(h.x509CertThumbprint), true
206 case X509CertThumbprintS256Key:
207 if h.x509CertThumbprintS256 == nil {
208 return nil, false
209 }
210 return *(h.x509CertThumbprintS256), true
211 case X509URLKey:
212 if h.x509URL == nil {
213 return nil, false
214 }
215 return *(h.x509URL), true
216 default:
217 v, ok := h.privateParams[name]
218 return v, ok
219 }
220 }
221
222 func (h *symmetricKey) Set(name string, value interface{}) error {
223 h.mu.Lock()
224 defer h.mu.Unlock()
225 return h.setNoLock(name, value)
226 }
227
228 func (h *symmetricKey) setNoLock(name string, value interface{}) error {
229 switch name {
230 case "kty":
231 return nil
232 case AlgorithmKey:
233 switch v := value.(type) {
234 case string:
235 h.algorithm = &v
236 case fmt.Stringer:
237 tmp := v.String()
238 h.algorithm = &tmp
239 default:
240 return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value)
241 }
242 return nil
243 case KeyIDKey:
244 if v, ok := value.(string); ok {
245 h.keyID = &v
246 return nil
247 }
248 return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
249 case KeyOpsKey:
250 var acceptor KeyOperationList
251 if err := acceptor.Accept(value); err != nil {
252 return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey)
253 }
254 h.keyOps = &acceptor
255 return nil
256 case KeyUsageKey:
257 switch v := value.(type) {
258 case KeyUsageType:
259 switch v {
260 case ForSignature, ForEncryption:
261 tmp := v.String()
262 h.keyUsage = &tmp
263 default:
264 return errors.Errorf(`invalid key usage type %s`, v)
265 }
266 case string:
267 h.keyUsage = &v
268 default:
269 return errors.Errorf(`invalid key usage type %s`, v)
270 }
271 case SymmetricOctetsKey:
272 if v, ok := value.([]byte); ok {
273 h.octets = v
274 return nil
275 }
276 return errors.Errorf(`invalid value for %s key: %T`, SymmetricOctetsKey, value)
277 case X509CertChainKey:
278 var acceptor CertificateChain
279 if err := acceptor.Accept(value); err != nil {
280 return errors.Wrapf(err, `invalid value for %s key`, X509CertChainKey)
281 }
282 h.x509CertChain = &acceptor
283 return nil
284 case X509CertThumbprintKey:
285 if v, ok := value.(string); ok {
286 h.x509CertThumbprint = &v
287 return nil
288 }
289 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
290 case X509CertThumbprintS256Key:
291 if v, ok := value.(string); ok {
292 h.x509CertThumbprintS256 = &v
293 return nil
294 }
295 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
296 case X509URLKey:
297 if v, ok := value.(string); ok {
298 h.x509URL = &v
299 return nil
300 }
301 return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
302 default:
303 if h.privateParams == nil {
304 h.privateParams = map[string]interface{}{}
305 }
306 h.privateParams[name] = value
307 }
308 return nil
309 }
310
311 func (k *symmetricKey) Remove(key string) error {
312 k.mu.Lock()
313 defer k.mu.Unlock()
314 switch key {
315 case AlgorithmKey:
316 k.algorithm = nil
317 case KeyIDKey:
318 k.keyID = nil
319 case KeyOpsKey:
320 k.keyOps = nil
321 case KeyUsageKey:
322 k.keyUsage = nil
323 case SymmetricOctetsKey:
324 k.octets = nil
325 case X509CertChainKey:
326 k.x509CertChain = nil
327 case X509CertThumbprintKey:
328 k.x509CertThumbprint = nil
329 case X509CertThumbprintS256Key:
330 k.x509CertThumbprintS256 = nil
331 case X509URLKey:
332 k.x509URL = nil
333 default:
334 delete(k.privateParams, key)
335 }
336 return nil
337 }
338
339 func (k *symmetricKey) Clone() (Key, error) {
340 return cloneKey(k)
341 }
342
343 func (k *symmetricKey) DecodeCtx() json.DecodeCtx {
344 k.mu.RLock()
345 defer k.mu.RUnlock()
346 return k.dc
347 }
348
349 func (k *symmetricKey) SetDecodeCtx(dc json.DecodeCtx) {
350 k.mu.Lock()
351 defer k.mu.Unlock()
352 k.dc = dc
353 }
354
355 func (h *symmetricKey) UnmarshalJSON(buf []byte) error {
356 h.algorithm = nil
357 h.keyID = nil
358 h.keyOps = nil
359 h.keyUsage = nil
360 h.octets = nil
361 h.x509CertChain = nil
362 h.x509CertThumbprint = nil
363 h.x509CertThumbprintS256 = nil
364 h.x509URL = nil
365 dec := json.NewDecoder(bytes.NewReader(buf))
366 LOOP:
367 for {
368 tok, err := dec.Token()
369 if err != nil {
370 return errors.Wrap(err, `error reading token`)
371 }
372 switch tok := tok.(type) {
373 case json.Delim:
374
375
376 if tok == '}' {
377 break LOOP
378 } else if tok != '{' {
379 return errors.Errorf(`expected '{', but got '%c'`, tok)
380 }
381 case string:
382 switch tok {
383 case KeyTypeKey:
384 val, err := json.ReadNextStringToken(dec)
385 if err != nil {
386 return errors.Wrap(err, `error reading token`)
387 }
388 if val != jwa.OctetSeq.String() {
389 return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val)
390 }
391 case AlgorithmKey:
392 if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil {
393 return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
394 }
395 case KeyIDKey:
396 if err := json.AssignNextStringToken(&h.keyID, dec); err != nil {
397 return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey)
398 }
399 case KeyOpsKey:
400 var decoded KeyOperationList
401 if err := dec.Decode(&decoded); err != nil {
402 return errors.Wrapf(err, `failed to decode value for key %s`, KeyOpsKey)
403 }
404 h.keyOps = &decoded
405 case KeyUsageKey:
406 if err := json.AssignNextStringToken(&h.keyUsage, dec); err != nil {
407 return errors.Wrapf(err, `failed to decode value for key %s`, KeyUsageKey)
408 }
409 case SymmetricOctetsKey:
410 if err := json.AssignNextBytesToken(&h.octets, dec); err != nil {
411 return errors.Wrapf(err, `failed to decode value for key %s`, SymmetricOctetsKey)
412 }
413 case X509CertChainKey:
414 var decoded CertificateChain
415 if err := dec.Decode(&decoded); err != nil {
416 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertChainKey)
417 }
418 h.x509CertChain = &decoded
419 case X509CertThumbprintKey:
420 if err := json.AssignNextStringToken(&h.x509CertThumbprint, dec); err != nil {
421 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintKey)
422 }
423 case X509CertThumbprintS256Key:
424 if err := json.AssignNextStringToken(&h.x509CertThumbprintS256, dec); err != nil {
425 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintS256Key)
426 }
427 case X509URLKey:
428 if err := json.AssignNextStringToken(&h.x509URL, dec); err != nil {
429 return errors.Wrapf(err, `failed to decode value for key %s`, X509URLKey)
430 }
431 default:
432 if dc := h.dc; dc != nil {
433 if localReg := dc.Registry(); localReg != nil {
434 decoded, err := localReg.Decode(dec, tok)
435 if err == nil {
436 h.setNoLock(tok, decoded)
437 continue
438 }
439 }
440 }
441 decoded, err := registry.Decode(dec, tok)
442 if err == nil {
443 h.setNoLock(tok, decoded)
444 continue
445 }
446 return errors.Wrapf(err, `could not decode field %s`, tok)
447 }
448 default:
449 return errors.Errorf(`invalid token %T`, tok)
450 }
451 }
452 if h.octets == nil {
453 return errors.Errorf(`required field k is missing`)
454 }
455 return nil
456 }
457
458 func (h symmetricKey) MarshalJSON() ([]byte, error) {
459 data := make(map[string]interface{})
460 fields := make([]string, 0, 9)
461 for _, pair := range h.makePairs() {
462 fields = append(fields, pair.Key.(string))
463 data[pair.Key.(string)] = pair.Value
464 }
465
466 sort.Strings(fields)
467 buf := pool.GetBytesBuffer()
468 defer pool.ReleaseBytesBuffer(buf)
469 buf.WriteByte('{')
470 enc := json.NewEncoder(buf)
471 for i, f := range fields {
472 if i > 0 {
473 buf.WriteRune(',')
474 }
475 buf.WriteRune('"')
476 buf.WriteString(f)
477 buf.WriteString(`":`)
478 v := data[f]
479 switch v := v.(type) {
480 case []byte:
481 buf.WriteRune('"')
482 buf.WriteString(base64.EncodeToString(v))
483 buf.WriteRune('"')
484 default:
485 if err := enc.Encode(v); err != nil {
486 return nil, errors.Wrapf(err, `failed to encode value for field %s`, f)
487 }
488 buf.Truncate(buf.Len() - 1)
489 }
490 }
491 buf.WriteByte('}')
492 ret := make([]byte, buf.Len())
493 copy(ret, buf.Bytes())
494 return ret, nil
495 }
496
497 func (h *symmetricKey) Iterate(ctx context.Context) HeaderIterator {
498 pairs := h.makePairs()
499 ch := make(chan *HeaderPair, len(pairs))
500 go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
501 defer close(ch)
502 for _, pair := range pairs {
503 select {
504 case <-ctx.Done():
505 return
506 case ch <- pair:
507 }
508 }
509 }(ctx, ch, pairs)
510 return mapiter.New(ch)
511 }
512
513 func (h *symmetricKey) Walk(ctx context.Context, visitor HeaderVisitor) error {
514 return iter.WalkMap(ctx, h, visitor)
515 }
516
517 func (h *symmetricKey) AsMap(ctx context.Context) (map[string]interface{}, error) {
518 return iter.AsMap(ctx, h)
519 }
520
View as plain text