1 package jwt
2
3 import (
4 "context"
5 "strconv"
6 "time"
7
8 "github.com/pkg/errors"
9 )
10
11 type Clock interface {
12 Now() time.Time
13 }
14 type ClockFunc func() time.Time
15
16 func (f ClockFunc) Now() time.Time {
17 return f()
18 }
19
20 func isSupportedTimeClaim(c string) error {
21 switch c {
22 case ExpirationKey, IssuedAtKey, NotBeforeKey:
23 return nil
24 }
25 return NewValidationError(errors.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
26 }
27
28 func timeClaim(t Token, clock Clock, c string) time.Time {
29 switch c {
30 case ExpirationKey:
31 return t.Expiration()
32 case IssuedAtKey:
33 return t.IssuedAt()
34 case NotBeforeKey:
35 return t.NotBefore()
36 case "":
37 return clock.Now()
38 }
39 return time.Time{}
40 }
41
42
43
44
45
46 func Validate(t Token, options ...ValidateOption) error {
47 ctx := context.Background()
48 var clock Clock = ClockFunc(time.Now)
49 var skew time.Duration
50 var validators = []Validator{
51 IsIssuedAtValid(),
52 IsExpirationValid(),
53 IsNbfValid(),
54 }
55 for _, o := range options {
56
57 switch o.Ident() {
58 case identClock{}:
59 clock = o.Value().(Clock)
60 case identAcceptableSkew{}:
61 skew = o.Value().(time.Duration)
62 case identContext{}:
63 ctx = o.Value().(context.Context)
64 case identValidator{}:
65 v := o.Value().(Validator)
66 switch v := v.(type) {
67 case *isInTimeRange:
68 if v.c1 != "" {
69 if err := isSupportedTimeClaim(v.c1); err != nil {
70 return err
71 }
72 validators = append(validators, IsRequired(v.c1))
73 }
74 if v.c2 != "" {
75 if err := isSupportedTimeClaim(v.c2); err != nil {
76 return err
77 }
78 validators = append(validators, IsRequired(v.c2))
79 }
80 }
81 validators = append(validators, v)
82 }
83 }
84
85 ctx = SetValidationCtxSkew(ctx, skew)
86 ctx = SetValidationCtxClock(ctx, clock)
87 for _, v := range validators {
88 if err := v.Validate(ctx, t); err != nil {
89 return err
90 }
91 }
92
93 return nil
94 }
95
96 type isInTimeRange struct {
97 c1 string
98 c2 string
99 dur time.Duration
100 less bool
101 }
102
103
104 func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
105 return &isInTimeRange{
106 c1: c1,
107 c2: c2,
108 dur: dur,
109 less: true,
110 }
111 }
112
113
114 func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
115 return &isInTimeRange{
116 c1: c1,
117 c2: c2,
118 dur: dur,
119 less: false,
120 }
121 }
122
123 func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) error {
124 clock := ValidationCtxClock(ctx)
125 skew := ValidationCtxSkew(ctx)
126
127
128 t1 := timeClaim(t, clock, iitr.c1).Truncate(time.Second)
129 t2 := timeClaim(t, clock, iitr.c2).Truncate(time.Second)
130 if iitr.less {
131
132 if t1.Sub(t2) > iitr.dur+skew {
133 return NewValidationError(errors.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
134 }
135 } else {
136 if t1.Sub(t2) < iitr.dur-skew {
137 return NewValidationError(errors.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
138 }
139 }
140 return nil
141 }
142
143 type ValidationError interface {
144 error
145 isValidationError()
146 }
147
148 func NewValidationError(err error) ValidationError {
149 return &validationError{error: err}
150 }
151
152
153 type validationError struct {
154 error
155 }
156
157 func (validationError) isValidationError() {}
158
159 var errTokenExpired = NewValidationError(errors.New(`exp not satisfied`))
160 var errInvalidIssuedAt = NewValidationError(errors.New(`iat not satisfied`))
161 var errTokenNotYetValid = NewValidationError(errors.New(`nbf not satisfied`))
162
163
164
165 func ErrTokenExpired() error {
166 return errTokenExpired
167 }
168
169
170
171 func ErrInvalidIssuedAt() error {
172 return errInvalidIssuedAt
173 }
174
175 func ErrTokenNotYetValid() error {
176 return errTokenNotYetValid
177 }
178
179
180 type Validator interface {
181
182
183
184
185 Validate(context.Context, Token) error
186 }
187
188
189
190 type ValidatorFunc func(context.Context, Token) error
191
192 func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) error {
193 return vf(ctx, tok)
194 }
195
196 type identValidationCtxClock struct{}
197 type identValidationCtxSkew struct{}
198
199 func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
200 return context.WithValue(ctx, identValidationCtxClock{}, cl)
201 }
202
203
204
205
206 func ValidationCtxClock(ctx context.Context) Clock {
207
208 return ctx.Value(identValidationCtxClock{}).(Clock)
209 }
210
211 func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
212 return context.WithValue(ctx, identValidationCtxSkew{}, dur)
213 }
214
215 func ValidationCtxSkew(ctx context.Context) time.Duration {
216
217 return ctx.Value(identValidationCtxSkew{}).(time.Duration)
218 }
219
220
221
222
223
224
225
226
227 func IsExpirationValid() Validator {
228 return ValidatorFunc(isExpirationValid)
229 }
230
231 func isExpirationValid(ctx context.Context, t Token) error {
232 if tv := t.Expiration(); !tv.IsZero() && tv.Unix() != 0 {
233 clock := ValidationCtxClock(ctx)
234 now := clock.Now().Truncate(time.Second)
235 ttv := tv.Truncate(time.Second)
236 skew := ValidationCtxSkew(ctx)
237 if !now.Before(ttv.Add(skew)) {
238 return ErrTokenExpired()
239 }
240 }
241 return nil
242 }
243
244
245
246
247
248
249
250
251 func IsIssuedAtValid() Validator {
252 return ValidatorFunc(isIssuedAtValid)
253 }
254
255 func isIssuedAtValid(ctx context.Context, t Token) error {
256 if tv := t.IssuedAt(); !tv.IsZero() && tv.Unix() != 0 {
257 clock := ValidationCtxClock(ctx)
258 now := clock.Now().Truncate(time.Second)
259 ttv := tv.Truncate(time.Second)
260 skew := ValidationCtxSkew(ctx)
261 if now.Before(ttv.Add(-1 * skew)) {
262 return ErrInvalidIssuedAt()
263 }
264 }
265 return nil
266 }
267
268
269
270
271
272
273
274
275 func IsNbfValid() Validator {
276 return ValidatorFunc(isNbfValid)
277 }
278
279 func isNbfValid(ctx context.Context, t Token) error {
280 if tv := t.NotBefore(); !tv.IsZero() && tv.Unix() != 0 {
281 clock := ValidationCtxClock(ctx)
282 now := clock.Now().Truncate(time.Second)
283 ttv := tv.Truncate(time.Second)
284 skew := ValidationCtxSkew(ctx)
285
286 if !now.Equal(ttv) && !now.After(ttv.Add(-1*skew)) {
287 return ErrTokenNotYetValid()
288 }
289 }
290 return nil
291 }
292
293 type claimContainsString struct {
294 name string
295 value string
296 }
297
298
299
300
301 func ClaimContainsString(name, value string) Validator {
302 return claimContainsString{
303 name: name,
304 value: value,
305 }
306 }
307
308
309 func IsValidationError(err error) bool {
310 switch err {
311 case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
312 return true
313 default:
314 switch err.(type) {
315 case *validationError:
316 return true
317 default:
318 return false
319 }
320 }
321 }
322
323 func (ccs claimContainsString) Validate(_ context.Context, t Token) error {
324 v, ok := t.Get(ccs.name)
325 if !ok {
326 return NewValidationError(errors.Errorf(`claim %q not found`, ccs.name))
327 }
328
329 list, ok := v.([]string)
330 if !ok {
331 return NewValidationError(errors.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
332 }
333
334 var found bool
335 for _, v := range list {
336 if v == ccs.value {
337 found = true
338 break
339 }
340 }
341 if !found {
342 return NewValidationError(errors.Errorf(`%s not satisfied`, ccs.name))
343 }
344 return nil
345 }
346
347 type claimValueIs struct {
348 name string
349 value interface{}
350 }
351
352
353
354
355
356 func ClaimValueIs(name string, value interface{}) Validator {
357 return &claimValueIs{name: name, value: value}
358 }
359
360 func (cv *claimValueIs) Validate(_ context.Context, t Token) error {
361 v, ok := t.Get(cv.name)
362 if !ok {
363 return NewValidationError(errors.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
364 }
365 if v != cv.value {
366 return NewValidationError(errors.Errorf(`%q not satisfied: values do not match`, cv.name))
367 }
368 return nil
369 }
370
371
372
373 func IsRequired(name string) Validator {
374 return isRequired(name)
375 }
376
377 type isRequired string
378
379 func (ir isRequired) Validate(_ context.Context, t Token) error {
380 _, ok := t.Get(string(ir))
381 if !ok {
382 return NewValidationError(errors.Errorf(`required claim %q was not found`, string(ir)))
383 }
384 return nil
385 }
386
View as plain text