...

Source file src/github.com/lestrrat-go/jwx/jwt/token_gen.go

Documentation: github.com/lestrrat-go/jwx/jwt

     1  // This file is auto-generated by jwt/internal/cmd/gentoken/main.go. DO NOT EDIT
     2  
     3  package jwt
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"sort"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/lestrrat-go/iter/mapiter"
    13  	"github.com/lestrrat-go/jwx/internal/base64"
    14  	"github.com/lestrrat-go/jwx/internal/iter"
    15  	"github.com/lestrrat-go/jwx/internal/json"
    16  	"github.com/lestrrat-go/jwx/internal/pool"
    17  	"github.com/lestrrat-go/jwx/jwt/internal/types"
    18  	"github.com/pkg/errors"
    19  )
    20  
    21  const (
    22  	AudienceKey   = "aud"
    23  	ExpirationKey = "exp"
    24  	IssuedAtKey   = "iat"
    25  	IssuerKey     = "iss"
    26  	JwtIDKey      = "jti"
    27  	NotBeforeKey  = "nbf"
    28  	SubjectKey    = "sub"
    29  )
    30  
    31  // Token represents a generic JWT token.
    32  // which are type-aware (to an extent). Other claims may be accessed via the `Get`/`Set`
    33  // methods but their types are not taken into consideration at all. If you have non-standard
    34  // claims that you must frequently access, consider creating accessors functions
    35  // like the following
    36  //
    37  // func SetFoo(tok jwt.Token) error
    38  // func GetFoo(tok jwt.Token) (*Customtyp, error)
    39  //
    40  // Embedding jwt.Token into another struct is not recommended, because
    41  // jwt.Token needs to handle private claims, and this really does not
    42  // work well when it is embedded in other structure
    43  type Token interface {
    44  
    45  	// Audience returns the value for "aud" field of the token
    46  	Audience() []string
    47  
    48  	// Expiration returns the value for "exp" field of the token
    49  	Expiration() time.Time
    50  
    51  	// IssuedAt returns the value for "iat" field of the token
    52  	IssuedAt() time.Time
    53  
    54  	// Issuer returns the value for "iss" field of the token
    55  	Issuer() string
    56  
    57  	// JwtID returns the value for "jti" field of the token
    58  	JwtID() string
    59  
    60  	// NotBefore returns the value for "nbf" field of the token
    61  	NotBefore() time.Time
    62  
    63  	// Subject returns the value for "sub" field of the token
    64  	Subject() string
    65  
    66  	// PrivateClaims return the entire set of fields (claims) in the token
    67  	// *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc.
    68  	PrivateClaims() map[string]interface{}
    69  
    70  	// Get returns the value of the corresponding field in the token, such as
    71  	// `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not
    72  	// exist in the token, the second return value will be `false`
    73  	//
    74  	// If you need to access fields like `alg`, `kid`, `jku`, etc, you need
    75  	// to access the corresponding fields in the JWS/JWE message. For this,
    76  	// you will need to access them by directly parsing the payload using
    77  	// `jws.Parse` and `jwe.Parse`
    78  	Get(string) (interface{}, bool)
    79  
    80  	// Set assigns a value to the corresponding field in the token. Some
    81  	// pre-defined fields such as `nbf`, `iat`, `iss` need their values to
    82  	// be of a specific type. See the other getter methods in this interface
    83  	// for the types of each of these fields
    84  	Set(string, interface{}) error
    85  	Remove(string) error
    86  	Clone() (Token, error)
    87  	Iterate(context.Context) Iterator
    88  	Walk(context.Context, Visitor) error
    89  	AsMap(context.Context) (map[string]interface{}, error)
    90  }
    91  type stdToken struct {
    92  	mu            *sync.RWMutex
    93  	dc            DecodeCtx          // per-object context for decoding
    94  	audience      types.StringList   // https://tools.ietf.org/html/rfc7519#section-4.1.3
    95  	expiration    *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.4
    96  	issuedAt      *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.6
    97  	issuer        *string            // https://tools.ietf.org/html/rfc7519#section-4.1.1
    98  	jwtID         *string            // https://tools.ietf.org/html/rfc7519#section-4.1.7
    99  	notBefore     *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.5
   100  	subject       *string            // https://tools.ietf.org/html/rfc7519#section-4.1.2
   101  	privateClaims map[string]interface{}
   102  }
   103  
   104  // New creates a standard token, with minimal knowledge of
   105  // possible claims. Standard claims include"aud", "exp", "iat", "iss", "jti", "nbf" and "sub".
   106  // Convenience accessors are provided for these standard claims
   107  func New() Token {
   108  	return &stdToken{
   109  		mu:            &sync.RWMutex{},
   110  		privateClaims: make(map[string]interface{}),
   111  	}
   112  }
   113  
   114  func (t *stdToken) Get(name string) (interface{}, bool) {
   115  	t.mu.RLock()
   116  	defer t.mu.RUnlock()
   117  	switch name {
   118  	case AudienceKey:
   119  		if t.audience == nil {
   120  			return nil, false
   121  		}
   122  		v := t.audience.Get()
   123  		return v, true
   124  	case ExpirationKey:
   125  		if t.expiration == nil {
   126  			return nil, false
   127  		}
   128  		v := t.expiration.Get()
   129  		return v, true
   130  	case IssuedAtKey:
   131  		if t.issuedAt == nil {
   132  			return nil, false
   133  		}
   134  		v := t.issuedAt.Get()
   135  		return v, true
   136  	case IssuerKey:
   137  		if t.issuer == nil {
   138  			return nil, false
   139  		}
   140  		v := *(t.issuer)
   141  		return v, true
   142  	case JwtIDKey:
   143  		if t.jwtID == nil {
   144  			return nil, false
   145  		}
   146  		v := *(t.jwtID)
   147  		return v, true
   148  	case NotBeforeKey:
   149  		if t.notBefore == nil {
   150  			return nil, false
   151  		}
   152  		v := t.notBefore.Get()
   153  		return v, true
   154  	case SubjectKey:
   155  		if t.subject == nil {
   156  			return nil, false
   157  		}
   158  		v := *(t.subject)
   159  		return v, true
   160  	default:
   161  		v, ok := t.privateClaims[name]
   162  		return v, ok
   163  	}
   164  }
   165  
   166  func (t *stdToken) Remove(key string) error {
   167  	t.mu.Lock()
   168  	defer t.mu.Unlock()
   169  	switch key {
   170  	case AudienceKey:
   171  		t.audience = nil
   172  	case ExpirationKey:
   173  		t.expiration = nil
   174  	case IssuedAtKey:
   175  		t.issuedAt = nil
   176  	case IssuerKey:
   177  		t.issuer = nil
   178  	case JwtIDKey:
   179  		t.jwtID = nil
   180  	case NotBeforeKey:
   181  		t.notBefore = nil
   182  	case SubjectKey:
   183  		t.subject = nil
   184  	default:
   185  		delete(t.privateClaims, key)
   186  	}
   187  	return nil
   188  }
   189  
   190  func (t *stdToken) Set(name string, value interface{}) error {
   191  	t.mu.Lock()
   192  	defer t.mu.Unlock()
   193  	return t.setNoLock(name, value)
   194  }
   195  
   196  func (t *stdToken) DecodeCtx() DecodeCtx {
   197  	t.mu.RLock()
   198  	defer t.mu.RUnlock()
   199  	return t.dc
   200  }
   201  
   202  func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
   203  	t.mu.Lock()
   204  	defer t.mu.Unlock()
   205  	t.dc = v
   206  }
   207  
   208  func (t *stdToken) setNoLock(name string, value interface{}) error {
   209  	switch name {
   210  	case AudienceKey:
   211  		var acceptor types.StringList
   212  		if err := acceptor.Accept(value); err != nil {
   213  			return errors.Wrapf(err, `invalid value for %s key`, AudienceKey)
   214  		}
   215  		t.audience = acceptor
   216  		return nil
   217  	case ExpirationKey:
   218  		var acceptor types.NumericDate
   219  		if err := acceptor.Accept(value); err != nil {
   220  			return errors.Wrapf(err, `invalid value for %s key`, ExpirationKey)
   221  		}
   222  		t.expiration = &acceptor
   223  		return nil
   224  	case IssuedAtKey:
   225  		var acceptor types.NumericDate
   226  		if err := acceptor.Accept(value); err != nil {
   227  			return errors.Wrapf(err, `invalid value for %s key`, IssuedAtKey)
   228  		}
   229  		t.issuedAt = &acceptor
   230  		return nil
   231  	case IssuerKey:
   232  		if v, ok := value.(string); ok {
   233  			t.issuer = &v
   234  			return nil
   235  		}
   236  		return errors.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
   237  	case JwtIDKey:
   238  		if v, ok := value.(string); ok {
   239  			t.jwtID = &v
   240  			return nil
   241  		}
   242  		return errors.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
   243  	case NotBeforeKey:
   244  		var acceptor types.NumericDate
   245  		if err := acceptor.Accept(value); err != nil {
   246  			return errors.Wrapf(err, `invalid value for %s key`, NotBeforeKey)
   247  		}
   248  		t.notBefore = &acceptor
   249  		return nil
   250  	case SubjectKey:
   251  		if v, ok := value.(string); ok {
   252  			t.subject = &v
   253  			return nil
   254  		}
   255  		return errors.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
   256  	default:
   257  		if t.privateClaims == nil {
   258  			t.privateClaims = map[string]interface{}{}
   259  		}
   260  		t.privateClaims[name] = value
   261  	}
   262  	return nil
   263  }
   264  
   265  func (t *stdToken) Audience() []string {
   266  	t.mu.RLock()
   267  	defer t.mu.RUnlock()
   268  	if t.audience != nil {
   269  		return t.audience.Get()
   270  	}
   271  	return nil
   272  }
   273  
   274  func (t *stdToken) Expiration() time.Time {
   275  	t.mu.RLock()
   276  	defer t.mu.RUnlock()
   277  	if t.expiration != nil {
   278  		return t.expiration.Get()
   279  	}
   280  	return time.Time{}
   281  }
   282  
   283  func (t *stdToken) IssuedAt() time.Time {
   284  	t.mu.RLock()
   285  	defer t.mu.RUnlock()
   286  	if t.issuedAt != nil {
   287  		return t.issuedAt.Get()
   288  	}
   289  	return time.Time{}
   290  }
   291  
   292  func (t *stdToken) Issuer() string {
   293  	t.mu.RLock()
   294  	defer t.mu.RUnlock()
   295  	if t.issuer != nil {
   296  		return *(t.issuer)
   297  	}
   298  	return ""
   299  }
   300  
   301  func (t *stdToken) JwtID() string {
   302  	t.mu.RLock()
   303  	defer t.mu.RUnlock()
   304  	if t.jwtID != nil {
   305  		return *(t.jwtID)
   306  	}
   307  	return ""
   308  }
   309  
   310  func (t *stdToken) NotBefore() time.Time {
   311  	t.mu.RLock()
   312  	defer t.mu.RUnlock()
   313  	if t.notBefore != nil {
   314  		return t.notBefore.Get()
   315  	}
   316  	return time.Time{}
   317  }
   318  
   319  func (t *stdToken) Subject() string {
   320  	t.mu.RLock()
   321  	defer t.mu.RUnlock()
   322  	if t.subject != nil {
   323  		return *(t.subject)
   324  	}
   325  	return ""
   326  }
   327  
   328  func (t *stdToken) PrivateClaims() map[string]interface{} {
   329  	t.mu.RLock()
   330  	defer t.mu.RUnlock()
   331  	return t.privateClaims
   332  }
   333  
   334  func (t *stdToken) makePairs() []*ClaimPair {
   335  	t.mu.RLock()
   336  	defer t.mu.RUnlock()
   337  
   338  	pairs := make([]*ClaimPair, 0, 7)
   339  	if t.audience != nil {
   340  		v := t.audience.Get()
   341  		pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
   342  	}
   343  	if t.expiration != nil {
   344  		v := t.expiration.Get()
   345  		pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
   346  	}
   347  	if t.issuedAt != nil {
   348  		v := t.issuedAt.Get()
   349  		pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
   350  	}
   351  	if t.issuer != nil {
   352  		v := *(t.issuer)
   353  		pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
   354  	}
   355  	if t.jwtID != nil {
   356  		v := *(t.jwtID)
   357  		pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
   358  	}
   359  	if t.notBefore != nil {
   360  		v := t.notBefore.Get()
   361  		pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
   362  	}
   363  	if t.subject != nil {
   364  		v := *(t.subject)
   365  		pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
   366  	}
   367  	for k, v := range t.privateClaims {
   368  		pairs = append(pairs, &ClaimPair{Key: k, Value: v})
   369  	}
   370  	sort.Slice(pairs, func(i, j int) bool {
   371  		return pairs[i].Key.(string) < pairs[j].Key.(string)
   372  	})
   373  	return pairs
   374  }
   375  
   376  func (t *stdToken) UnmarshalJSON(buf []byte) error {
   377  	t.mu.Lock()
   378  	defer t.mu.Unlock()
   379  	t.audience = nil
   380  	t.expiration = nil
   381  	t.issuedAt = nil
   382  	t.issuer = nil
   383  	t.jwtID = nil
   384  	t.notBefore = nil
   385  	t.subject = nil
   386  	dec := json.NewDecoder(bytes.NewReader(buf))
   387  LOOP:
   388  	for {
   389  		tok, err := dec.Token()
   390  		if err != nil {
   391  			return errors.Wrap(err, `error reading token`)
   392  		}
   393  		switch tok := tok.(type) {
   394  		case json.Delim:
   395  			// Assuming we're doing everything correctly, we should ONLY
   396  			// get either '{' or '}' here.
   397  			if tok == '}' { // End of object
   398  				break LOOP
   399  			} else if tok != '{' {
   400  				return errors.Errorf(`expected '{', but got '%c'`, tok)
   401  			}
   402  		case string: // Objects can only have string keys
   403  			switch tok {
   404  			case AudienceKey:
   405  				var decoded types.StringList
   406  				if err := dec.Decode(&decoded); err != nil {
   407  					return errors.Wrapf(err, `failed to decode value for key %s`, AudienceKey)
   408  				}
   409  				t.audience = decoded
   410  			case ExpirationKey:
   411  				var decoded types.NumericDate
   412  				if err := dec.Decode(&decoded); err != nil {
   413  					return errors.Wrapf(err, `failed to decode value for key %s`, ExpirationKey)
   414  				}
   415  				t.expiration = &decoded
   416  			case IssuedAtKey:
   417  				var decoded types.NumericDate
   418  				if err := dec.Decode(&decoded); err != nil {
   419  					return errors.Wrapf(err, `failed to decode value for key %s`, IssuedAtKey)
   420  				}
   421  				t.issuedAt = &decoded
   422  			case IssuerKey:
   423  				if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
   424  					return errors.Wrapf(err, `failed to decode value for key %s`, IssuerKey)
   425  				}
   426  			case JwtIDKey:
   427  				if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
   428  					return errors.Wrapf(err, `failed to decode value for key %s`, JwtIDKey)
   429  				}
   430  			case NotBeforeKey:
   431  				var decoded types.NumericDate
   432  				if err := dec.Decode(&decoded); err != nil {
   433  					return errors.Wrapf(err, `failed to decode value for key %s`, NotBeforeKey)
   434  				}
   435  				t.notBefore = &decoded
   436  			case SubjectKey:
   437  				if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
   438  					return errors.Wrapf(err, `failed to decode value for key %s`, SubjectKey)
   439  				}
   440  			default:
   441  				if dc := t.dc; dc != nil {
   442  					if localReg := dc.Registry(); localReg != nil {
   443  						decoded, err := localReg.Decode(dec, tok)
   444  						if err == nil {
   445  							t.setNoLock(tok, decoded)
   446  							continue
   447  						}
   448  					}
   449  				}
   450  				decoded, err := registry.Decode(dec, tok)
   451  				if err == nil {
   452  					t.setNoLock(tok, decoded)
   453  					continue
   454  				}
   455  				return errors.Wrapf(err, `could not decode field %s`, tok)
   456  			}
   457  		default:
   458  			return errors.Errorf(`invalid token %T`, tok)
   459  		}
   460  	}
   461  	return nil
   462  }
   463  
   464  func (t stdToken) MarshalJSON() ([]byte, error) {
   465  	t.mu.RLock()
   466  	defer t.mu.RUnlock()
   467  	buf := pool.GetBytesBuffer()
   468  	defer pool.ReleaseBytesBuffer(buf)
   469  	buf.WriteByte('{')
   470  	enc := json.NewEncoder(buf)
   471  	for i, pair := range t.makePairs() {
   472  		f := pair.Key.(string)
   473  		if i > 0 {
   474  			buf.WriteByte(',')
   475  		}
   476  		buf.WriteRune('"')
   477  		buf.WriteString(f)
   478  		buf.WriteString(`":`)
   479  		switch f {
   480  		case AudienceKey:
   481  			if err := json.EncodeAudience(enc, pair.Value.([]string)); err != nil {
   482  				return nil, errors.Wrap(err, `failed to encode "aud"`)
   483  			}
   484  			continue
   485  		case ExpirationKey, IssuedAtKey, NotBeforeKey:
   486  			enc.Encode(pair.Value.(time.Time).Unix())
   487  			continue
   488  		}
   489  		switch v := pair.Value.(type) {
   490  		case []byte:
   491  			buf.WriteRune('"')
   492  			buf.WriteString(base64.EncodeToString(v))
   493  			buf.WriteRune('"')
   494  		default:
   495  			if err := enc.Encode(v); err != nil {
   496  				return nil, errors.Wrapf(err, `failed to marshal field %s`, f)
   497  			}
   498  			buf.Truncate(buf.Len() - 1)
   499  		}
   500  	}
   501  	buf.WriteByte('}')
   502  	ret := make([]byte, buf.Len())
   503  	copy(ret, buf.Bytes())
   504  	return ret, nil
   505  }
   506  
   507  func (t *stdToken) Iterate(ctx context.Context) Iterator {
   508  	pairs := t.makePairs()
   509  	ch := make(chan *ClaimPair, len(pairs))
   510  	go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
   511  		defer close(ch)
   512  		for _, pair := range pairs {
   513  			select {
   514  			case <-ctx.Done():
   515  				return
   516  			case ch <- pair:
   517  			}
   518  		}
   519  	}(ctx, ch, pairs)
   520  	return mapiter.New(ch)
   521  }
   522  
   523  func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
   524  	return iter.WalkMap(ctx, t, visitor)
   525  }
   526  
   527  func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
   528  	return iter.AsMap(ctx, t)
   529  }
   530  

View as plain text