...

Source file src/github.com/golang/protobuf/proto/extensions.go

Documentation: github.com/golang/protobuf/proto

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package proto
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"reflect"
    11  
    12  	"google.golang.org/protobuf/encoding/protowire"
    13  	"google.golang.org/protobuf/proto"
    14  	"google.golang.org/protobuf/reflect/protoreflect"
    15  	"google.golang.org/protobuf/reflect/protoregistry"
    16  	"google.golang.org/protobuf/runtime/protoiface"
    17  	"google.golang.org/protobuf/runtime/protoimpl"
    18  )
    19  
    20  type (
    21  	// ExtensionDesc represents an extension descriptor and
    22  	// is used to interact with an extension field in a message.
    23  	//
    24  	// Variables of this type are generated in code by protoc-gen-go.
    25  	ExtensionDesc = protoimpl.ExtensionInfo
    26  
    27  	// ExtensionRange represents a range of message extensions.
    28  	// Used in code generated by protoc-gen-go.
    29  	ExtensionRange = protoiface.ExtensionRangeV1
    30  
    31  	// Deprecated: Do not use; this is an internal type.
    32  	Extension = protoimpl.ExtensionFieldV1
    33  
    34  	// Deprecated: Do not use; this is an internal type.
    35  	XXX_InternalExtensions = protoimpl.ExtensionFields
    36  )
    37  
    38  // ErrMissingExtension reports whether the extension was not present.
    39  var ErrMissingExtension = errors.New("proto: missing extension")
    40  
    41  var errNotExtendable = errors.New("proto: not an extendable proto.Message")
    42  
    43  // HasExtension reports whether the extension field is present in m
    44  // either as an explicitly populated field or as an unknown field.
    45  func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
    46  	mr := MessageReflect(m)
    47  	if mr == nil || !mr.IsValid() {
    48  		return false
    49  	}
    50  
    51  	// Check whether any populated known field matches the field number.
    52  	xtd := xt.TypeDescriptor()
    53  	if isValidExtension(mr.Descriptor(), xtd) {
    54  		has = mr.Has(xtd)
    55  	} else {
    56  		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
    57  			has = int32(fd.Number()) == xt.Field
    58  			return !has
    59  		})
    60  	}
    61  
    62  	// Check whether any unknown field matches the field number.
    63  	for b := mr.GetUnknown(); !has && len(b) > 0; {
    64  		num, _, n := protowire.ConsumeField(b)
    65  		has = int32(num) == xt.Field
    66  		b = b[n:]
    67  	}
    68  	return has
    69  }
    70  
    71  // ClearExtension removes the extension field from m
    72  // either as an explicitly populated field or as an unknown field.
    73  func ClearExtension(m Message, xt *ExtensionDesc) {
    74  	mr := MessageReflect(m)
    75  	if mr == nil || !mr.IsValid() {
    76  		return
    77  	}
    78  
    79  	xtd := xt.TypeDescriptor()
    80  	if isValidExtension(mr.Descriptor(), xtd) {
    81  		mr.Clear(xtd)
    82  	} else {
    83  		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
    84  			if int32(fd.Number()) == xt.Field {
    85  				mr.Clear(fd)
    86  				return false
    87  			}
    88  			return true
    89  		})
    90  	}
    91  	clearUnknown(mr, fieldNum(xt.Field))
    92  }
    93  
    94  // ClearAllExtensions clears all extensions from m.
    95  // This includes populated fields and unknown fields in the extension range.
    96  func ClearAllExtensions(m Message) {
    97  	mr := MessageReflect(m)
    98  	if mr == nil || !mr.IsValid() {
    99  		return
   100  	}
   101  
   102  	mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
   103  		if fd.IsExtension() {
   104  			mr.Clear(fd)
   105  		}
   106  		return true
   107  	})
   108  	clearUnknown(mr, mr.Descriptor().ExtensionRanges())
   109  }
   110  
   111  // GetExtension retrieves a proto2 extended field from m.
   112  //
   113  // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
   114  // then GetExtension parses the encoded field and returns a Go value of the specified type.
   115  // If the field is not present, then the default value is returned (if one is specified),
   116  // otherwise ErrMissingExtension is reported.
   117  //
   118  // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
   119  // then GetExtension returns the raw encoded bytes for the extension field.
   120  func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
   121  	mr := MessageReflect(m)
   122  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
   123  		return nil, errNotExtendable
   124  	}
   125  
   126  	// Retrieve the unknown fields for this extension field.
   127  	var bo protoreflect.RawFields
   128  	for bi := mr.GetUnknown(); len(bi) > 0; {
   129  		num, _, n := protowire.ConsumeField(bi)
   130  		if int32(num) == xt.Field {
   131  			bo = append(bo, bi[:n]...)
   132  		}
   133  		bi = bi[n:]
   134  	}
   135  
   136  	// For type incomplete descriptors, only retrieve the unknown fields.
   137  	if xt.ExtensionType == nil {
   138  		return []byte(bo), nil
   139  	}
   140  
   141  	// If the extension field only exists as unknown fields, unmarshal it.
   142  	// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
   143  	xtd := xt.TypeDescriptor()
   144  	if !isValidExtension(mr.Descriptor(), xtd) {
   145  		return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
   146  	}
   147  	if !mr.Has(xtd) && len(bo) > 0 {
   148  		m2 := mr.New()
   149  		if err := (proto.UnmarshalOptions{
   150  			Resolver: extensionResolver{xt},
   151  		}.Unmarshal(bo, m2.Interface())); err != nil {
   152  			return nil, err
   153  		}
   154  		if m2.Has(xtd) {
   155  			mr.Set(xtd, m2.Get(xtd))
   156  			clearUnknown(mr, fieldNum(xt.Field))
   157  		}
   158  	}
   159  
   160  	// Check whether the message has the extension field set or a default.
   161  	var pv protoreflect.Value
   162  	switch {
   163  	case mr.Has(xtd):
   164  		pv = mr.Get(xtd)
   165  	case xtd.HasDefault():
   166  		pv = xtd.Default()
   167  	default:
   168  		return nil, ErrMissingExtension
   169  	}
   170  
   171  	v := xt.InterfaceOf(pv)
   172  	rv := reflect.ValueOf(v)
   173  	if isScalarKind(rv.Kind()) {
   174  		rv2 := reflect.New(rv.Type())
   175  		rv2.Elem().Set(rv)
   176  		v = rv2.Interface()
   177  	}
   178  	return v, nil
   179  }
   180  
   181  // extensionResolver is a custom extension resolver that stores a single
   182  // extension type that takes precedence over the global registry.
   183  type extensionResolver struct{ xt protoreflect.ExtensionType }
   184  
   185  func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   186  	if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
   187  		return r.xt, nil
   188  	}
   189  	return protoregistry.GlobalTypes.FindExtensionByName(field)
   190  }
   191  
   192  func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   193  	if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
   194  		return r.xt, nil
   195  	}
   196  	return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
   197  }
   198  
   199  // GetExtensions returns a list of the extensions values present in m,
   200  // corresponding with the provided list of extension descriptors, xts.
   201  // If an extension is missing in m, the corresponding value is nil.
   202  func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
   203  	mr := MessageReflect(m)
   204  	if mr == nil || !mr.IsValid() {
   205  		return nil, errNotExtendable
   206  	}
   207  
   208  	vs := make([]interface{}, len(xts))
   209  	for i, xt := range xts {
   210  		v, err := GetExtension(m, xt)
   211  		if err != nil {
   212  			if err == ErrMissingExtension {
   213  				continue
   214  			}
   215  			return vs, err
   216  		}
   217  		vs[i] = v
   218  	}
   219  	return vs, nil
   220  }
   221  
   222  // SetExtension sets an extension field in m to the provided value.
   223  func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
   224  	mr := MessageReflect(m)
   225  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
   226  		return errNotExtendable
   227  	}
   228  
   229  	rv := reflect.ValueOf(v)
   230  	if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
   231  		return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
   232  	}
   233  	if rv.Kind() == reflect.Ptr {
   234  		if rv.IsNil() {
   235  			return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
   236  		}
   237  		if isScalarKind(rv.Elem().Kind()) {
   238  			v = rv.Elem().Interface()
   239  		}
   240  	}
   241  
   242  	xtd := xt.TypeDescriptor()
   243  	if !isValidExtension(mr.Descriptor(), xtd) {
   244  		return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
   245  	}
   246  	mr.Set(xtd, xt.ValueOf(v))
   247  	clearUnknown(mr, fieldNum(xt.Field))
   248  	return nil
   249  }
   250  
   251  // SetRawExtension inserts b into the unknown fields of m.
   252  //
   253  // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
   254  func SetRawExtension(m Message, fnum int32, b []byte) {
   255  	mr := MessageReflect(m)
   256  	if mr == nil || !mr.IsValid() {
   257  		return
   258  	}
   259  
   260  	// Verify that the raw field is valid.
   261  	for b0 := b; len(b0) > 0; {
   262  		num, _, n := protowire.ConsumeField(b0)
   263  		if int32(num) != fnum {
   264  			panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
   265  		}
   266  		b0 = b0[n:]
   267  	}
   268  
   269  	ClearExtension(m, &ExtensionDesc{Field: fnum})
   270  	mr.SetUnknown(append(mr.GetUnknown(), b...))
   271  }
   272  
   273  // ExtensionDescs returns a list of extension descriptors found in m,
   274  // containing descriptors for both populated extension fields in m and
   275  // also unknown fields of m that are in the extension range.
   276  // For the later case, an type incomplete descriptor is provided where only
   277  // the ExtensionDesc.Field field is populated.
   278  // The order of the extension descriptors is undefined.
   279  func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
   280  	mr := MessageReflect(m)
   281  	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
   282  		return nil, errNotExtendable
   283  	}
   284  
   285  	// Collect a set of known extension descriptors.
   286  	extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
   287  	mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   288  		if fd.IsExtension() {
   289  			xt := fd.(protoreflect.ExtensionTypeDescriptor)
   290  			if xd, ok := xt.Type().(*ExtensionDesc); ok {
   291  				extDescs[fd.Number()] = xd
   292  			}
   293  		}
   294  		return true
   295  	})
   296  
   297  	// Collect a set of unknown extension descriptors.
   298  	extRanges := mr.Descriptor().ExtensionRanges()
   299  	for b := mr.GetUnknown(); len(b) > 0; {
   300  		num, _, n := protowire.ConsumeField(b)
   301  		if extRanges.Has(num) && extDescs[num] == nil {
   302  			extDescs[num] = nil
   303  		}
   304  		b = b[n:]
   305  	}
   306  
   307  	// Transpose the set of descriptors into a list.
   308  	var xts []*ExtensionDesc
   309  	for num, xt := range extDescs {
   310  		if xt == nil {
   311  			xt = &ExtensionDesc{Field: int32(num)}
   312  		}
   313  		xts = append(xts, xt)
   314  	}
   315  	return xts, nil
   316  }
   317  
   318  // isValidExtension reports whether xtd is a valid extension descriptor for md.
   319  func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
   320  	return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
   321  }
   322  
   323  // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
   324  // This function exists for historical reasons since the representation of
   325  // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
   326  func isScalarKind(k reflect.Kind) bool {
   327  	switch k {
   328  	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
   329  		return true
   330  	default:
   331  		return false
   332  	}
   333  }
   334  
   335  // clearUnknown removes unknown fields from m where remover.Has reports true.
   336  func clearUnknown(m protoreflect.Message, remover interface {
   337  	Has(protoreflect.FieldNumber) bool
   338  }) {
   339  	var bo protoreflect.RawFields
   340  	for bi := m.GetUnknown(); len(bi) > 0; {
   341  		num, _, n := protowire.ConsumeField(bi)
   342  		if !remover.Has(num) {
   343  			bo = append(bo, bi[:n]...)
   344  		}
   345  		bi = bi[n:]
   346  	}
   347  	if bi := m.GetUnknown(); len(bi) != len(bo) {
   348  		m.SetUnknown(bo)
   349  	}
   350  }
   351  
   352  type fieldNum protoreflect.FieldNumber
   353  
   354  func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
   355  	return protoreflect.FieldNumber(n1) == n2
   356  }
   357  

View as plain text