...
1
2
3
4
5
6
7
8
9
10
11 package protocmp
12
13 import (
14 "reflect"
15 "strconv"
16
17 "github.com/google/go-cmp/cmp"
18
19 "google.golang.org/protobuf/encoding/protowire"
20 "google.golang.org/protobuf/internal/genid"
21 "google.golang.org/protobuf/internal/msgfmt"
22 "google.golang.org/protobuf/proto"
23 "google.golang.org/protobuf/reflect/protoreflect"
24 "google.golang.org/protobuf/reflect/protoregistry"
25 "google.golang.org/protobuf/runtime/protoiface"
26 "google.golang.org/protobuf/runtime/protoimpl"
27 )
28
29 var (
30 enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
31 messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
32 messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
33 )
34
35
36
37 type Enum struct {
38 num protoreflect.EnumNumber
39 ed protoreflect.EnumDescriptor
40 }
41
42
43
44 func (e Enum) Descriptor() protoreflect.EnumDescriptor {
45 return e.ed
46 }
47
48
49 func (e Enum) Number() protoreflect.EnumNumber {
50 return e.num
51 }
52
53
54 func (e1 Enum) Equal(e2 Enum) bool {
55 if e1.ed.FullName() != e2.ed.FullName() {
56 return false
57 }
58 return e1.num == e2.num
59 }
60
61
62
63 func (e Enum) String() string {
64 if ev := e.ed.Values().ByNumber(e.num); ev != nil {
65 return string(ev.Name())
66 }
67 return strconv.Itoa(int(e.num))
68 }
69
70 const (
71
72
73
74
75 messageTypeKey = "@type"
76
77
78
79 messageInvalidKey = "@invalid"
80 )
81
82 type messageMeta struct {
83 m proto.Message
84 md protoreflect.MessageDescriptor
85 xds map[string]protoreflect.ExtensionDescriptor
86 }
87
88 func (t messageMeta) String() string {
89 return string(t.md.FullName())
90 }
91
92 func (t1 messageMeta) Equal(t2 messageMeta) bool {
93 return t1.md.FullName() == t2.md.FullName()
94 }
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118 type Message map[string]interface{}
119
120
121
122 func (m Message) Unwrap() proto.Message {
123 mm, _ := m[messageTypeKey].(messageMeta)
124 return mm.m
125 }
126
127
128
129 func (m Message) Descriptor() protoreflect.MessageDescriptor {
130 mm, _ := m[messageTypeKey].(messageMeta)
131 return mm.md
132 }
133
134
135
136
137 func (m Message) ProtoReflect() protoreflect.Message {
138 return (reflectMessage)(m)
139 }
140
141
142 func (m Message) ProtoMessage() {}
143
144
145 func (m Message) Reset() {
146 panic("invalid mutation of a read-only message")
147 }
148
149
150
151
152 func (m Message) String() string {
153 switch {
154 case m == nil:
155 return "<nil>"
156 case !m.ProtoReflect().IsValid():
157 return "<invalid>"
158 default:
159 return msgfmt.Format(m)
160 }
161 }
162
163 type transformer struct {
164 resolver protoregistry.MessageTypeResolver
165 }
166
167 func newTransformer(opts ...option) *transformer {
168 xf := &transformer{
169 resolver: protoregistry.GlobalTypes,
170 }
171 for _, opt := range opts {
172 opt(xf)
173 }
174 return xf
175 }
176
177 type option func(*transformer)
178
179
180
181
182
183
184 func MessageTypeResolver(r protoregistry.MessageTypeResolver) option {
185 return func(xf *transformer) {
186 xf.resolver = r
187 }
188 }
189
190
191
192
193
194
195
196
197
198
199
200 func Transform(opts ...option) cmp.Option {
201 xf := newTransformer(opts...)
202
203
204 addrType := func(t reflect.Type) reflect.Type {
205 if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
206 return t
207 }
208 return reflect.PtrTo(t)
209 }
210
211
212 return cmp.FilterPath(func(p cmp.Path) bool {
213 ps := p.Last()
214 if isMessageType(addrType(ps.Type())) {
215 return true
216 }
217
218
219
220 if ps.Type().Kind() == reflect.Interface {
221 vx, vy := ps.Values()
222 if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
223 return false
224 }
225 return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
226 }
227
228 return false
229 }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
230
231
232 if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
233 pv := reflect.New(rv.Type())
234 pv.Elem().Set(rv)
235 v = pv.Interface()
236 }
237
238 m := protoimpl.X.MessageOf(v)
239 switch {
240 case m == nil:
241 return nil
242 case !m.IsValid():
243 return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
244 default:
245 return xf.transformMessage(m)
246 }
247 }))
248 }
249
250 func isMessageType(t reflect.Type) bool {
251
252 if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
253 return false
254 }
255 return t.Implements(messageV1Type) || t.Implements(messageV2Type)
256 }
257
258 func (xf *transformer) transformMessage(m protoreflect.Message) Message {
259 mx := Message{}
260 mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
261
262
263 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
264 s := fd.TextName()
265 if fd.IsExtension() {
266 mt.xds[s] = fd
267 }
268 switch {
269 case fd.IsList():
270 mx[s] = xf.transformList(fd, v.List())
271 case fd.IsMap():
272 mx[s] = xf.transformMap(fd, v.Map())
273 default:
274 mx[s] = xf.transformSingular(fd, v)
275 }
276 return true
277 })
278
279
280 for b := m.GetUnknown(); len(b) > 0; {
281 num, _, n := protowire.ConsumeField(b)
282 s := strconv.Itoa(int(num))
283 b2, _ := mx[s].(protoreflect.RawFields)
284 mx[s] = append(b2, b[:n]...)
285 b = b[n:]
286 }
287
288
289 if mt.md.FullName() == genid.Any_message_fullname {
290 s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
291 b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
292 mt, err := xf.resolver.FindMessageByURL(s)
293 if mt != nil && err == nil {
294 m2 := mt.New()
295 err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
296 if err == nil {
297 mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2)
298 }
299 }
300 }
301
302 mx[messageTypeKey] = mt
303 return mx
304 }
305
306 func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
307 t := protoKindToGoType(fd.Kind())
308 rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
309 for i := 0; i < lv.Len(); i++ {
310 v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i)))
311 rv.Index(i).Set(v)
312 }
313 return rv.Interface()
314 }
315
316 func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
317 kfd := fd.MapKey()
318 vfd := fd.MapValue()
319 kt := protoKindToGoType(kfd.Kind())
320 vt := protoKindToGoType(vfd.Kind())
321 rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
322 mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
323 kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value()))
324 vv := reflect.ValueOf(xf.transformSingular(vfd, v))
325 rv.SetMapIndex(kv, vv)
326 return true
327 })
328 return rv.Interface()
329 }
330
331 func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
332 switch fd.Kind() {
333 case protoreflect.EnumKind:
334 return Enum{num: v.Enum(), ed: fd.Enum()}
335 case protoreflect.MessageKind, protoreflect.GroupKind:
336 return xf.transformMessage(v.Message())
337 case protoreflect.BytesKind:
338
339
340
341 if len(v.Bytes()) == 0 {
342 return []byte{}
343 }
344 return v.Bytes()
345 default:
346 return v.Interface()
347 }
348 }
349
350 func protoKindToGoType(k protoreflect.Kind) reflect.Type {
351 switch k {
352 case protoreflect.BoolKind:
353 return reflect.TypeOf(bool(false))
354 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
355 return reflect.TypeOf(int32(0))
356 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
357 return reflect.TypeOf(int64(0))
358 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
359 return reflect.TypeOf(uint32(0))
360 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
361 return reflect.TypeOf(uint64(0))
362 case protoreflect.FloatKind:
363 return reflect.TypeOf(float32(0))
364 case protoreflect.DoubleKind:
365 return reflect.TypeOf(float64(0))
366 case protoreflect.StringKind:
367 return reflect.TypeOf(string(""))
368 case protoreflect.BytesKind:
369 return reflect.TypeOf([]byte(nil))
370 case protoreflect.EnumKind:
371 return reflect.TypeOf(Enum{})
372 case protoreflect.MessageKind, protoreflect.GroupKind:
373 return reflect.TypeOf(Message{})
374 default:
375 panic("invalid kind")
376 }
377 }
378
View as plain text