...

Source file src/go.einride.tech/aip/fieldmask/validate.go

Documentation: go.einride.tech/aip/fieldmask

     1  package fieldmask
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"google.golang.org/protobuf/proto"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"google.golang.org/protobuf/types/known/fieldmaskpb"
    10  )
    11  
    12  // Validate validates that the paths in the provided field mask are syntactically valid and
    13  // refer to known fields in the specified message type.
    14  func Validate(fm *fieldmaskpb.FieldMask, m proto.Message) error {
    15  	// special case for '*'
    16  	if stringsContain(WildcardPath, fm.GetPaths()) {
    17  		if len(fm.GetPaths()) != 1 {
    18  			return fmt.Errorf("invalid field path: '*' must not be used with other paths")
    19  		}
    20  		return nil
    21  	}
    22  	md0 := m.ProtoReflect().Descriptor()
    23  	for _, path := range fm.GetPaths() {
    24  		md := md0
    25  		if !rangeFields(path, func(field string) bool {
    26  			// Search the field within the message.
    27  			if md == nil {
    28  				return false // not within a message
    29  			}
    30  			fd := md.Fields().ByName(protoreflect.Name(field))
    31  			// The real field name of a group is the message name.
    32  			if fd == nil {
    33  				gd := md.Fields().ByName(protoreflect.Name(strings.ToLower(field)))
    34  				if gd != nil && gd.Kind() == protoreflect.GroupKind && string(gd.Message().Name()) == field {
    35  					fd = gd
    36  				}
    37  			} else if fd.Kind() == protoreflect.GroupKind && string(fd.Message().Name()) != field {
    38  				fd = nil
    39  			}
    40  			if fd == nil {
    41  				return false // message has does not have this field
    42  			}
    43  			// Identify the next message to search within.
    44  			md = fd.Message() // may be nil
    45  			if fd.IsMap() {
    46  				md = fd.MapValue().Message() // may be nil
    47  			}
    48  			return true
    49  		}) {
    50  			return fmt.Errorf("invalid field path: %s", path)
    51  		}
    52  	}
    53  	return nil
    54  }
    55  
    56  func stringsContain(str string, ss []string) bool {
    57  	for _, s := range ss {
    58  		if s == str {
    59  			return true
    60  		}
    61  	}
    62  	return false
    63  }
    64  
    65  func rangeFields(path string, f func(field string) bool) bool {
    66  	for {
    67  		var field string
    68  		if i := strings.IndexByte(path, '.'); i >= 0 {
    69  			field, path = path[:i], path[i:]
    70  		} else {
    71  			field, path = path, ""
    72  		}
    73  		if !f(field) {
    74  			return false
    75  		}
    76  		if len(path) == 0 {
    77  			return true
    78  		}
    79  		path = strings.TrimPrefix(path, ".")
    80  	}
    81  }
    82  

View as plain text