...

Source file src/github.com/lestrrat-go/jwx/jwk/set.go

Documentation: github.com/lestrrat-go/jwx/jwk

     1  package jwk
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"sort"
     8  
     9  	"github.com/lestrrat-go/iter/arrayiter"
    10  	"github.com/lestrrat-go/jwx/internal/json"
    11  	"github.com/lestrrat-go/jwx/internal/pool"
    12  	"github.com/pkg/errors"
    13  )
    14  
    15  const keysKey = `keys` // appease linter
    16  
    17  // NewSet creates and empty `jwk.Set` object
    18  func NewSet() Set {
    19  	return &set{
    20  		privateParams: make(map[string]interface{}),
    21  	}
    22  }
    23  
    24  func (s *set) Set(n string, v interface{}) error {
    25  	s.mu.RLock()
    26  	defer s.mu.RUnlock()
    27  
    28  	if n == keysKey {
    29  		vl, ok := v.([]Key)
    30  		if !ok {
    31  			return errors.Errorf(`value for field "keys" must be []jwk.Key`)
    32  		}
    33  		s.keys = vl
    34  		return nil
    35  	}
    36  
    37  	s.privateParams[n] = v
    38  	return nil
    39  }
    40  
    41  func (s *set) Field(n string) (interface{}, bool) {
    42  	s.mu.RLock()
    43  	defer s.mu.RUnlock()
    44  
    45  	v, ok := s.privateParams[n]
    46  	return v, ok
    47  }
    48  
    49  func (s *set) Get(idx int) (Key, bool) {
    50  	s.mu.RLock()
    51  	defer s.mu.RUnlock()
    52  
    53  	if idx >= 0 && idx < len(s.keys) {
    54  		return s.keys[idx], true
    55  	}
    56  	return nil, false
    57  }
    58  
    59  func (s *set) Len() int {
    60  	s.mu.RLock()
    61  	defer s.mu.RUnlock()
    62  
    63  	return len(s.keys)
    64  }
    65  
    66  // indexNL is Index(), but without the locking
    67  func (s *set) indexNL(key Key) int {
    68  	for i, k := range s.keys {
    69  		if k == key {
    70  			return i
    71  		}
    72  	}
    73  	return -1
    74  }
    75  
    76  func (s *set) Index(key Key) int {
    77  	s.mu.RLock()
    78  	defer s.mu.RUnlock()
    79  
    80  	return s.indexNL(key)
    81  }
    82  
    83  func (s *set) Add(key Key) bool {
    84  	s.mu.Lock()
    85  	defer s.mu.Unlock()
    86  
    87  	if i := s.indexNL(key); i > -1 {
    88  		return false
    89  	}
    90  	s.keys = append(s.keys, key)
    91  	return true
    92  }
    93  
    94  func (s *set) Remove(key Key) bool {
    95  	s.mu.Lock()
    96  	defer s.mu.Unlock()
    97  
    98  	for i, k := range s.keys {
    99  		if k == key {
   100  			switch i {
   101  			case 0:
   102  				s.keys = s.keys[1:]
   103  			case len(s.keys) - 1:
   104  				s.keys = s.keys[:i]
   105  			default:
   106  				s.keys = append(s.keys[:i], s.keys[i+1:]...)
   107  			}
   108  			return true
   109  		}
   110  	}
   111  	return false
   112  }
   113  
   114  func (s *set) Clear() {
   115  	s.mu.Lock()
   116  	defer s.mu.Unlock()
   117  
   118  	s.keys = nil
   119  }
   120  
   121  func (s *set) Iterate(ctx context.Context) KeyIterator {
   122  	ch := make(chan *KeyPair, s.Len())
   123  	go iterate(ctx, s.keys, ch)
   124  	return arrayiter.New(ch)
   125  }
   126  
   127  func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
   128  	defer close(ch)
   129  
   130  	for i, key := range keys {
   131  		pair := &KeyPair{Index: i, Value: key}
   132  		select {
   133  		case <-ctx.Done():
   134  			return
   135  		case ch <- pair:
   136  		}
   137  	}
   138  }
   139  
   140  func (s *set) MarshalJSON() ([]byte, error) {
   141  	s.mu.RLock()
   142  	defer s.mu.RUnlock()
   143  
   144  	buf := pool.GetBytesBuffer()
   145  	defer pool.ReleaseBytesBuffer(buf)
   146  	enc := json.NewEncoder(buf)
   147  
   148  	fields := []string{keysKey}
   149  	for k := range s.privateParams {
   150  		fields = append(fields, k)
   151  	}
   152  	sort.Strings(fields)
   153  
   154  	buf.WriteByte('{')
   155  	for i, field := range fields {
   156  		if i > 0 {
   157  			buf.WriteByte(',')
   158  		}
   159  		fmt.Fprintf(buf, `%q:`, field)
   160  		if field != keysKey {
   161  			if err := enc.Encode(s.privateParams[field]); err != nil {
   162  				return nil, errors.Wrapf(err, `failed to marshal field %q`, field)
   163  			}
   164  		} else {
   165  			buf.WriteByte('[')
   166  			for j, k := range s.keys {
   167  				if j > 0 {
   168  					buf.WriteByte(',')
   169  				}
   170  				if err := enc.Encode(k); err != nil {
   171  					return nil, errors.Wrapf(err, `failed to marshal key #%d`, i)
   172  				}
   173  			}
   174  			buf.WriteByte(']')
   175  		}
   176  	}
   177  	buf.WriteByte('}')
   178  
   179  	ret := make([]byte, buf.Len())
   180  	copy(ret, buf.Bytes())
   181  	return ret, nil
   182  }
   183  
   184  func (s *set) UnmarshalJSON(data []byte) error {
   185  	s.mu.Lock()
   186  	defer s.mu.Unlock()
   187  
   188  	s.privateParams = make(map[string]interface{})
   189  	s.keys = nil
   190  
   191  	var options []ParseOption
   192  	var ignoreParseError bool
   193  	if dc := s.dc; dc != nil {
   194  		if localReg := dc.Registry(); localReg != nil {
   195  			options = append(options, withLocalRegistry(localReg))
   196  		}
   197  		ignoreParseError = dc.IgnoreParseError()
   198  	}
   199  
   200  	var sawKeysField bool
   201  	dec := json.NewDecoder(bytes.NewReader(data))
   202  LOOP:
   203  	for {
   204  		tok, err := dec.Token()
   205  		if err != nil {
   206  			return errors.Wrap(err, `error reading token`)
   207  		}
   208  
   209  		switch tok := tok.(type) {
   210  		case json.Delim:
   211  			// Assuming we're doing everything correctly, we should ONLY
   212  			// get either '{' or '}' here.
   213  			if tok == '}' { // End of object
   214  				break LOOP
   215  			} else if tok != '{' {
   216  				return errors.Errorf(`expected '{', but got '%c'`, tok)
   217  			}
   218  		case string:
   219  			switch tok {
   220  			case "keys":
   221  				sawKeysField = true
   222  				var list []json.RawMessage
   223  				if err := dec.Decode(&list); err != nil {
   224  					return errors.Wrap(err, `failed to decode "keys"`)
   225  				}
   226  
   227  				for i, keysrc := range list {
   228  					key, err := ParseKey(keysrc, options...)
   229  					if err != nil {
   230  						if !ignoreParseError {
   231  							return errors.Wrapf(err, `failed to decode key #%d in "keys"`, i)
   232  						}
   233  						continue
   234  					}
   235  					s.keys = append(s.keys, key)
   236  				}
   237  			default:
   238  				var v interface{}
   239  				if err := dec.Decode(&v); err != nil {
   240  					return errors.Wrapf(err, `failed to decode value for key %q`, tok)
   241  				}
   242  				s.privateParams[tok] = v
   243  			}
   244  		}
   245  	}
   246  
   247  	// This is really silly, but we can only detect the
   248  	// lack of the "keys" field after going through the
   249  	// entire object once
   250  	// Not checking for len(s.keys) == 0, because it could be
   251  	// an empty key set
   252  	if !sawKeysField {
   253  		key, err := ParseKey(data, options...)
   254  		if err != nil {
   255  			return errors.Wrapf(err, `failed to parse sole key in key set`)
   256  		}
   257  		s.keys = append(s.keys, key)
   258  	}
   259  	return nil
   260  }
   261  
   262  func (s *set) LookupKeyID(kid string) (Key, bool) {
   263  	s.mu.RLock()
   264  	defer s.mu.RUnlock()
   265  
   266  	n := s.Len()
   267  	for i := 0; i < n; i++ {
   268  		key, ok := s.Get(i)
   269  		if !ok {
   270  			return nil, false
   271  		}
   272  		if key.KeyID() == kid {
   273  			return key, true
   274  		}
   275  	}
   276  	return nil, false
   277  }
   278  
   279  func (s *set) DecodeCtx() DecodeCtx {
   280  	s.mu.RLock()
   281  	defer s.mu.RUnlock()
   282  	return s.dc
   283  }
   284  
   285  func (s *set) SetDecodeCtx(dc DecodeCtx) {
   286  	s.mu.Lock()
   287  	defer s.mu.Unlock()
   288  	s.dc = dc
   289  }
   290  
   291  func (s *set) Clone() (Set, error) {
   292  	s2 := &set{}
   293  
   294  	s.mu.RLock()
   295  	defer s.mu.RUnlock()
   296  
   297  	s2.keys = make([]Key, len(s.keys))
   298  
   299  	for i := 0; i < len(s.keys); i++ {
   300  		s2.keys[i] = s.keys[i]
   301  	}
   302  	return s2, nil
   303  }
   304  

View as plain text