...

Source file src/github.com/lestrrat-go/jwx/jwt/validate.go

Documentation: github.com/lestrrat-go/jwx/jwt

     1  package jwt
     2  
     3  import (
     4  	"context"
     5  	"strconv"
     6  	"time"
     7  
     8  	"github.com/pkg/errors"
     9  )
    10  
    11  type Clock interface {
    12  	Now() time.Time
    13  }
    14  type ClockFunc func() time.Time
    15  
    16  func (f ClockFunc) Now() time.Time {
    17  	return f()
    18  }
    19  
    20  func isSupportedTimeClaim(c string) error {
    21  	switch c {
    22  	case ExpirationKey, IssuedAtKey, NotBeforeKey:
    23  		return nil
    24  	}
    25  	return NewValidationError(errors.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
    26  }
    27  
    28  func timeClaim(t Token, clock Clock, c string) time.Time {
    29  	switch c {
    30  	case ExpirationKey:
    31  		return t.Expiration()
    32  	case IssuedAtKey:
    33  		return t.IssuedAt()
    34  	case NotBeforeKey:
    35  		return t.NotBefore()
    36  	case "":
    37  		return clock.Now()
    38  	}
    39  	return time.Time{} // should *NEVER* reach here, but...
    40  }
    41  
    42  // Validate makes sure that the essential claims stand.
    43  //
    44  // See the various `WithXXX` functions for optional parameters
    45  // that can control the behavior of this method.
    46  func Validate(t Token, options ...ValidateOption) error {
    47  	ctx := context.Background()
    48  	var clock Clock = ClockFunc(time.Now)
    49  	var skew time.Duration
    50  	var validators = []Validator{
    51  		IsIssuedAtValid(),
    52  		IsExpirationValid(),
    53  		IsNbfValid(),
    54  	}
    55  	for _, o := range options {
    56  		//nolint:forcetypeassert
    57  		switch o.Ident() {
    58  		case identClock{}:
    59  			clock = o.Value().(Clock)
    60  		case identAcceptableSkew{}:
    61  			skew = o.Value().(time.Duration)
    62  		case identContext{}:
    63  			ctx = o.Value().(context.Context)
    64  		case identValidator{}:
    65  			v := o.Value().(Validator)
    66  			switch v := v.(type) {
    67  			case *isInTimeRange:
    68  				if v.c1 != "" {
    69  					if err := isSupportedTimeClaim(v.c1); err != nil {
    70  						return err
    71  					}
    72  					validators = append(validators, IsRequired(v.c1))
    73  				}
    74  				if v.c2 != "" {
    75  					if err := isSupportedTimeClaim(v.c2); err != nil {
    76  						return err
    77  					}
    78  					validators = append(validators, IsRequired(v.c2))
    79  				}
    80  			}
    81  			validators = append(validators, v)
    82  		}
    83  	}
    84  
    85  	ctx = SetValidationCtxSkew(ctx, skew)
    86  	ctx = SetValidationCtxClock(ctx, clock)
    87  	for _, v := range validators {
    88  		if err := v.Validate(ctx, t); err != nil {
    89  			return err
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  type isInTimeRange struct {
    97  	c1   string
    98  	c2   string
    99  	dur  time.Duration
   100  	less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2
   101  }
   102  
   103  // MaxDeltaIs implements the logic behind `WithMaxDelta()` option
   104  func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
   105  	return &isInTimeRange{
   106  		c1:   c1,
   107  		c2:   c2,
   108  		dur:  dur,
   109  		less: true,
   110  	}
   111  }
   112  
   113  // MinDeltaIs implements the logic behind `WithMinDelta()` option
   114  func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
   115  	return &isInTimeRange{
   116  		c1:   c1,
   117  		c2:   c2,
   118  		dur:  dur,
   119  		less: false,
   120  	}
   121  }
   122  
   123  func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) error {
   124  	clock := ValidationCtxClock(ctx) // MUST be populated
   125  	skew := ValidationCtxSkew(ctx)   // MUST be populated
   126  	// We don't check if the claims already exist, because we already did that
   127  	// by piggybacking on `required` check.
   128  	t1 := timeClaim(t, clock, iitr.c1).Truncate(time.Second)
   129  	t2 := timeClaim(t, clock, iitr.c2).Truncate(time.Second)
   130  	if iitr.less { // t1 - t2 <= iitr.dur
   131  		// t1 - t2 < iitr.dur + skew
   132  		if t1.Sub(t2) > iitr.dur+skew {
   133  			return NewValidationError(errors.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
   134  		}
   135  	} else {
   136  		if t1.Sub(t2) < iitr.dur-skew {
   137  			return NewValidationError(errors.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  type ValidationError interface {
   144  	error
   145  	isValidationError()
   146  }
   147  
   148  func NewValidationError(err error) ValidationError {
   149  	return &validationError{error: err}
   150  }
   151  
   152  // This is a generic validation error.
   153  type validationError struct {
   154  	error
   155  }
   156  
   157  func (validationError) isValidationError() {}
   158  
   159  var errTokenExpired = NewValidationError(errors.New(`exp not satisfied`))
   160  var errInvalidIssuedAt = NewValidationError(errors.New(`iat not satisfied`))
   161  var errTokenNotYetValid = NewValidationError(errors.New(`nbf not satisfied`))
   162  
   163  // ErrTokenExpired returns the immutable error used when `exp` claim
   164  // is not satisfied
   165  func ErrTokenExpired() error {
   166  	return errTokenExpired
   167  }
   168  
   169  // ErrInvalidIssuedAt returns the immutable error used when `iat` claim
   170  // is not satisfied
   171  func ErrInvalidIssuedAt() error {
   172  	return errInvalidIssuedAt
   173  }
   174  
   175  func ErrTokenNotYetValid() error {
   176  	return errTokenNotYetValid
   177  }
   178  
   179  // Validator describes interface to validate a Token.
   180  type Validator interface {
   181  	// Validate should return an error if a required conditions is not met.
   182  	// This method will be changed in the next major release to return
   183  	// jwt.ValidationError instead of error to force users to return
   184  	// a validation error even for user-specified validators
   185  	Validate(context.Context, Token) error
   186  }
   187  
   188  // ValidatorFunc is a type of Validator that does not have any
   189  // state, that is implemented as a function
   190  type ValidatorFunc func(context.Context, Token) error
   191  
   192  func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) error {
   193  	return vf(ctx, tok)
   194  }
   195  
   196  type identValidationCtxClock struct{}
   197  type identValidationCtxSkew struct{}
   198  
   199  func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
   200  	return context.WithValue(ctx, identValidationCtxClock{}, cl)
   201  }
   202  
   203  // ValidationCtxClock returns the Clock object associated with
   204  // the current validation context. This value will always be available
   205  // during validation of tokens.
   206  func ValidationCtxClock(ctx context.Context) Clock {
   207  	//nolint:forcetypeassert
   208  	return ctx.Value(identValidationCtxClock{}).(Clock)
   209  }
   210  
   211  func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
   212  	return context.WithValue(ctx, identValidationCtxSkew{}, dur)
   213  }
   214  
   215  func ValidationCtxSkew(ctx context.Context) time.Duration {
   216  	//nolint:forcetypeassert
   217  	return ctx.Value(identValidationCtxSkew{}).(time.Duration)
   218  }
   219  
   220  // IsExpirationValid is one of the default validators that will be executed.
   221  // It does not need to be specified by users, but it exists as an
   222  // exported field so that you can check what it does.
   223  //
   224  // The supplied context.Context object must have the "clock" and "skew"
   225  // populated with appropriate values using SetValidationCtxClock() and
   226  // SetValidationCtxSkew()
   227  func IsExpirationValid() Validator {
   228  	return ValidatorFunc(isExpirationValid)
   229  }
   230  
   231  func isExpirationValid(ctx context.Context, t Token) error {
   232  	if tv := t.Expiration(); !tv.IsZero() && tv.Unix() != 0 {
   233  		clock := ValidationCtxClock(ctx) // MUST be populated
   234  		now := clock.Now().Truncate(time.Second)
   235  		ttv := tv.Truncate(time.Second)
   236  		skew := ValidationCtxSkew(ctx) // MUST be populated
   237  		if !now.Before(ttv.Add(skew)) {
   238  			return ErrTokenExpired()
   239  		}
   240  	}
   241  	return nil
   242  }
   243  
   244  // IsIssuedAtValid is one of the default validators that will be executed.
   245  // It does not need to be specified by users, but it exists as an
   246  // exported field so that you can check what it does.
   247  //
   248  // The supplied context.Context object must have the "clock" and "skew"
   249  // populated with appropriate values using SetValidationCtxClock() and
   250  // SetValidationCtxSkew()
   251  func IsIssuedAtValid() Validator {
   252  	return ValidatorFunc(isIssuedAtValid)
   253  }
   254  
   255  func isIssuedAtValid(ctx context.Context, t Token) error {
   256  	if tv := t.IssuedAt(); !tv.IsZero() && tv.Unix() != 0 {
   257  		clock := ValidationCtxClock(ctx) // MUST be populated
   258  		now := clock.Now().Truncate(time.Second)
   259  		ttv := tv.Truncate(time.Second)
   260  		skew := ValidationCtxSkew(ctx) // MUST be populated
   261  		if now.Before(ttv.Add(-1 * skew)) {
   262  			return ErrInvalidIssuedAt()
   263  		}
   264  	}
   265  	return nil
   266  }
   267  
   268  // IsNbfValid is one of the default validators that will be executed.
   269  // It does not need to be specified by users, but it exists as an
   270  // exported field so that you can check what it does.
   271  //
   272  // The supplied context.Context object must have the "clock" and "skew"
   273  // populated with appropriate values using SetValidationCtxClock() and
   274  // SetValidationCtxSkew()
   275  func IsNbfValid() Validator {
   276  	return ValidatorFunc(isNbfValid)
   277  }
   278  
   279  func isNbfValid(ctx context.Context, t Token) error {
   280  	if tv := t.NotBefore(); !tv.IsZero() && tv.Unix() != 0 {
   281  		clock := ValidationCtxClock(ctx) // MUST be populated
   282  		now := clock.Now().Truncate(time.Second)
   283  		ttv := tv.Truncate(time.Second)
   284  		skew := ValidationCtxSkew(ctx) // MUST be populated
   285  		// now cannot be before t, so we check for now > t - skew
   286  		if !now.Equal(ttv) && !now.After(ttv.Add(-1*skew)) {
   287  			return ErrTokenNotYetValid()
   288  		}
   289  	}
   290  	return nil
   291  }
   292  
   293  type claimContainsString struct {
   294  	name  string
   295  	value string
   296  }
   297  
   298  // ClaimContainsString can be used to check if the claim called `name`, which is
   299  // expected to be a list of strings, contains `value`. Currently because of the
   300  // implementation this will probably only work for `aud` fields.
   301  func ClaimContainsString(name, value string) Validator {
   302  	return claimContainsString{
   303  		name:  name,
   304  		value: value,
   305  	}
   306  }
   307  
   308  // IsValidationError returns true if the error is a validation error
   309  func IsValidationError(err error) bool {
   310  	switch err {
   311  	case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
   312  		return true
   313  	default:
   314  		switch err.(type) {
   315  		case *validationError:
   316  			return true
   317  		default:
   318  			return false
   319  		}
   320  	}
   321  }
   322  
   323  func (ccs claimContainsString) Validate(_ context.Context, t Token) error {
   324  	v, ok := t.Get(ccs.name)
   325  	if !ok {
   326  		return NewValidationError(errors.Errorf(`claim %q not found`, ccs.name))
   327  	}
   328  
   329  	list, ok := v.([]string)
   330  	if !ok {
   331  		return NewValidationError(errors.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
   332  	}
   333  
   334  	var found bool
   335  	for _, v := range list {
   336  		if v == ccs.value {
   337  			found = true
   338  			break
   339  		}
   340  	}
   341  	if !found {
   342  		return NewValidationError(errors.Errorf(`%s not satisfied`, ccs.name))
   343  	}
   344  	return nil
   345  }
   346  
   347  type claimValueIs struct {
   348  	name  string
   349  	value interface{}
   350  }
   351  
   352  // ClaimValueIs creates a Validator that checks if the value of claim `name`
   353  // matches `value`. The comparison is done using a simple `==` comparison,
   354  // and therefore complex comparisons may fail using this code. If you
   355  // need to do more, use a custom Validator.
   356  func ClaimValueIs(name string, value interface{}) Validator {
   357  	return &claimValueIs{name: name, value: value}
   358  }
   359  
   360  func (cv *claimValueIs) Validate(_ context.Context, t Token) error {
   361  	v, ok := t.Get(cv.name)
   362  	if !ok {
   363  		return NewValidationError(errors.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
   364  	}
   365  	if v != cv.value {
   366  		return NewValidationError(errors.Errorf(`%q not satisfied: values do not match`, cv.name))
   367  	}
   368  	return nil
   369  }
   370  
   371  // IsRequired creates a Validator that checks if the required claim `name`
   372  // exists in the token
   373  func IsRequired(name string) Validator {
   374  	return isRequired(name)
   375  }
   376  
   377  type isRequired string
   378  
   379  func (ir isRequired) Validate(_ context.Context, t Token) error {
   380  	_, ok := t.Get(string(ir))
   381  	if !ok {
   382  		return NewValidationError(errors.Errorf(`required claim %q was not found`, string(ir)))
   383  	}
   384  	return nil
   385  }
   386  

View as plain text