1
2
3
4
5 package proto
6
7 import (
8 "errors"
9 "fmt"
10 "reflect"
11
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/proto"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17 "google.golang.org/protobuf/runtime/protoimpl"
18 )
19
20 type (
21
22
23
24
25 ExtensionDesc = protoimpl.ExtensionInfo
26
27
28
29 ExtensionRange = protoiface.ExtensionRangeV1
30
31
32 Extension = protoimpl.ExtensionFieldV1
33
34
35 XXX_InternalExtensions = protoimpl.ExtensionFields
36 )
37
38
39 var ErrMissingExtension = errors.New("proto: missing extension")
40
41 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
42
43
44
45 func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
46 mr := MessageReflect(m)
47 if mr == nil || !mr.IsValid() {
48 return false
49 }
50
51
52 xtd := xt.TypeDescriptor()
53 if isValidExtension(mr.Descriptor(), xtd) {
54 has = mr.Has(xtd)
55 } else {
56 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
57 has = int32(fd.Number()) == xt.Field
58 return !has
59 })
60 }
61
62
63 for b := mr.GetUnknown(); !has && len(b) > 0; {
64 num, _, n := protowire.ConsumeField(b)
65 has = int32(num) == xt.Field
66 b = b[n:]
67 }
68 return has
69 }
70
71
72
73 func ClearExtension(m Message, xt *ExtensionDesc) {
74 mr := MessageReflect(m)
75 if mr == nil || !mr.IsValid() {
76 return
77 }
78
79 xtd := xt.TypeDescriptor()
80 if isValidExtension(mr.Descriptor(), xtd) {
81 mr.Clear(xtd)
82 } else {
83 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
84 if int32(fd.Number()) == xt.Field {
85 mr.Clear(fd)
86 return false
87 }
88 return true
89 })
90 }
91 clearUnknown(mr, fieldNum(xt.Field))
92 }
93
94
95
96 func ClearAllExtensions(m Message) {
97 mr := MessageReflect(m)
98 if mr == nil || !mr.IsValid() {
99 return
100 }
101
102 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
103 if fd.IsExtension() {
104 mr.Clear(fd)
105 }
106 return true
107 })
108 clearUnknown(mr, mr.Descriptor().ExtensionRanges())
109 }
110
111
112
113
114
115
116
117
118
119
120 func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
121 mr := MessageReflect(m)
122 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
123 return nil, errNotExtendable
124 }
125
126
127 var bo protoreflect.RawFields
128 for bi := mr.GetUnknown(); len(bi) > 0; {
129 num, _, n := protowire.ConsumeField(bi)
130 if int32(num) == xt.Field {
131 bo = append(bo, bi[:n]...)
132 }
133 bi = bi[n:]
134 }
135
136
137 if xt.ExtensionType == nil {
138 return []byte(bo), nil
139 }
140
141
142
143 xtd := xt.TypeDescriptor()
144 if !isValidExtension(mr.Descriptor(), xtd) {
145 return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
146 }
147 if !mr.Has(xtd) && len(bo) > 0 {
148 m2 := mr.New()
149 if err := (proto.UnmarshalOptions{
150 Resolver: extensionResolver{xt},
151 }.Unmarshal(bo, m2.Interface())); err != nil {
152 return nil, err
153 }
154 if m2.Has(xtd) {
155 mr.Set(xtd, m2.Get(xtd))
156 clearUnknown(mr, fieldNum(xt.Field))
157 }
158 }
159
160
161 var pv protoreflect.Value
162 switch {
163 case mr.Has(xtd):
164 pv = mr.Get(xtd)
165 case xtd.HasDefault():
166 pv = xtd.Default()
167 default:
168 return nil, ErrMissingExtension
169 }
170
171 v := xt.InterfaceOf(pv)
172 rv := reflect.ValueOf(v)
173 if isScalarKind(rv.Kind()) {
174 rv2 := reflect.New(rv.Type())
175 rv2.Elem().Set(rv)
176 v = rv2.Interface()
177 }
178 return v, nil
179 }
180
181
182
183 type extensionResolver struct{ xt protoreflect.ExtensionType }
184
185 func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
186 if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
187 return r.xt, nil
188 }
189 return protoregistry.GlobalTypes.FindExtensionByName(field)
190 }
191
192 func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
193 if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
194 return r.xt, nil
195 }
196 return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
197 }
198
199
200
201
202 func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
203 mr := MessageReflect(m)
204 if mr == nil || !mr.IsValid() {
205 return nil, errNotExtendable
206 }
207
208 vs := make([]interface{}, len(xts))
209 for i, xt := range xts {
210 v, err := GetExtension(m, xt)
211 if err != nil {
212 if err == ErrMissingExtension {
213 continue
214 }
215 return vs, err
216 }
217 vs[i] = v
218 }
219 return vs, nil
220 }
221
222
223 func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
224 mr := MessageReflect(m)
225 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
226 return errNotExtendable
227 }
228
229 rv := reflect.ValueOf(v)
230 if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
231 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
232 }
233 if rv.Kind() == reflect.Ptr {
234 if rv.IsNil() {
235 return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
236 }
237 if isScalarKind(rv.Elem().Kind()) {
238 v = rv.Elem().Interface()
239 }
240 }
241
242 xtd := xt.TypeDescriptor()
243 if !isValidExtension(mr.Descriptor(), xtd) {
244 return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
245 }
246 mr.Set(xtd, xt.ValueOf(v))
247 clearUnknown(mr, fieldNum(xt.Field))
248 return nil
249 }
250
251
252
253
254 func SetRawExtension(m Message, fnum int32, b []byte) {
255 mr := MessageReflect(m)
256 if mr == nil || !mr.IsValid() {
257 return
258 }
259
260
261 for b0 := b; len(b0) > 0; {
262 num, _, n := protowire.ConsumeField(b0)
263 if int32(num) != fnum {
264 panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
265 }
266 b0 = b0[n:]
267 }
268
269 ClearExtension(m, &ExtensionDesc{Field: fnum})
270 mr.SetUnknown(append(mr.GetUnknown(), b...))
271 }
272
273
274
275
276
277
278
279 func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
280 mr := MessageReflect(m)
281 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
282 return nil, errNotExtendable
283 }
284
285
286 extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
287 mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
288 if fd.IsExtension() {
289 xt := fd.(protoreflect.ExtensionTypeDescriptor)
290 if xd, ok := xt.Type().(*ExtensionDesc); ok {
291 extDescs[fd.Number()] = xd
292 }
293 }
294 return true
295 })
296
297
298 extRanges := mr.Descriptor().ExtensionRanges()
299 for b := mr.GetUnknown(); len(b) > 0; {
300 num, _, n := protowire.ConsumeField(b)
301 if extRanges.Has(num) && extDescs[num] == nil {
302 extDescs[num] = nil
303 }
304 b = b[n:]
305 }
306
307
308 var xts []*ExtensionDesc
309 for num, xt := range extDescs {
310 if xt == nil {
311 xt = &ExtensionDesc{Field: int32(num)}
312 }
313 xts = append(xts, xt)
314 }
315 return xts, nil
316 }
317
318
319 func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
320 return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
321 }
322
323
324
325
326 func isScalarKind(k reflect.Kind) bool {
327 switch k {
328 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
329 return true
330 default:
331 return false
332 }
333 }
334
335
336 func clearUnknown(m protoreflect.Message, remover interface {
337 Has(protoreflect.FieldNumber) bool
338 }) {
339 var bo protoreflect.RawFields
340 for bi := m.GetUnknown(); len(bi) > 0; {
341 num, _, n := protowire.ConsumeField(bi)
342 if !remover.Has(num) {
343 bo = append(bo, bi[:n]...)
344 }
345 bi = bi[n:]
346 }
347 if bi := m.GetUnknown(); len(bi) != len(bo) {
348 m.SetUnknown(bo)
349 }
350 }
351
352 type fieldNum protoreflect.FieldNumber
353
354 func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
355 return protoreflect.FieldNumber(n1) == n2
356 }
357
View as plain text