...

Source file src/google.golang.org/protobuf/testing/protocmp/xform.go

Documentation: google.golang.org/protobuf/testing/protocmp

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package protocmp provides protobuf specific options for the
     6  // [github.com/google/go-cmp/cmp] package.
     7  //
     8  // The primary feature is the [Transform] option, which transform [proto.Message]
     9  // types into a [Message] map that is suitable for cmp to introspect upon.
    10  // All other options in this package must be used in conjunction with [Transform].
    11  package protocmp
    12  
    13  import (
    14  	"reflect"
    15  	"strconv"
    16  
    17  	"github.com/google/go-cmp/cmp"
    18  
    19  	"google.golang.org/protobuf/encoding/protowire"
    20  	"google.golang.org/protobuf/internal/genid"
    21  	"google.golang.org/protobuf/internal/msgfmt"
    22  	"google.golang.org/protobuf/proto"
    23  	"google.golang.org/protobuf/reflect/protoreflect"
    24  	"google.golang.org/protobuf/reflect/protoregistry"
    25  	"google.golang.org/protobuf/runtime/protoiface"
    26  	"google.golang.org/protobuf/runtime/protoimpl"
    27  )
    28  
    29  var (
    30  	enumV2Type    = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
    31  	messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
    32  	messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
    33  )
    34  
    35  // Enum is a dynamic representation of a protocol buffer enum that is
    36  // suitable for [cmp.Equal] and [cmp.Diff] to compare upon.
    37  type Enum struct {
    38  	num protoreflect.EnumNumber
    39  	ed  protoreflect.EnumDescriptor
    40  }
    41  
    42  // Descriptor returns the enum descriptor.
    43  // It returns nil for a zero Enum value.
    44  func (e Enum) Descriptor() protoreflect.EnumDescriptor {
    45  	return e.ed
    46  }
    47  
    48  // Number returns the enum value as an integer.
    49  func (e Enum) Number() protoreflect.EnumNumber {
    50  	return e.num
    51  }
    52  
    53  // Equal reports whether e1 and e2 represent the same enum value.
    54  func (e1 Enum) Equal(e2 Enum) bool {
    55  	if e1.ed.FullName() != e2.ed.FullName() {
    56  		return false
    57  	}
    58  	return e1.num == e2.num
    59  }
    60  
    61  // String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
    62  // otherwise it returns the formatted decimal enum number (e.g., "14").
    63  func (e Enum) String() string {
    64  	if ev := e.ed.Values().ByNumber(e.num); ev != nil {
    65  		return string(ev.Name())
    66  	}
    67  	return strconv.Itoa(int(e.num))
    68  }
    69  
    70  const (
    71  	// messageTypeKey indicates the protobuf message type.
    72  	// The value type is always messageMeta.
    73  	// From the public API, it presents itself as only the type, but the
    74  	// underlying data structure holds arbitrary metadata about the message.
    75  	messageTypeKey = "@type"
    76  
    77  	// messageInvalidKey indicates that the message is invalid.
    78  	// The value is always the boolean "true".
    79  	messageInvalidKey = "@invalid"
    80  )
    81  
    82  type messageMeta struct {
    83  	m   proto.Message
    84  	md  protoreflect.MessageDescriptor
    85  	xds map[string]protoreflect.ExtensionDescriptor
    86  }
    87  
    88  func (t messageMeta) String() string {
    89  	return string(t.md.FullName())
    90  }
    91  
    92  func (t1 messageMeta) Equal(t2 messageMeta) bool {
    93  	return t1.md.FullName() == t2.md.FullName()
    94  }
    95  
    96  // Message is a dynamic representation of a protocol buffer message that is
    97  // suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon.
    98  //
    99  // Every populated known field (excluding extension fields) is stored in the map
   100  // with the key being the short name of the field (e.g., "field_name") and
   101  // the value determined by the kind and cardinality of the field.
   102  //
   103  // Singular scalars are represented by the same Go type as [protoreflect.Value],
   104  // singular messages are represented by the [Message] type,
   105  // singular enums are represented by the [Enum] type,
   106  // list fields are represented as a Go slice, and
   107  // map fields are represented as a Go map.
   108  //
   109  // Every populated extension field is stored in the map with the key being the
   110  // full name of the field surrounded by brackets (e.g., "[extension.full.name]")
   111  // and the value determined according to the same rules as known fields.
   112  //
   113  // Every unknown field is stored in the map with the key being the field number
   114  // encoded as a decimal string (e.g., "132") and the value being the raw bytes
   115  // of the encoded field (as the [protoreflect.RawFields] type).
   116  //
   117  // Message values must not be created by or mutated by users.
   118  type Message map[string]interface{}
   119  
   120  // Unwrap returns the original message value.
   121  // It returns nil if this Message was not constructed from another message.
   122  func (m Message) Unwrap() proto.Message {
   123  	mm, _ := m[messageTypeKey].(messageMeta)
   124  	return mm.m
   125  }
   126  
   127  // Descriptor return the message descriptor.
   128  // It returns nil for a zero Message value.
   129  func (m Message) Descriptor() protoreflect.MessageDescriptor {
   130  	mm, _ := m[messageTypeKey].(messageMeta)
   131  	return mm.md
   132  }
   133  
   134  // ProtoReflect returns a reflective view of m.
   135  // It only implements the read-only operations of [protoreflect.Message].
   136  // Calling any mutating operations on m panics.
   137  func (m Message) ProtoReflect() protoreflect.Message {
   138  	return (reflectMessage)(m)
   139  }
   140  
   141  // ProtoMessage is a marker method from the legacy message interface.
   142  func (m Message) ProtoMessage() {}
   143  
   144  // Reset is the required Reset method from the legacy message interface.
   145  func (m Message) Reset() {
   146  	panic("invalid mutation of a read-only message")
   147  }
   148  
   149  // String returns a formatted string for the message.
   150  // It is intended for human debugging and has no guarantees about its
   151  // exact format or the stability of its output.
   152  func (m Message) String() string {
   153  	switch {
   154  	case m == nil:
   155  		return "<nil>"
   156  	case !m.ProtoReflect().IsValid():
   157  		return "<invalid>"
   158  	default:
   159  		return msgfmt.Format(m)
   160  	}
   161  }
   162  
   163  type transformer struct {
   164  	resolver protoregistry.MessageTypeResolver
   165  }
   166  
   167  func newTransformer(opts ...option) *transformer {
   168  	xf := &transformer{
   169  		resolver: protoregistry.GlobalTypes,
   170  	}
   171  	for _, opt := range opts {
   172  		opt(xf)
   173  	}
   174  	return xf
   175  }
   176  
   177  type option func(*transformer)
   178  
   179  // MessageTypeResolver overrides the resolver used for messages packed
   180  // inside Any. The default is protoregistry.GlobalTypes, which is
   181  // sufficient for all compiled-in Protobuf messages. Overriding the
   182  // resolver is useful in tests that dynamically create Protobuf
   183  // descriptors and messages, e.g. in proxies using dynamicpb.
   184  func MessageTypeResolver(r protoregistry.MessageTypeResolver) option {
   185  	return func(xf *transformer) {
   186  		xf.resolver = r
   187  	}
   188  }
   189  
   190  // Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
   191  // The transformation does not mutate nor alias any converted messages.
   192  //
   193  // The google.protobuf.Any message is automatically unmarshaled such that the
   194  // "value" field is a [Message] representing the underlying message value
   195  // assuming it could be resolved and properly unmarshaled.
   196  //
   197  // This does not directly transform higher-order composite Go types.
   198  // For example, []*foopb.Message is not transformed into []Message,
   199  // but rather the individual message elements of the slice are transformed.
   200  func Transform(opts ...option) cmp.Option {
   201  	xf := newTransformer(opts...)
   202  
   203  	// addrType returns a pointer to t if t isn't a pointer or interface.
   204  	addrType := func(t reflect.Type) reflect.Type {
   205  		if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
   206  			return t
   207  		}
   208  		return reflect.PtrTo(t)
   209  	}
   210  
   211  	// TODO: Should this transform protoreflect.Enum types to Enum as well?
   212  	return cmp.FilterPath(func(p cmp.Path) bool {
   213  		ps := p.Last()
   214  		if isMessageType(addrType(ps.Type())) {
   215  			return true
   216  		}
   217  
   218  		// Check whether the concrete values of an interface both satisfy
   219  		// the Message interface.
   220  		if ps.Type().Kind() == reflect.Interface {
   221  			vx, vy := ps.Values()
   222  			if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
   223  				return false
   224  			}
   225  			return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
   226  		}
   227  
   228  		return false
   229  	}, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
   230  		// For user convenience, shallow copy the message value if necessary
   231  		// in order for it to implement the message interface.
   232  		if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
   233  			pv := reflect.New(rv.Type())
   234  			pv.Elem().Set(rv)
   235  			v = pv.Interface()
   236  		}
   237  
   238  		m := protoimpl.X.MessageOf(v)
   239  		switch {
   240  		case m == nil:
   241  			return nil
   242  		case !m.IsValid():
   243  			return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
   244  		default:
   245  			return xf.transformMessage(m)
   246  		}
   247  	}))
   248  }
   249  
   250  func isMessageType(t reflect.Type) bool {
   251  	// Avoid transforming the Message itself.
   252  	if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
   253  		return false
   254  	}
   255  	return t.Implements(messageV1Type) || t.Implements(messageV2Type)
   256  }
   257  
   258  func (xf *transformer) transformMessage(m protoreflect.Message) Message {
   259  	mx := Message{}
   260  	mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
   261  
   262  	// Handle known and extension fields.
   263  	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   264  		s := fd.TextName()
   265  		if fd.IsExtension() {
   266  			mt.xds[s] = fd
   267  		}
   268  		switch {
   269  		case fd.IsList():
   270  			mx[s] = xf.transformList(fd, v.List())
   271  		case fd.IsMap():
   272  			mx[s] = xf.transformMap(fd, v.Map())
   273  		default:
   274  			mx[s] = xf.transformSingular(fd, v)
   275  		}
   276  		return true
   277  	})
   278  
   279  	// Handle unknown fields.
   280  	for b := m.GetUnknown(); len(b) > 0; {
   281  		num, _, n := protowire.ConsumeField(b)
   282  		s := strconv.Itoa(int(num))
   283  		b2, _ := mx[s].(protoreflect.RawFields)
   284  		mx[s] = append(b2, b[:n]...)
   285  		b = b[n:]
   286  	}
   287  
   288  	// Expand Any messages.
   289  	if mt.md.FullName() == genid.Any_message_fullname {
   290  		s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
   291  		b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
   292  		mt, err := xf.resolver.FindMessageByURL(s)
   293  		if mt != nil && err == nil {
   294  			m2 := mt.New()
   295  			err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
   296  			if err == nil {
   297  				mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2)
   298  			}
   299  		}
   300  	}
   301  
   302  	mx[messageTypeKey] = mt
   303  	return mx
   304  }
   305  
   306  func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
   307  	t := protoKindToGoType(fd.Kind())
   308  	rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
   309  	for i := 0; i < lv.Len(); i++ {
   310  		v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i)))
   311  		rv.Index(i).Set(v)
   312  	}
   313  	return rv.Interface()
   314  }
   315  
   316  func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
   317  	kfd := fd.MapKey()
   318  	vfd := fd.MapValue()
   319  	kt := protoKindToGoType(kfd.Kind())
   320  	vt := protoKindToGoType(vfd.Kind())
   321  	rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
   322  	mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   323  		kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value()))
   324  		vv := reflect.ValueOf(xf.transformSingular(vfd, v))
   325  		rv.SetMapIndex(kv, vv)
   326  		return true
   327  	})
   328  	return rv.Interface()
   329  }
   330  
   331  func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
   332  	switch fd.Kind() {
   333  	case protoreflect.EnumKind:
   334  		return Enum{num: v.Enum(), ed: fd.Enum()}
   335  	case protoreflect.MessageKind, protoreflect.GroupKind:
   336  		return xf.transformMessage(v.Message())
   337  	case protoreflect.BytesKind:
   338  		// The protoreflect API does not specify whether an empty bytes is
   339  		// guaranteed to be nil or not. Always return non-nil bytes to avoid
   340  		// leaking information about the concrete proto.Message implementation.
   341  		if len(v.Bytes()) == 0 {
   342  			return []byte{}
   343  		}
   344  		return v.Bytes()
   345  	default:
   346  		return v.Interface()
   347  	}
   348  }
   349  
   350  func protoKindToGoType(k protoreflect.Kind) reflect.Type {
   351  	switch k {
   352  	case protoreflect.BoolKind:
   353  		return reflect.TypeOf(bool(false))
   354  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   355  		return reflect.TypeOf(int32(0))
   356  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   357  		return reflect.TypeOf(int64(0))
   358  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   359  		return reflect.TypeOf(uint32(0))
   360  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   361  		return reflect.TypeOf(uint64(0))
   362  	case protoreflect.FloatKind:
   363  		return reflect.TypeOf(float32(0))
   364  	case protoreflect.DoubleKind:
   365  		return reflect.TypeOf(float64(0))
   366  	case protoreflect.StringKind:
   367  		return reflect.TypeOf(string(""))
   368  	case protoreflect.BytesKind:
   369  		return reflect.TypeOf([]byte(nil))
   370  	case protoreflect.EnumKind:
   371  		return reflect.TypeOf(Enum{})
   372  	case protoreflect.MessageKind, protoreflect.GroupKind:
   373  		return reflect.TypeOf(Message{})
   374  	default:
   375  		panic("invalid kind")
   376  	}
   377  }
   378  

View as plain text