...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/runtime/query.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/runtime

     1  package runtime
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/url"
     7  	"regexp"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
    13  	"google.golang.org/grpc/grpclog"
    14  	"google.golang.org/protobuf/encoding/protojson"
    15  	"google.golang.org/protobuf/proto"
    16  	"google.golang.org/protobuf/reflect/protoreflect"
    17  	"google.golang.org/protobuf/reflect/protoregistry"
    18  	"google.golang.org/protobuf/types/known/durationpb"
    19  	field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
    20  	"google.golang.org/protobuf/types/known/structpb"
    21  	"google.golang.org/protobuf/types/known/timestamppb"
    22  	"google.golang.org/protobuf/types/known/wrapperspb"
    23  )
    24  
    25  var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
    26  
    27  var currentQueryParser QueryParameterParser = &DefaultQueryParser{}
    28  
    29  // QueryParameterParser defines interface for all query parameter parsers
    30  type QueryParameterParser interface {
    31  	Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
    32  }
    33  
    34  // PopulateQueryParameters parses query parameters
    35  // into "msg" using current query parser
    36  func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
    37  	return currentQueryParser.Parse(msg, values, filter)
    38  }
    39  
    40  // DefaultQueryParser is a QueryParameterParser which implements the default
    41  // query parameters parsing behavior.
    42  //
    43  // See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context.
    44  type DefaultQueryParser struct{}
    45  
    46  // Parse populates "values" into "msg".
    47  // A value is ignored if its key starts with one of the elements in "filter".
    48  func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
    49  	for key, values := range values {
    50  		if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 {
    51  			key = match[1]
    52  			values = append([]string{match[2]}, values...)
    53  		}
    54  
    55  		msgValue := msg.ProtoReflect()
    56  		fieldPath := normalizeFieldPath(msgValue, strings.Split(key, "."))
    57  		if filter.HasCommonPrefix(fieldPath) {
    58  			continue
    59  		}
    60  		if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil {
    61  			return err
    62  		}
    63  	}
    64  	return nil
    65  }
    66  
    67  // PopulateFieldFromPath sets a value in a nested Protobuf structure.
    68  func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
    69  	fieldPath := strings.Split(fieldPathString, ".")
    70  	return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
    71  }
    72  
    73  func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string {
    74  	newFieldPath := make([]string, 0, len(fieldPath))
    75  	for i, fieldName := range fieldPath {
    76  		fields := msgValue.Descriptor().Fields()
    77  		fieldDesc := fields.ByTextName(fieldName)
    78  		if fieldDesc == nil {
    79  			fieldDesc = fields.ByJSONName(fieldName)
    80  		}
    81  		if fieldDesc == nil {
    82  			// return initial field path values if no matching  message field was found
    83  			return fieldPath
    84  		}
    85  
    86  		newFieldPath = append(newFieldPath, string(fieldDesc.Name()))
    87  
    88  		// If this is the last element, we're done
    89  		if i == len(fieldPath)-1 {
    90  			break
    91  		}
    92  
    93  		// Only singular message fields are allowed
    94  		if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated {
    95  			return fieldPath
    96  		}
    97  
    98  		// Get the nested message
    99  		msgValue = msgValue.Get(fieldDesc).Message()
   100  	}
   101  
   102  	return newFieldPath
   103  }
   104  
   105  func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
   106  	if len(fieldPath) < 1 {
   107  		return errors.New("no field path")
   108  	}
   109  	if len(values) < 1 {
   110  		return errors.New("no value provided")
   111  	}
   112  
   113  	var fieldDescriptor protoreflect.FieldDescriptor
   114  	for i, fieldName := range fieldPath {
   115  		fields := msgValue.Descriptor().Fields()
   116  
   117  		// Get field by name
   118  		fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
   119  		if fieldDescriptor == nil {
   120  			fieldDescriptor = fields.ByJSONName(fieldName)
   121  			if fieldDescriptor == nil {
   122  				// We're not returning an error here because this could just be
   123  				// an extra query parameter that isn't part of the request.
   124  				grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
   125  				return nil
   126  			}
   127  		}
   128  
   129  		// If this is the last element, we're done
   130  		if i == len(fieldPath)-1 {
   131  			break
   132  		}
   133  
   134  		// Only singular message fields are allowed
   135  		if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
   136  			return fmt.Errorf("invalid path: %q is not a message", fieldName)
   137  		}
   138  
   139  		// Get the nested message
   140  		msgValue = msgValue.Mutable(fieldDescriptor).Message()
   141  	}
   142  
   143  	// Check if oneof already set
   144  	if of := fieldDescriptor.ContainingOneof(); of != nil {
   145  		if f := msgValue.WhichOneof(of); f != nil {
   146  			return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
   147  		}
   148  	}
   149  
   150  	switch {
   151  	case fieldDescriptor.IsList():
   152  		return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
   153  	case fieldDescriptor.IsMap():
   154  		return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
   155  	}
   156  
   157  	if len(values) > 1 {
   158  		return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
   159  	}
   160  
   161  	return populateField(fieldDescriptor, msgValue, values[0])
   162  }
   163  
   164  func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
   165  	v, err := parseField(fieldDescriptor, value)
   166  	if err != nil {
   167  		return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
   168  	}
   169  
   170  	msgValue.Set(fieldDescriptor, v)
   171  	return nil
   172  }
   173  
   174  func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
   175  	for _, value := range values {
   176  		v, err := parseField(fieldDescriptor, value)
   177  		if err != nil {
   178  			return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
   179  		}
   180  		list.Append(v)
   181  	}
   182  
   183  	return nil
   184  }
   185  
   186  func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
   187  	if len(values) != 2 {
   188  		return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
   189  	}
   190  
   191  	key, err := parseField(fieldDescriptor.MapKey(), values[0])
   192  	if err != nil {
   193  		return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
   194  	}
   195  
   196  	value, err := parseField(fieldDescriptor.MapValue(), values[1])
   197  	if err != nil {
   198  		return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
   199  	}
   200  
   201  	mp.Set(key.MapKey(), value)
   202  
   203  	return nil
   204  }
   205  
   206  func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
   207  	switch fieldDescriptor.Kind() {
   208  	case protoreflect.BoolKind:
   209  		v, err := strconv.ParseBool(value)
   210  		if err != nil {
   211  			return protoreflect.Value{}, err
   212  		}
   213  		return protoreflect.ValueOfBool(v), nil
   214  	case protoreflect.EnumKind:
   215  		enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
   216  		if err != nil {
   217  			if errors.Is(err, protoregistry.NotFound) {
   218  				return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
   219  			}
   220  			return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
   221  		}
   222  		// Look for enum by name
   223  		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
   224  		if v == nil {
   225  			i, err := strconv.Atoi(value)
   226  			if err != nil {
   227  				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
   228  			}
   229  			// Look for enum by number
   230  			if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil {
   231  				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
   232  			}
   233  		}
   234  		return protoreflect.ValueOfEnum(v.Number()), nil
   235  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   236  		v, err := strconv.ParseInt(value, 10, 32)
   237  		if err != nil {
   238  			return protoreflect.Value{}, err
   239  		}
   240  		return protoreflect.ValueOfInt32(int32(v)), nil
   241  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   242  		v, err := strconv.ParseInt(value, 10, 64)
   243  		if err != nil {
   244  			return protoreflect.Value{}, err
   245  		}
   246  		return protoreflect.ValueOfInt64(v), nil
   247  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   248  		v, err := strconv.ParseUint(value, 10, 32)
   249  		if err != nil {
   250  			return protoreflect.Value{}, err
   251  		}
   252  		return protoreflect.ValueOfUint32(uint32(v)), nil
   253  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   254  		v, err := strconv.ParseUint(value, 10, 64)
   255  		if err != nil {
   256  			return protoreflect.Value{}, err
   257  		}
   258  		return protoreflect.ValueOfUint64(v), nil
   259  	case protoreflect.FloatKind:
   260  		v, err := strconv.ParseFloat(value, 32)
   261  		if err != nil {
   262  			return protoreflect.Value{}, err
   263  		}
   264  		return protoreflect.ValueOfFloat32(float32(v)), nil
   265  	case protoreflect.DoubleKind:
   266  		v, err := strconv.ParseFloat(value, 64)
   267  		if err != nil {
   268  			return protoreflect.Value{}, err
   269  		}
   270  		return protoreflect.ValueOfFloat64(v), nil
   271  	case protoreflect.StringKind:
   272  		return protoreflect.ValueOfString(value), nil
   273  	case protoreflect.BytesKind:
   274  		v, err := Bytes(value)
   275  		if err != nil {
   276  			return protoreflect.Value{}, err
   277  		}
   278  		return protoreflect.ValueOfBytes(v), nil
   279  	case protoreflect.MessageKind, protoreflect.GroupKind:
   280  		return parseMessage(fieldDescriptor.Message(), value)
   281  	default:
   282  		panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
   283  	}
   284  }
   285  
   286  func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
   287  	var msg proto.Message
   288  	switch msgDescriptor.FullName() {
   289  	case "google.protobuf.Timestamp":
   290  		t, err := time.Parse(time.RFC3339Nano, value)
   291  		if err != nil {
   292  			return protoreflect.Value{}, err
   293  		}
   294  		msg = timestamppb.New(t)
   295  	case "google.protobuf.Duration":
   296  		d, err := time.ParseDuration(value)
   297  		if err != nil {
   298  			return protoreflect.Value{}, err
   299  		}
   300  		msg = durationpb.New(d)
   301  	case "google.protobuf.DoubleValue":
   302  		v, err := strconv.ParseFloat(value, 64)
   303  		if err != nil {
   304  			return protoreflect.Value{}, err
   305  		}
   306  		msg = wrapperspb.Double(v)
   307  	case "google.protobuf.FloatValue":
   308  		v, err := strconv.ParseFloat(value, 32)
   309  		if err != nil {
   310  			return protoreflect.Value{}, err
   311  		}
   312  		msg = wrapperspb.Float(float32(v))
   313  	case "google.protobuf.Int64Value":
   314  		v, err := strconv.ParseInt(value, 10, 64)
   315  		if err != nil {
   316  			return protoreflect.Value{}, err
   317  		}
   318  		msg = wrapperspb.Int64(v)
   319  	case "google.protobuf.Int32Value":
   320  		v, err := strconv.ParseInt(value, 10, 32)
   321  		if err != nil {
   322  			return protoreflect.Value{}, err
   323  		}
   324  		msg = wrapperspb.Int32(int32(v))
   325  	case "google.protobuf.UInt64Value":
   326  		v, err := strconv.ParseUint(value, 10, 64)
   327  		if err != nil {
   328  			return protoreflect.Value{}, err
   329  		}
   330  		msg = wrapperspb.UInt64(v)
   331  	case "google.protobuf.UInt32Value":
   332  		v, err := strconv.ParseUint(value, 10, 32)
   333  		if err != nil {
   334  			return protoreflect.Value{}, err
   335  		}
   336  		msg = wrapperspb.UInt32(uint32(v))
   337  	case "google.protobuf.BoolValue":
   338  		v, err := strconv.ParseBool(value)
   339  		if err != nil {
   340  			return protoreflect.Value{}, err
   341  		}
   342  		msg = wrapperspb.Bool(v)
   343  	case "google.protobuf.StringValue":
   344  		msg = wrapperspb.String(value)
   345  	case "google.protobuf.BytesValue":
   346  		v, err := Bytes(value)
   347  		if err != nil {
   348  			return protoreflect.Value{}, err
   349  		}
   350  		msg = wrapperspb.Bytes(v)
   351  	case "google.protobuf.FieldMask":
   352  		fm := &field_mask.FieldMask{}
   353  		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
   354  		msg = fm
   355  	case "google.protobuf.Value":
   356  		var v structpb.Value
   357  		if err := protojson.Unmarshal([]byte(value), &v); err != nil {
   358  			return protoreflect.Value{}, err
   359  		}
   360  		msg = &v
   361  	case "google.protobuf.Struct":
   362  		var v structpb.Struct
   363  		if err := protojson.Unmarshal([]byte(value), &v); err != nil {
   364  			return protoreflect.Value{}, err
   365  		}
   366  		msg = &v
   367  	default:
   368  		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
   369  	}
   370  
   371  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   372  }
   373  

View as plain text