...

Source file src/github.com/grpc-ecosystem/grpc-gateway/runtime/query.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/runtime

     1  package runtime
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"net/url"
     7  	"reflect"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/golang/protobuf/proto"
    14  	"github.com/grpc-ecosystem/grpc-gateway/utilities"
    15  	"google.golang.org/grpc/grpclog"
    16  )
    17  
    18  var valuesKeyRegexp = regexp.MustCompile("^(.*)\\[(.*)\\]$")
    19  
    20  var currentQueryParser QueryParameterParser = &defaultQueryParser{}
    21  
    22  // QueryParameterParser defines interface for all query parameter parsers
    23  type QueryParameterParser interface {
    24  	Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
    25  }
    26  
    27  // PopulateQueryParameters parses query parameters
    28  // into "msg" using current query parser
    29  func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
    30  	return currentQueryParser.Parse(msg, values, filter)
    31  }
    32  
    33  type defaultQueryParser struct{}
    34  
    35  // Parse populates "values" into "msg".
    36  // A value is ignored if its key starts with one of the elements in "filter".
    37  func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
    38  	for key, values := range values {
    39  		match := valuesKeyRegexp.FindStringSubmatch(key)
    40  		if len(match) == 3 {
    41  			key = match[1]
    42  			values = append([]string{match[2]}, values...)
    43  		}
    44  		fieldPath := strings.Split(key, ".")
    45  		if filter.HasCommonPrefix(fieldPath) {
    46  			continue
    47  		}
    48  		if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
    49  			return err
    50  		}
    51  	}
    52  	return nil
    53  }
    54  
    55  // PopulateFieldFromPath sets a value in a nested Protobuf structure.
    56  // It instantiates missing protobuf fields as it goes.
    57  func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
    58  	fieldPath := strings.Split(fieldPathString, ".")
    59  	return populateFieldValueFromPath(msg, fieldPath, []string{value})
    60  }
    61  
    62  func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
    63  	m := reflect.ValueOf(msg)
    64  	if m.Kind() != reflect.Ptr {
    65  		return fmt.Errorf("unexpected type %T: %v", msg, msg)
    66  	}
    67  	var props *proto.Properties
    68  	m = m.Elem()
    69  	for i, fieldName := range fieldPath {
    70  		isLast := i == len(fieldPath)-1
    71  		if !isLast && m.Kind() != reflect.Struct {
    72  			return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
    73  		}
    74  		var f reflect.Value
    75  		var err error
    76  		f, props, err = fieldByProtoName(m, fieldName)
    77  		if err != nil {
    78  			return err
    79  		} else if !f.IsValid() {
    80  			grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
    81  			return nil
    82  		}
    83  
    84  		switch f.Kind() {
    85  		case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
    86  			if !isLast {
    87  				return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
    88  			}
    89  			m = f
    90  		case reflect.Slice:
    91  			if !isLast {
    92  				return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
    93  			}
    94  			// Handle []byte
    95  			if f.Type().Elem().Kind() == reflect.Uint8 {
    96  				m = f
    97  				break
    98  			}
    99  			return populateRepeatedField(f, values, props)
   100  		case reflect.Ptr:
   101  			if f.IsNil() {
   102  				m = reflect.New(f.Type().Elem())
   103  				f.Set(m.Convert(f.Type()))
   104  			}
   105  			m = f.Elem()
   106  			continue
   107  		case reflect.Struct:
   108  			m = f
   109  			continue
   110  		case reflect.Map:
   111  			if !isLast {
   112  				return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
   113  			}
   114  			return populateMapField(f, values, props)
   115  		default:
   116  			return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
   117  		}
   118  	}
   119  	switch len(values) {
   120  	case 0:
   121  		return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
   122  	case 1:
   123  	default:
   124  		grpclog.Infof("too many field values: %s", strings.Join(fieldPath, "."))
   125  	}
   126  	return populateField(m, values[0], props)
   127  }
   128  
   129  // fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
   130  // "m" must be a struct value. It returns zero reflect.Value if no such field found.
   131  func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
   132  	props := proto.GetProperties(m.Type())
   133  
   134  	// look up field name in oneof map
   135  	for _, op := range props.OneofTypes {
   136  		if name == op.Prop.OrigName || name == op.Prop.JSONName {
   137  			v := reflect.New(op.Type.Elem())
   138  			field := m.Field(op.Field)
   139  			if !field.IsNil() {
   140  				return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
   141  			}
   142  			field.Set(v)
   143  			return v.Elem().Field(0), op.Prop, nil
   144  		}
   145  	}
   146  
   147  	for _, p := range props.Prop {
   148  		if p.OrigName == name {
   149  			return m.FieldByName(p.Name), p, nil
   150  		}
   151  		if p.JSONName == name {
   152  			return m.FieldByName(p.Name), p, nil
   153  		}
   154  	}
   155  	return reflect.Value{}, nil, nil
   156  }
   157  
   158  func populateMapField(f reflect.Value, values []string, props *proto.Properties) error {
   159  	if len(values) != 2 {
   160  		return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name)
   161  	}
   162  
   163  	key, value := values[0], values[1]
   164  	keyType := f.Type().Key()
   165  	valueType := f.Type().Elem()
   166  	if f.IsNil() {
   167  		f.Set(reflect.MakeMap(f.Type()))
   168  	}
   169  
   170  	keyConv, ok := convFromType[keyType.Kind()]
   171  	if !ok {
   172  		return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name)
   173  	}
   174  	valueConv, ok := convFromType[valueType.Kind()]
   175  	if !ok {
   176  		return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name)
   177  	}
   178  
   179  	keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)})
   180  	if err := keyV[1].Interface(); err != nil {
   181  		return err.(error)
   182  	}
   183  	valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)})
   184  	if err := valueV[1].Interface(); err != nil {
   185  		return err.(error)
   186  	}
   187  
   188  	f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType))
   189  
   190  	return nil
   191  }
   192  
   193  func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
   194  	elemType := f.Type().Elem()
   195  
   196  	// is the destination field a slice of an enumeration type?
   197  	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
   198  		return populateFieldEnumRepeated(f, values, enumValMap)
   199  	}
   200  
   201  	conv, ok := convFromType[elemType.Kind()]
   202  	if !ok {
   203  		return fmt.Errorf("unsupported field type %s", elemType)
   204  	}
   205  	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
   206  	for i, v := range values {
   207  		result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
   208  		if err := result[1].Interface(); err != nil {
   209  			return err.(error)
   210  		}
   211  		f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
   212  	}
   213  	return nil
   214  }
   215  
   216  func populateField(f reflect.Value, value string, props *proto.Properties) error {
   217  	i := f.Addr().Interface()
   218  
   219  	// Handle protobuf well known types
   220  	var name string
   221  	switch m := i.(type) {
   222  	case interface{ XXX_WellKnownType() string }:
   223  		name = m.XXX_WellKnownType()
   224  	case proto.Message:
   225  		const wktPrefix = "google.protobuf."
   226  		if fullName := proto.MessageName(m); strings.HasPrefix(fullName, wktPrefix) {
   227  			name = fullName[len(wktPrefix):]
   228  		}
   229  	}
   230  	switch name {
   231  	case "Timestamp":
   232  		if value == "null" {
   233  			f.FieldByName("Seconds").SetInt(0)
   234  			f.FieldByName("Nanos").SetInt(0)
   235  			return nil
   236  		}
   237  
   238  		t, err := time.Parse(time.RFC3339Nano, value)
   239  		if err != nil {
   240  			return fmt.Errorf("bad Timestamp: %v", err)
   241  		}
   242  		f.FieldByName("Seconds").SetInt(int64(t.Unix()))
   243  		f.FieldByName("Nanos").SetInt(int64(t.Nanosecond()))
   244  		return nil
   245  	case "Duration":
   246  		if value == "null" {
   247  			f.FieldByName("Seconds").SetInt(0)
   248  			f.FieldByName("Nanos").SetInt(0)
   249  			return nil
   250  		}
   251  		d, err := time.ParseDuration(value)
   252  		if err != nil {
   253  			return fmt.Errorf("bad Duration: %v", err)
   254  		}
   255  
   256  		ns := d.Nanoseconds()
   257  		s := ns / 1e9
   258  		ns %= 1e9
   259  		f.FieldByName("Seconds").SetInt(s)
   260  		f.FieldByName("Nanos").SetInt(ns)
   261  		return nil
   262  	case "DoubleValue":
   263  		fallthrough
   264  	case "FloatValue":
   265  		float64Val, err := strconv.ParseFloat(value, 64)
   266  		if err != nil {
   267  			return fmt.Errorf("bad DoubleValue: %s", value)
   268  		}
   269  		f.FieldByName("Value").SetFloat(float64Val)
   270  		return nil
   271  	case "Int64Value":
   272  		fallthrough
   273  	case "Int32Value":
   274  		int64Val, err := strconv.ParseInt(value, 10, 64)
   275  		if err != nil {
   276  			return fmt.Errorf("bad DoubleValue: %s", value)
   277  		}
   278  		f.FieldByName("Value").SetInt(int64Val)
   279  		return nil
   280  	case "UInt64Value":
   281  		fallthrough
   282  	case "UInt32Value":
   283  		uint64Val, err := strconv.ParseUint(value, 10, 64)
   284  		if err != nil {
   285  			return fmt.Errorf("bad DoubleValue: %s", value)
   286  		}
   287  		f.FieldByName("Value").SetUint(uint64Val)
   288  		return nil
   289  	case "BoolValue":
   290  		if value == "true" {
   291  			f.FieldByName("Value").SetBool(true)
   292  		} else if value == "false" {
   293  			f.FieldByName("Value").SetBool(false)
   294  		} else {
   295  			return fmt.Errorf("bad BoolValue: %s", value)
   296  		}
   297  		return nil
   298  	case "StringValue":
   299  		f.FieldByName("Value").SetString(value)
   300  		return nil
   301  	case "BytesValue":
   302  		bytesVal, err := base64.StdEncoding.DecodeString(value)
   303  		if err != nil {
   304  			return fmt.Errorf("bad BytesValue: %s", value)
   305  		}
   306  		f.FieldByName("Value").SetBytes(bytesVal)
   307  		return nil
   308  	case "FieldMask":
   309  		p := f.FieldByName("Paths")
   310  		for _, v := range strings.Split(value, ",") {
   311  			if v != "" {
   312  				p.Set(reflect.Append(p, reflect.ValueOf(v)))
   313  			}
   314  		}
   315  		return nil
   316  	}
   317  
   318  	// Handle Time and Duration stdlib types
   319  	switch t := i.(type) {
   320  	case *time.Time:
   321  		pt, err := time.Parse(time.RFC3339Nano, value)
   322  		if err != nil {
   323  			return fmt.Errorf("bad Timestamp: %v", err)
   324  		}
   325  		*t = pt
   326  		return nil
   327  	case *time.Duration:
   328  		d, err := time.ParseDuration(value)
   329  		if err != nil {
   330  			return fmt.Errorf("bad Duration: %v", err)
   331  		}
   332  		*t = d
   333  		return nil
   334  	}
   335  
   336  	// is the destination field an enumeration type?
   337  	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
   338  		return populateFieldEnum(f, value, enumValMap)
   339  	}
   340  
   341  	conv, ok := convFromType[f.Kind()]
   342  	if !ok {
   343  		return fmt.Errorf("field type %T is not supported in query parameters", i)
   344  	}
   345  	result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
   346  	if err := result[1].Interface(); err != nil {
   347  		return err.(error)
   348  	}
   349  	f.Set(result[0].Convert(f.Type()))
   350  	return nil
   351  }
   352  
   353  func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
   354  	// see if it's an enumeration string
   355  	if enumVal, ok := enumValMap[value]; ok {
   356  		return reflect.ValueOf(enumVal).Convert(t), nil
   357  	}
   358  
   359  	// check for an integer that matches an enumeration value
   360  	eVal, err := strconv.Atoi(value)
   361  	if err != nil {
   362  		return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
   363  	}
   364  	for _, v := range enumValMap {
   365  		if v == int32(eVal) {
   366  			return reflect.ValueOf(eVal).Convert(t), nil
   367  		}
   368  	}
   369  	return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
   370  }
   371  
   372  func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
   373  	cval, err := convertEnum(value, f.Type(), enumValMap)
   374  	if err != nil {
   375  		return err
   376  	}
   377  	f.Set(cval)
   378  	return nil
   379  }
   380  
   381  func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
   382  	elemType := f.Type().Elem()
   383  	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
   384  	for i, v := range values {
   385  		result, err := convertEnum(v, elemType, enumValMap)
   386  		if err != nil {
   387  			return err
   388  		}
   389  		f.Index(i).Set(result)
   390  	}
   391  	return nil
   392  }
   393  
   394  var (
   395  	convFromType = map[reflect.Kind]reflect.Value{
   396  		reflect.String:  reflect.ValueOf(String),
   397  		reflect.Bool:    reflect.ValueOf(Bool),
   398  		reflect.Float64: reflect.ValueOf(Float64),
   399  		reflect.Float32: reflect.ValueOf(Float32),
   400  		reflect.Int64:   reflect.ValueOf(Int64),
   401  		reflect.Int32:   reflect.ValueOf(Int32),
   402  		reflect.Uint64:  reflect.ValueOf(Uint64),
   403  		reflect.Uint32:  reflect.ValueOf(Uint32),
   404  		reflect.Slice:   reflect.ValueOf(Bytes),
   405  	}
   406  )
   407  

View as plain text