...

Source file src/go.einride.tech/aip/fieldbehavior/fieldbehavior.go

Documentation: go.einride.tech/aip/fieldbehavior

     1  package fieldbehavior
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"google.golang.org/genproto/googleapis/api/annotations"
     7  	"google.golang.org/protobuf/proto"
     8  	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
     9  )
    10  
    11  // Get returns the field behavior of the provided field descriptor.
    12  func Get(field protoreflect.FieldDescriptor) []annotations.FieldBehavior {
    13  	if behaviors, ok := proto.GetExtension(
    14  		field.Options(), annotations.E_FieldBehavior,
    15  	).([]annotations.FieldBehavior); ok {
    16  		return behaviors
    17  	}
    18  	return nil
    19  }
    20  
    21  // Has returns true if the provided field descriptor has the wanted field behavior.
    22  func Has(field protoreflect.FieldDescriptor, want annotations.FieldBehavior) bool {
    23  	for _, got := range Get(field) {
    24  		if got == want {
    25  			return true
    26  		}
    27  	}
    28  	return false
    29  }
    30  
    31  // ClearFields clears all fields annotated with any of the provided behaviors.
    32  // This can be used to ignore fields provided as input that have field_behavior's
    33  // such as OUTPUT_ONLY and IMMUTABLE.
    34  //
    35  // See: https://google.aip.dev/161#output-only-fields
    36  func ClearFields(message proto.Message, behaviorsToClear ...annotations.FieldBehavior) {
    37  	clearFieldsWithBehaviors(message, behaviorsToClear...)
    38  }
    39  
    40  // CopyFields copies all fields annotated with any of the provided behaviors from src to dst.
    41  func CopyFields(dst, src proto.Message, behaviorsToCopy ...annotations.FieldBehavior) {
    42  	dstReflect := dst.ProtoReflect()
    43  	srcReflect := src.ProtoReflect()
    44  	if dstReflect.Descriptor() != srcReflect.Descriptor() {
    45  		panic(fmt.Sprintf(
    46  			"different types of dst (%s) and src (%s)",
    47  			dstReflect.Type().Descriptor().FullName(),
    48  			srcReflect.Type().Descriptor().FullName(),
    49  		))
    50  	}
    51  	for i := 0; i < dstReflect.Descriptor().Fields().Len(); i++ {
    52  		dstField := dstReflect.Descriptor().Fields().Get(i)
    53  		if hasAnyBehavior(Get(dstField), behaviorsToCopy) {
    54  			srcField := srcReflect.Descriptor().Fields().Get(i)
    55  			if isMessageFieldPresent(srcReflect, srcField) {
    56  				dstReflect.Set(dstField, srcReflect.Get(srcField))
    57  			} else {
    58  				dstReflect.Clear(dstField)
    59  			}
    60  		}
    61  	}
    62  }
    63  
    64  func isMessageFieldPresent(m protoreflect.Message, f protoreflect.FieldDescriptor) bool {
    65  	return isPresent(m.Get(f), f)
    66  }
    67  
    68  func isPresent(v protoreflect.Value, f protoreflect.FieldDescriptor) bool {
    69  	if !v.IsValid() {
    70  		return false
    71  	}
    72  	if f.IsList() {
    73  		return v.List().Len() > 0
    74  	}
    75  	if f.IsMap() {
    76  		return v.Map().Len() > 0
    77  	}
    78  	switch f.Kind() {
    79  	case protoreflect.EnumKind:
    80  		return v.Enum() != 0
    81  	case protoreflect.BoolKind:
    82  		return v.Bool()
    83  	case protoreflect.Int32Kind,
    84  		protoreflect.Sint32Kind,
    85  		protoreflect.Int64Kind,
    86  		protoreflect.Sint64Kind,
    87  		protoreflect.Sfixed32Kind,
    88  		protoreflect.Sfixed64Kind:
    89  		return v.Int() != 0
    90  	case protoreflect.Uint32Kind,
    91  		protoreflect.Uint64Kind,
    92  		protoreflect.Fixed32Kind,
    93  		protoreflect.Fixed64Kind:
    94  		return v.Uint() != 0
    95  	case protoreflect.FloatKind,
    96  		protoreflect.DoubleKind:
    97  		return v.Float() != 0
    98  	case protoreflect.StringKind:
    99  		return len(v.String()) > 0
   100  	case protoreflect.BytesKind:
   101  		return len(v.Bytes()) > 0
   102  	case protoreflect.MessageKind:
   103  		return v.Message().IsValid()
   104  	case protoreflect.GroupKind:
   105  		return v.IsValid()
   106  	default:
   107  		return v.IsValid()
   108  	}
   109  }
   110  
   111  func clearFieldsWithBehaviors(m proto.Message, behaviorsToClear ...annotations.FieldBehavior) {
   112  	rangeFieldsWithBehaviors(
   113  		m.ProtoReflect(),
   114  		func(
   115  			m protoreflect.Message,
   116  			f protoreflect.FieldDescriptor,
   117  			_ protoreflect.Value,
   118  			behaviors []annotations.FieldBehavior,
   119  		) bool {
   120  			if hasAnyBehavior(behaviors, behaviorsToClear) {
   121  				m.Clear(f)
   122  			}
   123  			return true
   124  		},
   125  	)
   126  }
   127  
   128  func rangeFieldsWithBehaviors(
   129  	m protoreflect.Message,
   130  	fn func(
   131  		protoreflect.Message,
   132  		protoreflect.FieldDescriptor,
   133  		protoreflect.Value,
   134  		[]annotations.FieldBehavior,
   135  	) bool,
   136  ) {
   137  	m.Range(
   138  		func(f protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   139  			if behaviors, ok := proto.GetExtension(
   140  				f.Options(),
   141  				annotations.E_FieldBehavior,
   142  			).([]annotations.FieldBehavior); ok {
   143  				fn(m, f, v, behaviors)
   144  			}
   145  
   146  			switch {
   147  			// if field is repeated, traverse the nested message for field behaviors
   148  			case f.IsList() && f.Kind() == protoreflect.MessageKind:
   149  				for i := 0; i < v.List().Len(); i++ {
   150  					rangeFieldsWithBehaviors(
   151  						v.List().Get(i).Message(),
   152  						fn,
   153  					)
   154  				}
   155  				return true
   156  			// if field is map, traverse the nested message for field behaviors
   157  			case f.IsMap() && f.MapValue().Kind() == protoreflect.MessageKind:
   158  				v.Map().Range(func(_ protoreflect.MapKey, mv protoreflect.Value) bool {
   159  					rangeFieldsWithBehaviors(
   160  						mv.Message(),
   161  						fn,
   162  					)
   163  					return true
   164  				})
   165  				return true
   166  			// if field is message, traverse the message
   167  			// maps are also treated as Kind message and should not be traversed as messages
   168  			case f.Kind() == protoreflect.MessageKind && !f.IsMap():
   169  				rangeFieldsWithBehaviors(
   170  					v.Message(),
   171  					fn,
   172  				)
   173  				return true
   174  			default:
   175  				return true
   176  			}
   177  		})
   178  }
   179  
   180  func hasAnyBehavior(haystack, needles []annotations.FieldBehavior) bool {
   181  	for _, needle := range needles {
   182  		if hasBehavior(haystack, needle) {
   183  			return true
   184  		}
   185  	}
   186  	return false
   187  }
   188  
   189  func hasBehavior(haystack []annotations.FieldBehavior, needle annotations.FieldBehavior) bool {
   190  	for _, straw := range haystack {
   191  		if straw == needle {
   192  			return true
   193  		}
   194  	}
   195  	return false
   196  }
   197  

View as plain text