1 package fieldbehavior
2
3 import (
4 "fmt"
5
6 "google.golang.org/genproto/googleapis/api/annotations"
7 "google.golang.org/protobuf/proto"
8 protoreflect "google.golang.org/protobuf/reflect/protoreflect"
9 )
10
11
12 func Get(field protoreflect.FieldDescriptor) []annotations.FieldBehavior {
13 if behaviors, ok := proto.GetExtension(
14 field.Options(), annotations.E_FieldBehavior,
15 ).([]annotations.FieldBehavior); ok {
16 return behaviors
17 }
18 return nil
19 }
20
21
22 func Has(field protoreflect.FieldDescriptor, want annotations.FieldBehavior) bool {
23 for _, got := range Get(field) {
24 if got == want {
25 return true
26 }
27 }
28 return false
29 }
30
31
32
33
34
35
36 func ClearFields(message proto.Message, behaviorsToClear ...annotations.FieldBehavior) {
37 clearFieldsWithBehaviors(message, behaviorsToClear...)
38 }
39
40
41 func CopyFields(dst, src proto.Message, behaviorsToCopy ...annotations.FieldBehavior) {
42 dstReflect := dst.ProtoReflect()
43 srcReflect := src.ProtoReflect()
44 if dstReflect.Descriptor() != srcReflect.Descriptor() {
45 panic(fmt.Sprintf(
46 "different types of dst (%s) and src (%s)",
47 dstReflect.Type().Descriptor().FullName(),
48 srcReflect.Type().Descriptor().FullName(),
49 ))
50 }
51 for i := 0; i < dstReflect.Descriptor().Fields().Len(); i++ {
52 dstField := dstReflect.Descriptor().Fields().Get(i)
53 if hasAnyBehavior(Get(dstField), behaviorsToCopy) {
54 srcField := srcReflect.Descriptor().Fields().Get(i)
55 if isMessageFieldPresent(srcReflect, srcField) {
56 dstReflect.Set(dstField, srcReflect.Get(srcField))
57 } else {
58 dstReflect.Clear(dstField)
59 }
60 }
61 }
62 }
63
64 func isMessageFieldPresent(m protoreflect.Message, f protoreflect.FieldDescriptor) bool {
65 return isPresent(m.Get(f), f)
66 }
67
68 func isPresent(v protoreflect.Value, f protoreflect.FieldDescriptor) bool {
69 if !v.IsValid() {
70 return false
71 }
72 if f.IsList() {
73 return v.List().Len() > 0
74 }
75 if f.IsMap() {
76 return v.Map().Len() > 0
77 }
78 switch f.Kind() {
79 case protoreflect.EnumKind:
80 return v.Enum() != 0
81 case protoreflect.BoolKind:
82 return v.Bool()
83 case protoreflect.Int32Kind,
84 protoreflect.Sint32Kind,
85 protoreflect.Int64Kind,
86 protoreflect.Sint64Kind,
87 protoreflect.Sfixed32Kind,
88 protoreflect.Sfixed64Kind:
89 return v.Int() != 0
90 case protoreflect.Uint32Kind,
91 protoreflect.Uint64Kind,
92 protoreflect.Fixed32Kind,
93 protoreflect.Fixed64Kind:
94 return v.Uint() != 0
95 case protoreflect.FloatKind,
96 protoreflect.DoubleKind:
97 return v.Float() != 0
98 case protoreflect.StringKind:
99 return len(v.String()) > 0
100 case protoreflect.BytesKind:
101 return len(v.Bytes()) > 0
102 case protoreflect.MessageKind:
103 return v.Message().IsValid()
104 case protoreflect.GroupKind:
105 return v.IsValid()
106 default:
107 return v.IsValid()
108 }
109 }
110
111 func clearFieldsWithBehaviors(m proto.Message, behaviorsToClear ...annotations.FieldBehavior) {
112 rangeFieldsWithBehaviors(
113 m.ProtoReflect(),
114 func(
115 m protoreflect.Message,
116 f protoreflect.FieldDescriptor,
117 _ protoreflect.Value,
118 behaviors []annotations.FieldBehavior,
119 ) bool {
120 if hasAnyBehavior(behaviors, behaviorsToClear) {
121 m.Clear(f)
122 }
123 return true
124 },
125 )
126 }
127
128 func rangeFieldsWithBehaviors(
129 m protoreflect.Message,
130 fn func(
131 protoreflect.Message,
132 protoreflect.FieldDescriptor,
133 protoreflect.Value,
134 []annotations.FieldBehavior,
135 ) bool,
136 ) {
137 m.Range(
138 func(f protoreflect.FieldDescriptor, v protoreflect.Value) bool {
139 if behaviors, ok := proto.GetExtension(
140 f.Options(),
141 annotations.E_FieldBehavior,
142 ).([]annotations.FieldBehavior); ok {
143 fn(m, f, v, behaviors)
144 }
145
146 switch {
147
148 case f.IsList() && f.Kind() == protoreflect.MessageKind:
149 for i := 0; i < v.List().Len(); i++ {
150 rangeFieldsWithBehaviors(
151 v.List().Get(i).Message(),
152 fn,
153 )
154 }
155 return true
156
157 case f.IsMap() && f.MapValue().Kind() == protoreflect.MessageKind:
158 v.Map().Range(func(_ protoreflect.MapKey, mv protoreflect.Value) bool {
159 rangeFieldsWithBehaviors(
160 mv.Message(),
161 fn,
162 )
163 return true
164 })
165 return true
166
167
168 case f.Kind() == protoreflect.MessageKind && !f.IsMap():
169 rangeFieldsWithBehaviors(
170 v.Message(),
171 fn,
172 )
173 return true
174 default:
175 return true
176 }
177 })
178 }
179
180 func hasAnyBehavior(haystack, needles []annotations.FieldBehavior) bool {
181 for _, needle := range needles {
182 if hasBehavior(haystack, needle) {
183 return true
184 }
185 }
186 return false
187 }
188
189 func hasBehavior(haystack []annotations.FieldBehavior, needle annotations.FieldBehavior) bool {
190 for _, straw := range haystack {
191 if straw == needle {
192 return true
193 }
194 }
195 return false
196 }
197
View as plain text