...

Source file src/go.einride.tech/aip/fieldbehavior/required.go

Documentation: go.einride.tech/aip/fieldbehavior

     1  package fieldbehavior
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"google.golang.org/genproto/googleapis/api/annotations"
     7  	"google.golang.org/protobuf/proto"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"google.golang.org/protobuf/types/known/fieldmaskpb"
    10  )
    11  
    12  // ValidateRequiredFields returns a validation error if any field annotated as required does not have a value.
    13  // See: https://aip.dev/203
    14  func ValidateRequiredFields(m proto.Message) error {
    15  	return validateRequiredFields(
    16  		m.ProtoReflect(),
    17  		&fieldmaskpb.FieldMask{Paths: []string{"*"}},
    18  		"",
    19  	)
    20  }
    21  
    22  func ValidateRequiredFieldsWithMask(m proto.Message, mask *fieldmaskpb.FieldMask) error {
    23  	return validateRequiredFields(m.ProtoReflect(), mask, "")
    24  }
    25  
    26  func validateRequiredFields(reflectMessage protoreflect.Message, mask *fieldmaskpb.FieldMask, path string) error {
    27  	// If no paths are provided, the field mask should be treated to be equivalent
    28  	// to all fields set on the wire. This means that no required fields can be missing,
    29  	// since if they were missing they're not set on the wire.
    30  	if len(mask.GetPaths()) == 0 {
    31  		return nil
    32  	}
    33  	for i := 0; i < reflectMessage.Descriptor().Fields().Len(); i++ {
    34  		field := reflectMessage.Descriptor().Fields().Get(i)
    35  		currPath := path
    36  		if len(currPath) > 0 {
    37  			currPath += "."
    38  		}
    39  		currPath += string(field.Name())
    40  		if !isMessageFieldPresent(reflectMessage, field) {
    41  			if Has(field, annotations.FieldBehavior_REQUIRED) && hasPath(mask, currPath) {
    42  				return fmt.Errorf("missing required field: %s", currPath)
    43  			}
    44  		} else if field.Kind() == protoreflect.MessageKind {
    45  			value := reflectMessage.Get(field)
    46  			switch {
    47  			case field.IsList():
    48  				for i := 0; i < value.List().Len(); i++ {
    49  					if err := validateRequiredFields(value.List().Get(i).Message(), mask, currPath); err != nil {
    50  						return err
    51  					}
    52  				}
    53  			case field.IsMap():
    54  				if field.MapValue().Kind() != protoreflect.MessageKind {
    55  					continue
    56  				}
    57  				var mapErr error
    58  				value.Map().Range(func(_ protoreflect.MapKey, value protoreflect.Value) bool {
    59  					if err := validateRequiredFields(value.Message(), mask, currPath); err != nil {
    60  						mapErr = err
    61  						return false
    62  					}
    63  
    64  					return true
    65  				})
    66  				if mapErr != nil {
    67  					return mapErr
    68  				}
    69  			default:
    70  				if err := validateRequiredFields(value.Message(), mask, currPath); err != nil {
    71  					return err
    72  				}
    73  			}
    74  		}
    75  	}
    76  	return nil
    77  }
    78  
    79  func isEmpty(mask *fieldmaskpb.FieldMask) bool {
    80  	return mask == nil || len(mask.GetPaths()) == 0
    81  }
    82  
    83  func hasPath(mask *fieldmaskpb.FieldMask, needle string) bool {
    84  	if isEmpty(mask) {
    85  		return true
    86  	}
    87  	for _, straw := range mask.GetPaths() {
    88  		if straw == "*" || straw == needle {
    89  			return true
    90  		}
    91  	}
    92  	return false
    93  }
    94  

View as plain text