...

Source file src/github.com/go-openapi/runtime/middleware/parameter.go

Documentation: github.com/go-openapi/runtime/middleware

     1  // Copyright 2015 go-swagger maintainers
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package middleware
    16  
    17  import (
    18  	"encoding"
    19  	"encoding/base64"
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"reflect"
    24  	"strconv"
    25  
    26  	"github.com/go-openapi/errors"
    27  	"github.com/go-openapi/spec"
    28  	"github.com/go-openapi/strfmt"
    29  	"github.com/go-openapi/swag"
    30  	"github.com/go-openapi/validate"
    31  
    32  	"github.com/go-openapi/runtime"
    33  )
    34  
    35  const defaultMaxMemory = 32 << 20
    36  
    37  const (
    38  	typeString = "string"
    39  	typeArray  = "array"
    40  )
    41  
    42  var textUnmarshalType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
    43  
    44  func newUntypedParamBinder(param spec.Parameter, spec *spec.Swagger, formats strfmt.Registry) *untypedParamBinder {
    45  	binder := new(untypedParamBinder)
    46  	binder.Name = param.Name
    47  	binder.parameter = &param
    48  	binder.formats = formats
    49  	if param.In != "body" {
    50  		binder.validator = validate.NewParamValidator(&param, formats)
    51  	} else {
    52  		binder.validator = validate.NewSchemaValidator(param.Schema, spec, param.Name, formats)
    53  	}
    54  
    55  	return binder
    56  }
    57  
    58  type untypedParamBinder struct {
    59  	parameter *spec.Parameter
    60  	formats   strfmt.Registry
    61  	Name      string
    62  	validator validate.EntityValidator
    63  }
    64  
    65  func (p *untypedParamBinder) Type() reflect.Type {
    66  	return p.typeForSchema(p.parameter.Type, p.parameter.Format, p.parameter.Items)
    67  }
    68  
    69  func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type {
    70  	switch tpe {
    71  	case "boolean":
    72  		return reflect.TypeOf(true)
    73  
    74  	case typeString:
    75  		if tt, ok := p.formats.GetType(format); ok {
    76  			return tt
    77  		}
    78  		return reflect.TypeOf("")
    79  
    80  	case "integer":
    81  		switch format {
    82  		case "int8":
    83  			return reflect.TypeOf(int8(0))
    84  		case "int16":
    85  			return reflect.TypeOf(int16(0))
    86  		case "int32":
    87  			return reflect.TypeOf(int32(0))
    88  		case "int64":
    89  			return reflect.TypeOf(int64(0))
    90  		default:
    91  			return reflect.TypeOf(int64(0))
    92  		}
    93  
    94  	case "number":
    95  		switch format {
    96  		case "float":
    97  			return reflect.TypeOf(float32(0))
    98  		case "double":
    99  			return reflect.TypeOf(float64(0))
   100  		}
   101  
   102  	case typeArray:
   103  		if items == nil {
   104  			return nil
   105  		}
   106  		itemsType := p.typeForSchema(items.Type, items.Format, items.Items)
   107  		if itemsType == nil {
   108  			return nil
   109  		}
   110  		return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type()
   111  
   112  	case "file":
   113  		return reflect.TypeOf(&runtime.File{}).Elem()
   114  
   115  	case "object":
   116  		return reflect.TypeOf(map[string]interface{}{})
   117  	}
   118  	return nil
   119  }
   120  
   121  func (p *untypedParamBinder) allowsMulti() bool {
   122  	return p.parameter.In == "query" || p.parameter.In == "formData"
   123  }
   124  
   125  func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) {
   126  	name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type
   127  	if tpe == typeArray {
   128  		if cf == "multi" {
   129  			if !p.allowsMulti() {
   130  				return nil, false, false, errors.InvalidCollectionFormat(name, in, cf)
   131  			}
   132  			vv, hasKey, _ := values.GetOK(name)
   133  			return vv, false, hasKey, nil
   134  		}
   135  
   136  		v, hk, hv := values.GetOK(name)
   137  		if !hv {
   138  			return nil, false, hk, nil
   139  		}
   140  		d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target)
   141  		return d, c, hk, e
   142  	}
   143  
   144  	vv, hk, _ := values.GetOK(name)
   145  	return vv, false, hk, nil
   146  }
   147  
   148  func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error {
   149  	// fmt.Println("binding", p.name, "as", p.Type())
   150  	switch p.parameter.In {
   151  	case "query":
   152  		data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target)
   153  		if err != nil {
   154  			return err
   155  		}
   156  		if custom {
   157  			return nil
   158  		}
   159  
   160  		return p.bindValue(data, hasKey, target)
   161  
   162  	case "header":
   163  		data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target)
   164  		if err != nil {
   165  			return err
   166  		}
   167  		if custom {
   168  			return nil
   169  		}
   170  		return p.bindValue(data, hasKey, target)
   171  
   172  	case "path":
   173  		data, custom, hasKey, err := p.readValue(routeParams, target)
   174  		if err != nil {
   175  			return err
   176  		}
   177  		if custom {
   178  			return nil
   179  		}
   180  		return p.bindValue(data, hasKey, target)
   181  
   182  	case "formData":
   183  		var err error
   184  		var mt string
   185  
   186  		mt, _, e := runtime.ContentType(request.Header)
   187  		if e != nil {
   188  			// because of the interface conversion go thinks the error is not nil
   189  			// so we first check for nil and then set the err var if it's not nil
   190  			err = e
   191  		}
   192  
   193  		if err != nil {
   194  			return errors.InvalidContentType("", []string{"multipart/form-data", "application/x-www-form-urlencoded"})
   195  		}
   196  
   197  		if mt != "multipart/form-data" && mt != "application/x-www-form-urlencoded" {
   198  			return errors.InvalidContentType(mt, []string{"multipart/form-data", "application/x-www-form-urlencoded"})
   199  		}
   200  
   201  		if mt == "multipart/form-data" {
   202  			if err = request.ParseMultipartForm(defaultMaxMemory); err != nil {
   203  				return errors.NewParseError(p.Name, p.parameter.In, "", err)
   204  			}
   205  		}
   206  
   207  		if err = request.ParseForm(); err != nil {
   208  			return errors.NewParseError(p.Name, p.parameter.In, "", err)
   209  		}
   210  
   211  		if p.parameter.Type == "file" {
   212  			file, header, ffErr := request.FormFile(p.parameter.Name)
   213  			if ffErr != nil {
   214  				if p.parameter.Required {
   215  					return errors.NewParseError(p.Name, p.parameter.In, "", ffErr)
   216  				}
   217  
   218  				return nil
   219  			}
   220  
   221  			target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header}))
   222  			return nil
   223  		}
   224  
   225  		if request.MultipartForm != nil {
   226  			data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target)
   227  			if rvErr != nil {
   228  				return rvErr
   229  			}
   230  			if custom {
   231  				return nil
   232  			}
   233  			return p.bindValue(data, hasKey, target)
   234  		}
   235  		data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target)
   236  		if err != nil {
   237  			return err
   238  		}
   239  		if custom {
   240  			return nil
   241  		}
   242  		return p.bindValue(data, hasKey, target)
   243  
   244  	case "body":
   245  		newValue := reflect.New(target.Type())
   246  		if !runtime.HasBody(request) {
   247  			if p.parameter.Default != nil {
   248  				target.Set(reflect.ValueOf(p.parameter.Default))
   249  			}
   250  
   251  			return nil
   252  		}
   253  		if err := consumer.Consume(request.Body, newValue.Interface()); err != nil {
   254  			if err == io.EOF && p.parameter.Default != nil {
   255  				target.Set(reflect.ValueOf(p.parameter.Default))
   256  				return nil
   257  			}
   258  			tpe := p.parameter.Type
   259  			if p.parameter.Format != "" {
   260  				tpe = p.parameter.Format
   261  			}
   262  			return errors.InvalidType(p.Name, p.parameter.In, tpe, nil)
   263  		}
   264  		target.Set(reflect.Indirect(newValue))
   265  		return nil
   266  	default:
   267  		return errors.New(500, fmt.Sprintf("invalid parameter location %q", p.parameter.In))
   268  	}
   269  }
   270  
   271  func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflect.Value) error {
   272  	if p.parameter.Type == typeArray {
   273  		return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey)
   274  	}
   275  	var d string
   276  	if len(data) > 0 {
   277  		d = data[len(data)-1]
   278  	}
   279  	return p.setFieldValue(target, p.parameter.Default, d, hasKey)
   280  }
   281  
   282  func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue interface{}, data string, hasKey bool) error { //nolint:gocyclo
   283  	tpe := p.parameter.Type
   284  	if p.parameter.Format != "" {
   285  		tpe = p.parameter.Format
   286  	}
   287  
   288  	if (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) && p.parameter.Required && p.parameter.Default == nil {
   289  		return errors.Required(p.Name, p.parameter.In, data)
   290  	}
   291  
   292  	ok, err := p.tryUnmarshaler(target, defaultValue, data)
   293  	if err != nil {
   294  		return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   295  	}
   296  	if ok {
   297  		return nil
   298  	}
   299  
   300  	defVal := reflect.Zero(target.Type())
   301  	if defaultValue != nil {
   302  		defVal = reflect.ValueOf(defaultValue)
   303  	}
   304  
   305  	if tpe == "byte" {
   306  		if data == "" {
   307  			if target.CanSet() {
   308  				target.SetBytes(defVal.Bytes())
   309  			}
   310  			return nil
   311  		}
   312  
   313  		b, err := base64.StdEncoding.DecodeString(data)
   314  		if err != nil {
   315  			b, err = base64.URLEncoding.DecodeString(data)
   316  			if err != nil {
   317  				return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   318  			}
   319  		}
   320  		if target.CanSet() {
   321  			target.SetBytes(b)
   322  		}
   323  		return nil
   324  	}
   325  
   326  	switch target.Kind() { //nolint:exhaustive // we want to check only types that map from a swagger parameter
   327  	case reflect.Bool:
   328  		if data == "" {
   329  			if target.CanSet() {
   330  				target.SetBool(defVal.Bool())
   331  			}
   332  			return nil
   333  		}
   334  		b, err := swag.ConvertBool(data)
   335  		if err != nil {
   336  			return err
   337  		}
   338  		if target.CanSet() {
   339  			target.SetBool(b)
   340  		}
   341  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   342  		if data == "" {
   343  			if target.CanSet() {
   344  				rd := defVal.Convert(reflect.TypeOf(int64(0)))
   345  				target.SetInt(rd.Int())
   346  			}
   347  			return nil
   348  		}
   349  		i, err := strconv.ParseInt(data, 10, 64)
   350  		if err != nil {
   351  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   352  		}
   353  		if target.OverflowInt(i) {
   354  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   355  		}
   356  		if target.CanSet() {
   357  			target.SetInt(i)
   358  		}
   359  
   360  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   361  		if data == "" {
   362  			if target.CanSet() {
   363  				rd := defVal.Convert(reflect.TypeOf(uint64(0)))
   364  				target.SetUint(rd.Uint())
   365  			}
   366  			return nil
   367  		}
   368  		u, err := strconv.ParseUint(data, 10, 64)
   369  		if err != nil {
   370  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   371  		}
   372  		if target.OverflowUint(u) {
   373  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   374  		}
   375  		if target.CanSet() {
   376  			target.SetUint(u)
   377  		}
   378  
   379  	case reflect.Float32, reflect.Float64:
   380  		if data == "" {
   381  			if target.CanSet() {
   382  				rd := defVal.Convert(reflect.TypeOf(float64(0)))
   383  				target.SetFloat(rd.Float())
   384  			}
   385  			return nil
   386  		}
   387  		f, err := strconv.ParseFloat(data, 64)
   388  		if err != nil {
   389  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   390  		}
   391  		if target.OverflowFloat(f) {
   392  			return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   393  		}
   394  		if target.CanSet() {
   395  			target.SetFloat(f)
   396  		}
   397  
   398  	case reflect.String:
   399  		value := data
   400  		if value == "" {
   401  			value = defVal.String()
   402  		}
   403  		// validate string
   404  		if target.CanSet() {
   405  			target.SetString(value)
   406  		}
   407  
   408  	case reflect.Ptr:
   409  		if data == "" && defVal.Kind() == reflect.Ptr {
   410  			if target.CanSet() {
   411  				target.Set(defVal)
   412  			}
   413  			return nil
   414  		}
   415  		newVal := reflect.New(target.Type().Elem())
   416  		if err := p.setFieldValue(reflect.Indirect(newVal), defVal, data, hasKey); err != nil {
   417  			return err
   418  		}
   419  		if target.CanSet() {
   420  			target.Set(newVal)
   421  		}
   422  
   423  	default:
   424  		return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
   425  	}
   426  	return nil
   427  }
   428  
   429  func (p *untypedParamBinder) tryUnmarshaler(target reflect.Value, defaultValue interface{}, data string) (bool, error) {
   430  	if !target.CanSet() {
   431  		return false, nil
   432  	}
   433  	// When a type implements encoding.TextUnmarshaler we'll use that instead of reflecting some more
   434  	if reflect.PtrTo(target.Type()).Implements(textUnmarshalType) {
   435  		if defaultValue != nil && len(data) == 0 {
   436  			target.Set(reflect.ValueOf(defaultValue))
   437  			return true, nil
   438  		}
   439  		value := reflect.New(target.Type())
   440  		if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil {
   441  			return true, err
   442  		}
   443  		target.Set(reflect.Indirect(value))
   444  		return true, nil
   445  	}
   446  	return false, nil
   447  }
   448  
   449  func (p *untypedParamBinder) readFormattedSliceFieldValue(data string, target reflect.Value) ([]string, bool, error) {
   450  	ok, err := p.tryUnmarshaler(target, p.parameter.Default, data)
   451  	if err != nil {
   452  		return nil, true, err
   453  	}
   454  	if ok {
   455  		return nil, true, nil
   456  	}
   457  
   458  	return swag.SplitByFormat(data, p.parameter.CollectionFormat), false, nil
   459  }
   460  
   461  func (p *untypedParamBinder) setSliceFieldValue(target reflect.Value, defaultValue interface{}, data []string, hasKey bool) error {
   462  	sz := len(data)
   463  	if (!hasKey || (!p.parameter.AllowEmptyValue && (sz == 0 || (sz == 1 && data[0] == "")))) && p.parameter.Required && defaultValue == nil {
   464  		return errors.Required(p.Name, p.parameter.In, data)
   465  	}
   466  
   467  	defVal := reflect.Zero(target.Type())
   468  	if defaultValue != nil {
   469  		defVal = reflect.ValueOf(defaultValue)
   470  	}
   471  
   472  	if !target.CanSet() {
   473  		return nil
   474  	}
   475  	if sz == 0 {
   476  		target.Set(defVal)
   477  		return nil
   478  	}
   479  
   480  	value := reflect.MakeSlice(reflect.SliceOf(target.Type().Elem()), sz, sz)
   481  
   482  	for i := 0; i < sz; i++ {
   483  		if err := p.setFieldValue(value.Index(i), nil, data[i], hasKey); err != nil {
   484  			return err
   485  		}
   486  	}
   487  
   488  	target.Set(value)
   489  
   490  	return nil
   491  }
   492  

View as plain text