1 package validator
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "reflect"
8 "strings"
9 "sync"
10 "time"
11
12 ut "github.com/go-playground/universal-translator"
13 )
14
15 const (
16 defaultTagName = "validate"
17 utf8HexComma = "0x2C"
18 utf8Pipe = "0x7C"
19 tagSeparator = ","
20 orSeparator = "|"
21 tagKeySeparator = "="
22 structOnlyTag = "structonly"
23 noStructLevelTag = "nostructlevel"
24 omitempty = "omitempty"
25 omitnil = "omitnil"
26 isdefault = "isdefault"
27 requiredWithoutAllTag = "required_without_all"
28 requiredWithoutTag = "required_without"
29 requiredWithTag = "required_with"
30 requiredWithAllTag = "required_with_all"
31 requiredIfTag = "required_if"
32 requiredUnlessTag = "required_unless"
33 skipUnlessTag = "skip_unless"
34 excludedWithoutAllTag = "excluded_without_all"
35 excludedWithoutTag = "excluded_without"
36 excludedWithTag = "excluded_with"
37 excludedWithAllTag = "excluded_with_all"
38 excludedIfTag = "excluded_if"
39 excludedUnlessTag = "excluded_unless"
40 skipValidationTag = "-"
41 diveTag = "dive"
42 keysTag = "keys"
43 endKeysTag = "endkeys"
44 requiredTag = "required"
45 namespaceSeparator = "."
46 leftBracket = "["
47 rightBracket = "]"
48 restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}"
49 restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
50 restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
51 )
52
53 var (
54 timeDurationType = reflect.TypeOf(time.Duration(0))
55 timeType = reflect.TypeOf(time.Time{})
56
57 byteSliceType = reflect.TypeOf([]byte{})
58
59 defaultCField = &cField{namesEqual: true}
60 )
61
62
63
64
65
66 type FilterFunc func(ns []byte) bool
67
68
69
70
71 type CustomTypeFunc func(field reflect.Value) interface{}
72
73
74 type TagNameFunc func(field reflect.StructField) string
75
76 type internalValidationFuncWrapper struct {
77 fn FuncCtx
78 runValidatinOnNil bool
79 }
80
81
82 type Validate struct {
83 tagName string
84 pool *sync.Pool
85 tagNameFunc TagNameFunc
86 structLevelFuncs map[reflect.Type]StructLevelFuncCtx
87 customFuncs map[reflect.Type]CustomTypeFunc
88 aliases map[string]string
89 validations map[string]internalValidationFuncWrapper
90 transTagFunc map[ut.Translator]map[string]TranslationFunc
91 rules map[reflect.Type]map[string]string
92 tagCache *tagCache
93 structCache *structCache
94 hasCustomFuncs bool
95 hasTagNameFunc bool
96 requiredStructEnabled bool
97 privateFieldValidation bool
98 }
99
100
101
102
103
104
105 func New(options ...Option) *Validate {
106
107 tc := new(tagCache)
108 tc.m.Store(make(map[string]*cTag))
109
110 sc := new(structCache)
111 sc.m.Store(make(map[reflect.Type]*cStruct))
112
113 v := &Validate{
114 tagName: defaultTagName,
115 aliases: make(map[string]string, len(bakedInAliases)),
116 validations: make(map[string]internalValidationFuncWrapper, len(bakedInValidators)),
117 tagCache: tc,
118 structCache: sc,
119 }
120
121
122 for k, val := range bakedInAliases {
123 v.RegisterAlias(k, val)
124 }
125
126
127 for k, val := range bakedInValidators {
128
129 switch k {
130
131 case requiredIfTag, requiredUnlessTag, requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag,
132 excludedIfTag, excludedUnlessTag, excludedWithTag, excludedWithAllTag, excludedWithoutTag, excludedWithoutAllTag,
133 skipUnlessTag:
134 _ = v.registerValidation(k, wrapFunc(val), true, true)
135 default:
136
137 _ = v.registerValidation(k, wrapFunc(val), true, false)
138 }
139 }
140
141 v.pool = &sync.Pool{
142 New: func() interface{} {
143 return &validate{
144 v: v,
145 ns: make([]byte, 0, 64),
146 actualNs: make([]byte, 0, 64),
147 misc: make([]byte, 32),
148 }
149 },
150 }
151
152 for _, o := range options {
153 o(v)
154 }
155 return v
156 }
157
158
159 func (v *Validate) SetTagName(name string) {
160 v.tagName = name
161 }
162
163
164
165 func (v Validate) ValidateMapCtx(ctx context.Context, data map[string]interface{}, rules map[string]interface{}) map[string]interface{} {
166 errs := make(map[string]interface{})
167 for field, rule := range rules {
168 if ruleObj, ok := rule.(map[string]interface{}); ok {
169 if dataObj, ok := data[field].(map[string]interface{}); ok {
170 err := v.ValidateMapCtx(ctx, dataObj, ruleObj)
171 if len(err) > 0 {
172 errs[field] = err
173 }
174 } else if dataObjs, ok := data[field].([]map[string]interface{}); ok {
175 for _, obj := range dataObjs {
176 err := v.ValidateMapCtx(ctx, obj, ruleObj)
177 if len(err) > 0 {
178 errs[field] = err
179 }
180 }
181 } else {
182 errs[field] = errors.New("The field: '" + field + "' is not a map to dive")
183 }
184 } else if ruleStr, ok := rule.(string); ok {
185 err := v.VarCtx(ctx, data[field], ruleStr)
186 if err != nil {
187 errs[field] = err
188 }
189 }
190 }
191 return errs
192 }
193
194
195 func (v *Validate) ValidateMap(data map[string]interface{}, rules map[string]interface{}) map[string]interface{} {
196 return v.ValidateMapCtx(context.Background(), data, rules)
197 }
198
199
200
201
202
203
204
205
206
207
208
209
210
211 func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
212 v.tagNameFunc = fn
213 v.hasTagNameFunc = true
214 }
215
216
217
218
219
220
221 func (v *Validate) RegisterValidation(tag string, fn Func, callValidationEvenIfNull ...bool) error {
222 return v.RegisterValidationCtx(tag, wrapFunc(fn), callValidationEvenIfNull...)
223 }
224
225
226
227 func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx, callValidationEvenIfNull ...bool) error {
228 var nilCheckable bool
229 if len(callValidationEvenIfNull) > 0 {
230 nilCheckable = callValidationEvenIfNull[0]
231 }
232 return v.registerValidation(tag, fn, false, nilCheckable)
233 }
234
235 func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error {
236 if len(tag) == 0 {
237 return errors.New("function Key cannot be empty")
238 }
239
240 if fn == nil {
241 return errors.New("function cannot be empty")
242 }
243
244 _, ok := restrictedTags[tag]
245 if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
246 panic(fmt.Sprintf(restrictedTagErr, tag))
247 }
248 v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable}
249 return nil
250 }
251
252
253
254
255
256
257 func (v *Validate) RegisterAlias(alias, tags string) {
258
259 _, ok := restrictedTags[alias]
260
261 if ok || strings.ContainsAny(alias, restrictedTagChars) {
262 panic(fmt.Sprintf(restrictedAliasErr, alias))
263 }
264
265 v.aliases[alias] = tags
266 }
267
268
269
270
271
272 func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...interface{}) {
273 v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...)
274 }
275
276
277
278
279
280
281 func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...interface{}) {
282
283 if v.structLevelFuncs == nil {
284 v.structLevelFuncs = make(map[reflect.Type]StructLevelFuncCtx)
285 }
286
287 for _, t := range types {
288 tv := reflect.ValueOf(t)
289 if tv.Kind() == reflect.Ptr {
290 t = reflect.Indirect(tv).Interface()
291 }
292
293 v.structLevelFuncs[reflect.TypeOf(t)] = fn
294 }
295 }
296
297
298
299
300
301 func (v *Validate) RegisterStructValidationMapRules(rules map[string]string, types ...interface{}) {
302 if v.rules == nil {
303 v.rules = make(map[reflect.Type]map[string]string)
304 }
305
306 deepCopyRules := make(map[string]string)
307 for i, rule := range rules {
308 deepCopyRules[i] = rule
309 }
310
311 for _, t := range types {
312 typ := reflect.TypeOf(t)
313
314 if typ.Kind() == reflect.Ptr {
315 typ = typ.Elem()
316 }
317
318 if typ.Kind() != reflect.Struct {
319 continue
320 }
321 v.rules[typ] = deepCopyRules
322 }
323 }
324
325
326
327
328 func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
329
330 if v.customFuncs == nil {
331 v.customFuncs = make(map[reflect.Type]CustomTypeFunc)
332 }
333
334 for _, t := range types {
335 v.customFuncs[reflect.TypeOf(t)] = fn
336 }
337
338 v.hasCustomFuncs = true
339 }
340
341
342 func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, registerFn RegisterTranslationsFunc, translationFn TranslationFunc) (err error) {
343
344 if v.transTagFunc == nil {
345 v.transTagFunc = make(map[ut.Translator]map[string]TranslationFunc)
346 }
347
348 if err = registerFn(trans); err != nil {
349 return
350 }
351
352 m, ok := v.transTagFunc[trans]
353 if !ok {
354 m = make(map[string]TranslationFunc)
355 v.transTagFunc[trans] = m
356 }
357
358 m[tag] = translationFn
359
360 return
361 }
362
363
364
365
366
367 func (v *Validate) Struct(s interface{}) error {
368 return v.StructCtx(context.Background(), s)
369 }
370
371
372
373
374
375
376 func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
377
378 val := reflect.ValueOf(s)
379 top := val
380
381 if val.Kind() == reflect.Ptr && !val.IsNil() {
382 val = val.Elem()
383 }
384
385 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
386 return &InvalidValidationError{Type: reflect.TypeOf(s)}
387 }
388
389
390 vd := v.pool.Get().(*validate)
391 vd.top = top
392 vd.isPartial = false
393
394
395 vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
396
397 if len(vd.errs) > 0 {
398 err = vd.errs
399 vd.errs = nil
400 }
401
402 v.pool.Put(vd)
403
404 return
405 }
406
407
408
409
410
411
412 func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) error {
413 return v.StructFilteredCtx(context.Background(), s, fn)
414 }
415
416
417
418
419
420
421
422 func (v *Validate) StructFilteredCtx(ctx context.Context, s interface{}, fn FilterFunc) (err error) {
423 val := reflect.ValueOf(s)
424 top := val
425
426 if val.Kind() == reflect.Ptr && !val.IsNil() {
427 val = val.Elem()
428 }
429
430 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
431 return &InvalidValidationError{Type: reflect.TypeOf(s)}
432 }
433
434
435 vd := v.pool.Get().(*validate)
436 vd.top = top
437 vd.isPartial = true
438 vd.ffn = fn
439
440
441 vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
442
443 if len(vd.errs) > 0 {
444 err = vd.errs
445 vd.errs = nil
446 }
447
448 v.pool.Put(vd)
449
450 return
451 }
452
453
454
455
456
457
458
459 func (v *Validate) StructPartial(s interface{}, fields ...string) error {
460 return v.StructPartialCtx(context.Background(), s, fields...)
461 }
462
463
464
465
466
467
468
469
470 func (v *Validate) StructPartialCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
471 val := reflect.ValueOf(s)
472 top := val
473
474 if val.Kind() == reflect.Ptr && !val.IsNil() {
475 val = val.Elem()
476 }
477
478 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
479 return &InvalidValidationError{Type: reflect.TypeOf(s)}
480 }
481
482
483 vd := v.pool.Get().(*validate)
484 vd.top = top
485 vd.isPartial = true
486 vd.ffn = nil
487 vd.hasExcludes = false
488 vd.includeExclude = make(map[string]struct{})
489
490 typ := val.Type()
491 name := typ.Name()
492
493 for _, k := range fields {
494
495 flds := strings.Split(k, namespaceSeparator)
496 if len(flds) > 0 {
497
498 vd.misc = append(vd.misc[0:0], name...)
499
500 if len(vd.misc) != 0 {
501 vd.misc = append(vd.misc, '.')
502 }
503
504 for _, s := range flds {
505
506 idx := strings.Index(s, leftBracket)
507
508 if idx != -1 {
509 for idx != -1 {
510 vd.misc = append(vd.misc, s[:idx]...)
511 vd.includeExclude[string(vd.misc)] = struct{}{}
512
513 idx2 := strings.Index(s, rightBracket)
514 idx2++
515 vd.misc = append(vd.misc, s[idx:idx2]...)
516 vd.includeExclude[string(vd.misc)] = struct{}{}
517 s = s[idx2:]
518 idx = strings.Index(s, leftBracket)
519 }
520 } else {
521
522 vd.misc = append(vd.misc, s...)
523 vd.includeExclude[string(vd.misc)] = struct{}{}
524 }
525
526 vd.misc = append(vd.misc, '.')
527 }
528 }
529 }
530
531 vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
532
533 if len(vd.errs) > 0 {
534 err = vd.errs
535 vd.errs = nil
536 }
537
538 v.pool.Put(vd)
539
540 return
541 }
542
543
544
545
546
547
548
549 func (v *Validate) StructExcept(s interface{}, fields ...string) error {
550 return v.StructExceptCtx(context.Background(), s, fields...)
551 }
552
553
554
555
556
557
558
559
560 func (v *Validate) StructExceptCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
561 val := reflect.ValueOf(s)
562 top := val
563
564 if val.Kind() == reflect.Ptr && !val.IsNil() {
565 val = val.Elem()
566 }
567
568 if val.Kind() != reflect.Struct || val.Type().ConvertibleTo(timeType) {
569 return &InvalidValidationError{Type: reflect.TypeOf(s)}
570 }
571
572
573 vd := v.pool.Get().(*validate)
574 vd.top = top
575 vd.isPartial = true
576 vd.ffn = nil
577 vd.hasExcludes = true
578 vd.includeExclude = make(map[string]struct{})
579
580 typ := val.Type()
581 name := typ.Name()
582
583 for _, key := range fields {
584
585 vd.misc = vd.misc[0:0]
586
587 if len(name) > 0 {
588 vd.misc = append(vd.misc, name...)
589 vd.misc = append(vd.misc, '.')
590 }
591
592 vd.misc = append(vd.misc, key...)
593 vd.includeExclude[string(vd.misc)] = struct{}{}
594 }
595
596 vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
597
598 if len(vd.errs) > 0 {
599 err = vd.errs
600 vd.errs = nil
601 }
602
603 v.pool.Put(vd)
604
605 return
606 }
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621 func (v *Validate) Var(field interface{}, tag string) error {
622 return v.VarCtx(context.Background(), field, tag)
623 }
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639 func (v *Validate) VarCtx(ctx context.Context, field interface{}, tag string) (err error) {
640 if len(tag) == 0 || tag == skipValidationTag {
641 return nil
642 }
643
644 ctag := v.fetchCacheTag(tag)
645
646 val := reflect.ValueOf(field)
647 vd := v.pool.Get().(*validate)
648 vd.top = val
649 vd.isPartial = false
650 vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
651
652 if len(vd.errs) > 0 {
653 err = vd.errs
654 vd.errs = nil
655 }
656 v.pool.Put(vd)
657 return
658 }
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674 func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) error {
675 return v.VarWithValueCtx(context.Background(), field, other, tag)
676 }
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693 func (v *Validate) VarWithValueCtx(ctx context.Context, field interface{}, other interface{}, tag string) (err error) {
694 if len(tag) == 0 || tag == skipValidationTag {
695 return nil
696 }
697 ctag := v.fetchCacheTag(tag)
698 otherVal := reflect.ValueOf(other)
699 vd := v.pool.Get().(*validate)
700 vd.top = otherVal
701 vd.isPartial = false
702 vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
703
704 if len(vd.errs) > 0 {
705 err = vd.errs
706 vd.errs = nil
707 }
708 v.pool.Put(vd)
709 return
710 }
711
View as plain text