...
1 package fieldbehavior
2
3 import (
4 "fmt"
5
6 "google.golang.org/genproto/googleapis/api/annotations"
7 "google.golang.org/protobuf/proto"
8 "google.golang.org/protobuf/reflect/protoreflect"
9 "google.golang.org/protobuf/types/known/fieldmaskpb"
10 )
11
12
13
14 func ValidateRequiredFields(m proto.Message) error {
15 return validateRequiredFields(
16 m.ProtoReflect(),
17 &fieldmaskpb.FieldMask{Paths: []string{"*"}},
18 "",
19 )
20 }
21
22 func ValidateRequiredFieldsWithMask(m proto.Message, mask *fieldmaskpb.FieldMask) error {
23 return validateRequiredFields(m.ProtoReflect(), mask, "")
24 }
25
26 func validateRequiredFields(reflectMessage protoreflect.Message, mask *fieldmaskpb.FieldMask, path string) error {
27
28
29
30 if len(mask.GetPaths()) == 0 {
31 return nil
32 }
33 for i := 0; i < reflectMessage.Descriptor().Fields().Len(); i++ {
34 field := reflectMessage.Descriptor().Fields().Get(i)
35 currPath := path
36 if len(currPath) > 0 {
37 currPath += "."
38 }
39 currPath += string(field.Name())
40 if !isMessageFieldPresent(reflectMessage, field) {
41 if Has(field, annotations.FieldBehavior_REQUIRED) && hasPath(mask, currPath) {
42 return fmt.Errorf("missing required field: %s", currPath)
43 }
44 } else if field.Kind() == protoreflect.MessageKind {
45 value := reflectMessage.Get(field)
46 switch {
47 case field.IsList():
48 for i := 0; i < value.List().Len(); i++ {
49 if err := validateRequiredFields(value.List().Get(i).Message(), mask, currPath); err != nil {
50 return err
51 }
52 }
53 case field.IsMap():
54 if field.MapValue().Kind() != protoreflect.MessageKind {
55 continue
56 }
57 var mapErr error
58 value.Map().Range(func(_ protoreflect.MapKey, value protoreflect.Value) bool {
59 if err := validateRequiredFields(value.Message(), mask, currPath); err != nil {
60 mapErr = err
61 return false
62 }
63
64 return true
65 })
66 if mapErr != nil {
67 return mapErr
68 }
69 default:
70 if err := validateRequiredFields(value.Message(), mask, currPath); err != nil {
71 return err
72 }
73 }
74 }
75 }
76 return nil
77 }
78
79 func isEmpty(mask *fieldmaskpb.FieldMask) bool {
80 return mask == nil || len(mask.GetPaths()) == 0
81 }
82
83 func hasPath(mask *fieldmaskpb.FieldMask, needle string) bool {
84 if isEmpty(mask) {
85 return true
86 }
87 for _, straw := range mask.GetPaths() {
88 if straw == "*" || straw == needle {
89 return true
90 }
91 }
92 return false
93 }
94
View as plain text