...

Source file src/google.golang.org/protobuf/internal/impl/validate.go

Documentation: google.golang.org/protobuf/internal/impl

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package impl
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/bits"
    11  	"reflect"
    12  	"unicode/utf8"
    13  
    14  	"google.golang.org/protobuf/encoding/protowire"
    15  	"google.golang.org/protobuf/internal/encoding/messageset"
    16  	"google.golang.org/protobuf/internal/flags"
    17  	"google.golang.org/protobuf/internal/genid"
    18  	"google.golang.org/protobuf/internal/strs"
    19  	"google.golang.org/protobuf/reflect/protoreflect"
    20  	"google.golang.org/protobuf/reflect/protoregistry"
    21  	"google.golang.org/protobuf/runtime/protoiface"
    22  )
    23  
    24  // ValidationStatus is the result of validating the wire-format encoding of a message.
    25  type ValidationStatus int
    26  
    27  const (
    28  	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
    29  	// The validator was unable to render a judgement.
    30  	//
    31  	// The only causes of this status are an aberrant message type appearing somewhere
    32  	// in the message or a failure in the extension resolver.
    33  	ValidationUnknown ValidationStatus = iota + 1
    34  
    35  	// ValidationInvalid indicates that unmarshaling the message will fail.
    36  	ValidationInvalid
    37  
    38  	// ValidationValid indicates that unmarshaling the message will succeed.
    39  	ValidationValid
    40  )
    41  
    42  func (v ValidationStatus) String() string {
    43  	switch v {
    44  	case ValidationUnknown:
    45  		return "ValidationUnknown"
    46  	case ValidationInvalid:
    47  		return "ValidationInvalid"
    48  	case ValidationValid:
    49  		return "ValidationValid"
    50  	default:
    51  		return fmt.Sprintf("ValidationStatus(%d)", int(v))
    52  	}
    53  }
    54  
    55  // Validate determines whether the contents of the buffer are a valid wire encoding
    56  // of the message type.
    57  //
    58  // This function is exposed for testing.
    59  func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
    60  	mi, ok := mt.(*MessageInfo)
    61  	if !ok {
    62  		return out, ValidationUnknown
    63  	}
    64  	if in.Resolver == nil {
    65  		in.Resolver = protoregistry.GlobalTypes
    66  	}
    67  	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
    68  		flags:    in.Flags,
    69  		resolver: in.Resolver,
    70  	})
    71  	if o.initialized {
    72  		out.Flags |= protoiface.UnmarshalInitialized
    73  	}
    74  	return out, st
    75  }
    76  
    77  type validationInfo struct {
    78  	mi               *MessageInfo
    79  	typ              validationType
    80  	keyType, valType validationType
    81  
    82  	// For non-required fields, requiredBit is 0.
    83  	//
    84  	// For required fields, requiredBit's nth bit is set, where n is a
    85  	// unique index in the range [0, MessageInfo.numRequiredFields).
    86  	//
    87  	// If there are more than 64 required fields, requiredBit is 0.
    88  	requiredBit uint64
    89  }
    90  
    91  type validationType uint8
    92  
    93  const (
    94  	validationTypeOther validationType = iota
    95  	validationTypeMessage
    96  	validationTypeGroup
    97  	validationTypeMap
    98  	validationTypeRepeatedVarint
    99  	validationTypeRepeatedFixed32
   100  	validationTypeRepeatedFixed64
   101  	validationTypeVarint
   102  	validationTypeFixed32
   103  	validationTypeFixed64
   104  	validationTypeBytes
   105  	validationTypeUTF8String
   106  	validationTypeMessageSetItem
   107  )
   108  
   109  func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
   110  	var vi validationInfo
   111  	switch {
   112  	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
   113  		switch fd.Kind() {
   114  		case protoreflect.MessageKind:
   115  			vi.typ = validationTypeMessage
   116  			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   117  				vi.mi = getMessageInfo(ot.Field(0).Type)
   118  			}
   119  		case protoreflect.GroupKind:
   120  			vi.typ = validationTypeGroup
   121  			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
   122  				vi.mi = getMessageInfo(ot.Field(0).Type)
   123  			}
   124  		case protoreflect.StringKind:
   125  			if strs.EnforceUTF8(fd) {
   126  				vi.typ = validationTypeUTF8String
   127  			}
   128  		}
   129  	default:
   130  		vi = newValidationInfo(fd, ft)
   131  	}
   132  	if fd.Cardinality() == protoreflect.Required {
   133  		// Avoid overflow. The required field check is done with a 64-bit mask, with
   134  		// any message containing more than 64 required fields always reported as
   135  		// potentially uninitialized, so it is not important to get a precise count
   136  		// of the required fields past 64.
   137  		if mi.numRequiredFields < math.MaxUint8 {
   138  			mi.numRequiredFields++
   139  			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
   140  		}
   141  	}
   142  	return vi
   143  }
   144  
   145  func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
   146  	var vi validationInfo
   147  	switch {
   148  	case fd.IsList():
   149  		switch fd.Kind() {
   150  		case protoreflect.MessageKind:
   151  			vi.typ = validationTypeMessage
   152  			if ft.Kind() == reflect.Slice {
   153  				vi.mi = getMessageInfo(ft.Elem())
   154  			}
   155  		case protoreflect.GroupKind:
   156  			vi.typ = validationTypeGroup
   157  			if ft.Kind() == reflect.Slice {
   158  				vi.mi = getMessageInfo(ft.Elem())
   159  			}
   160  		case protoreflect.StringKind:
   161  			vi.typ = validationTypeBytes
   162  			if strs.EnforceUTF8(fd) {
   163  				vi.typ = validationTypeUTF8String
   164  			}
   165  		default:
   166  			switch wireTypes[fd.Kind()] {
   167  			case protowire.VarintType:
   168  				vi.typ = validationTypeRepeatedVarint
   169  			case protowire.Fixed32Type:
   170  				vi.typ = validationTypeRepeatedFixed32
   171  			case protowire.Fixed64Type:
   172  				vi.typ = validationTypeRepeatedFixed64
   173  			}
   174  		}
   175  	case fd.IsMap():
   176  		vi.typ = validationTypeMap
   177  		switch fd.MapKey().Kind() {
   178  		case protoreflect.StringKind:
   179  			if strs.EnforceUTF8(fd) {
   180  				vi.keyType = validationTypeUTF8String
   181  			}
   182  		}
   183  		switch fd.MapValue().Kind() {
   184  		case protoreflect.MessageKind:
   185  			vi.valType = validationTypeMessage
   186  			if ft.Kind() == reflect.Map {
   187  				vi.mi = getMessageInfo(ft.Elem())
   188  			}
   189  		case protoreflect.StringKind:
   190  			if strs.EnforceUTF8(fd) {
   191  				vi.valType = validationTypeUTF8String
   192  			}
   193  		}
   194  	default:
   195  		switch fd.Kind() {
   196  		case protoreflect.MessageKind:
   197  			vi.typ = validationTypeMessage
   198  			if !fd.IsWeak() {
   199  				vi.mi = getMessageInfo(ft)
   200  			}
   201  		case protoreflect.GroupKind:
   202  			vi.typ = validationTypeGroup
   203  			vi.mi = getMessageInfo(ft)
   204  		case protoreflect.StringKind:
   205  			vi.typ = validationTypeBytes
   206  			if strs.EnforceUTF8(fd) {
   207  				vi.typ = validationTypeUTF8String
   208  			}
   209  		default:
   210  			switch wireTypes[fd.Kind()] {
   211  			case protowire.VarintType:
   212  				vi.typ = validationTypeVarint
   213  			case protowire.Fixed32Type:
   214  				vi.typ = validationTypeFixed32
   215  			case protowire.Fixed64Type:
   216  				vi.typ = validationTypeFixed64
   217  			case protowire.BytesType:
   218  				vi.typ = validationTypeBytes
   219  			}
   220  		}
   221  	}
   222  	return vi
   223  }
   224  
   225  func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
   226  	mi.init()
   227  	type validationState struct {
   228  		typ              validationType
   229  		keyType, valType validationType
   230  		endGroup         protowire.Number
   231  		mi               *MessageInfo
   232  		tail             []byte
   233  		requiredMask     uint64
   234  	}
   235  
   236  	// Pre-allocate some slots to avoid repeated slice reallocation.
   237  	states := make([]validationState, 0, 16)
   238  	states = append(states, validationState{
   239  		typ: validationTypeMessage,
   240  		mi:  mi,
   241  	})
   242  	if groupTag > 0 {
   243  		states[0].typ = validationTypeGroup
   244  		states[0].endGroup = groupTag
   245  	}
   246  	initialized := true
   247  	start := len(b)
   248  State:
   249  	for len(states) > 0 {
   250  		st := &states[len(states)-1]
   251  		for len(b) > 0 {
   252  			// Parse the tag (field number and wire type).
   253  			var tag uint64
   254  			if b[0] < 0x80 {
   255  				tag = uint64(b[0])
   256  				b = b[1:]
   257  			} else if len(b) >= 2 && b[1] < 128 {
   258  				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
   259  				b = b[2:]
   260  			} else {
   261  				var n int
   262  				tag, n = protowire.ConsumeVarint(b)
   263  				if n < 0 {
   264  					return out, ValidationInvalid
   265  				}
   266  				b = b[n:]
   267  			}
   268  			var num protowire.Number
   269  			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
   270  				return out, ValidationInvalid
   271  			} else {
   272  				num = protowire.Number(n)
   273  			}
   274  			wtyp := protowire.Type(tag & 7)
   275  
   276  			if wtyp == protowire.EndGroupType {
   277  				if st.endGroup == num {
   278  					goto PopState
   279  				}
   280  				return out, ValidationInvalid
   281  			}
   282  			var vi validationInfo
   283  			switch {
   284  			case st.typ == validationTypeMap:
   285  				switch num {
   286  				case genid.MapEntry_Key_field_number:
   287  					vi.typ = st.keyType
   288  				case genid.MapEntry_Value_field_number:
   289  					vi.typ = st.valType
   290  					vi.mi = st.mi
   291  					vi.requiredBit = 1
   292  				}
   293  			case flags.ProtoLegacy && st.mi.isMessageSet:
   294  				switch num {
   295  				case messageset.FieldItem:
   296  					vi.typ = validationTypeMessageSetItem
   297  				}
   298  			default:
   299  				var f *coderFieldInfo
   300  				if int(num) < len(st.mi.denseCoderFields) {
   301  					f = st.mi.denseCoderFields[num]
   302  				} else {
   303  					f = st.mi.coderFields[num]
   304  				}
   305  				if f != nil {
   306  					vi = f.validation
   307  					if vi.typ == validationTypeMessage && vi.mi == nil {
   308  						// Probable weak field.
   309  						//
   310  						// TODO: Consider storing the results of this lookup somewhere
   311  						// rather than recomputing it on every validation.
   312  						fd := st.mi.Desc.Fields().ByNumber(num)
   313  						if fd == nil || !fd.IsWeak() {
   314  							break
   315  						}
   316  						messageName := fd.Message().FullName()
   317  						messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
   318  						switch err {
   319  						case nil:
   320  							vi.mi, _ = messageType.(*MessageInfo)
   321  						case protoregistry.NotFound:
   322  							vi.typ = validationTypeBytes
   323  						default:
   324  							return out, ValidationUnknown
   325  						}
   326  					}
   327  					break
   328  				}
   329  				// Possible extension field.
   330  				//
   331  				// TODO: We should return ValidationUnknown when:
   332  				//   1. The resolver is not frozen. (More extensions may be added to it.)
   333  				//   2. The resolver returns preg.NotFound.
   334  				// In this case, a type added to the resolver in the future could cause
   335  				// unmarshaling to begin failing. Supporting this requires some way to
   336  				// determine if the resolver is frozen.
   337  				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
   338  				if err != nil && err != protoregistry.NotFound {
   339  					return out, ValidationUnknown
   340  				}
   341  				if err == nil {
   342  					vi = getExtensionFieldInfo(xt).validation
   343  				}
   344  			}
   345  			if vi.requiredBit != 0 {
   346  				// Check that the field has a compatible wire type.
   347  				// We only need to consider non-repeated field types,
   348  				// since repeated fields (and maps) can never be required.
   349  				ok := false
   350  				switch vi.typ {
   351  				case validationTypeVarint:
   352  					ok = wtyp == protowire.VarintType
   353  				case validationTypeFixed32:
   354  					ok = wtyp == protowire.Fixed32Type
   355  				case validationTypeFixed64:
   356  					ok = wtyp == protowire.Fixed64Type
   357  				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
   358  					ok = wtyp == protowire.BytesType
   359  				case validationTypeGroup:
   360  					ok = wtyp == protowire.StartGroupType
   361  				}
   362  				if ok {
   363  					st.requiredMask |= vi.requiredBit
   364  				}
   365  			}
   366  
   367  			switch wtyp {
   368  			case protowire.VarintType:
   369  				if len(b) >= 10 {
   370  					switch {
   371  					case b[0] < 0x80:
   372  						b = b[1:]
   373  					case b[1] < 0x80:
   374  						b = b[2:]
   375  					case b[2] < 0x80:
   376  						b = b[3:]
   377  					case b[3] < 0x80:
   378  						b = b[4:]
   379  					case b[4] < 0x80:
   380  						b = b[5:]
   381  					case b[5] < 0x80:
   382  						b = b[6:]
   383  					case b[6] < 0x80:
   384  						b = b[7:]
   385  					case b[7] < 0x80:
   386  						b = b[8:]
   387  					case b[8] < 0x80:
   388  						b = b[9:]
   389  					case b[9] < 0x80 && b[9] < 2:
   390  						b = b[10:]
   391  					default:
   392  						return out, ValidationInvalid
   393  					}
   394  				} else {
   395  					switch {
   396  					case len(b) > 0 && b[0] < 0x80:
   397  						b = b[1:]
   398  					case len(b) > 1 && b[1] < 0x80:
   399  						b = b[2:]
   400  					case len(b) > 2 && b[2] < 0x80:
   401  						b = b[3:]
   402  					case len(b) > 3 && b[3] < 0x80:
   403  						b = b[4:]
   404  					case len(b) > 4 && b[4] < 0x80:
   405  						b = b[5:]
   406  					case len(b) > 5 && b[5] < 0x80:
   407  						b = b[6:]
   408  					case len(b) > 6 && b[6] < 0x80:
   409  						b = b[7:]
   410  					case len(b) > 7 && b[7] < 0x80:
   411  						b = b[8:]
   412  					case len(b) > 8 && b[8] < 0x80:
   413  						b = b[9:]
   414  					case len(b) > 9 && b[9] < 2:
   415  						b = b[10:]
   416  					default:
   417  						return out, ValidationInvalid
   418  					}
   419  				}
   420  				continue State
   421  			case protowire.BytesType:
   422  				var size uint64
   423  				if len(b) >= 1 && b[0] < 0x80 {
   424  					size = uint64(b[0])
   425  					b = b[1:]
   426  				} else if len(b) >= 2 && b[1] < 128 {
   427  					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
   428  					b = b[2:]
   429  				} else {
   430  					var n int
   431  					size, n = protowire.ConsumeVarint(b)
   432  					if n < 0 {
   433  						return out, ValidationInvalid
   434  					}
   435  					b = b[n:]
   436  				}
   437  				if size > uint64(len(b)) {
   438  					return out, ValidationInvalid
   439  				}
   440  				v := b[:size]
   441  				b = b[size:]
   442  				switch vi.typ {
   443  				case validationTypeMessage:
   444  					if vi.mi == nil {
   445  						return out, ValidationUnknown
   446  					}
   447  					vi.mi.init()
   448  					fallthrough
   449  				case validationTypeMap:
   450  					if vi.mi != nil {
   451  						vi.mi.init()
   452  					}
   453  					states = append(states, validationState{
   454  						typ:     vi.typ,
   455  						keyType: vi.keyType,
   456  						valType: vi.valType,
   457  						mi:      vi.mi,
   458  						tail:    b,
   459  					})
   460  					b = v
   461  					continue State
   462  				case validationTypeRepeatedVarint:
   463  					// Packed field.
   464  					for len(v) > 0 {
   465  						_, n := protowire.ConsumeVarint(v)
   466  						if n < 0 {
   467  							return out, ValidationInvalid
   468  						}
   469  						v = v[n:]
   470  					}
   471  				case validationTypeRepeatedFixed32:
   472  					// Packed field.
   473  					if len(v)%4 != 0 {
   474  						return out, ValidationInvalid
   475  					}
   476  				case validationTypeRepeatedFixed64:
   477  					// Packed field.
   478  					if len(v)%8 != 0 {
   479  						return out, ValidationInvalid
   480  					}
   481  				case validationTypeUTF8String:
   482  					if !utf8.Valid(v) {
   483  						return out, ValidationInvalid
   484  					}
   485  				}
   486  			case protowire.Fixed32Type:
   487  				if len(b) < 4 {
   488  					return out, ValidationInvalid
   489  				}
   490  				b = b[4:]
   491  			case protowire.Fixed64Type:
   492  				if len(b) < 8 {
   493  					return out, ValidationInvalid
   494  				}
   495  				b = b[8:]
   496  			case protowire.StartGroupType:
   497  				switch {
   498  				case vi.typ == validationTypeGroup:
   499  					if vi.mi == nil {
   500  						return out, ValidationUnknown
   501  					}
   502  					vi.mi.init()
   503  					states = append(states, validationState{
   504  						typ:      validationTypeGroup,
   505  						mi:       vi.mi,
   506  						endGroup: num,
   507  					})
   508  					continue State
   509  				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
   510  					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
   511  					if err != nil {
   512  						return out, ValidationInvalid
   513  					}
   514  					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
   515  					switch {
   516  					case err == protoregistry.NotFound:
   517  						b = b[n:]
   518  					case err != nil:
   519  						return out, ValidationUnknown
   520  					default:
   521  						xvi := getExtensionFieldInfo(xt).validation
   522  						if xvi.mi != nil {
   523  							xvi.mi.init()
   524  						}
   525  						states = append(states, validationState{
   526  							typ:  xvi.typ,
   527  							mi:   xvi.mi,
   528  							tail: b[n:],
   529  						})
   530  						b = v
   531  						continue State
   532  					}
   533  				default:
   534  					n := protowire.ConsumeFieldValue(num, wtyp, b)
   535  					if n < 0 {
   536  						return out, ValidationInvalid
   537  					}
   538  					b = b[n:]
   539  				}
   540  			default:
   541  				return out, ValidationInvalid
   542  			}
   543  		}
   544  		if st.endGroup != 0 {
   545  			return out, ValidationInvalid
   546  		}
   547  		if len(b) != 0 {
   548  			return out, ValidationInvalid
   549  		}
   550  		b = st.tail
   551  	PopState:
   552  		numRequiredFields := 0
   553  		switch st.typ {
   554  		case validationTypeMessage, validationTypeGroup:
   555  			numRequiredFields = int(st.mi.numRequiredFields)
   556  		case validationTypeMap:
   557  			// If this is a map field with a message value that contains
   558  			// required fields, require that the value be present.
   559  			if st.mi != nil && st.mi.numRequiredFields > 0 {
   560  				numRequiredFields = 1
   561  			}
   562  		}
   563  		// If there are more than 64 required fields, this check will
   564  		// always fail and we will report that the message is potentially
   565  		// uninitialized.
   566  		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
   567  			initialized = false
   568  		}
   569  		states = states[:len(states)-1]
   570  	}
   571  	out.n = start - len(b)
   572  	if initialized {
   573  		out.initialized = true
   574  	}
   575  	return out, ValidationValid
   576  }
   577  

View as plain text