1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package middleware
16
17 import (
18 "encoding"
19 "encoding/base64"
20 "fmt"
21 "io"
22 "net/http"
23 "reflect"
24 "strconv"
25
26 "github.com/go-openapi/errors"
27 "github.com/go-openapi/spec"
28 "github.com/go-openapi/strfmt"
29 "github.com/go-openapi/swag"
30 "github.com/go-openapi/validate"
31
32 "github.com/go-openapi/runtime"
33 )
34
35 const defaultMaxMemory = 32 << 20
36
37 const (
38 typeString = "string"
39 typeArray = "array"
40 )
41
42 var textUnmarshalType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
43
44 func newUntypedParamBinder(param spec.Parameter, spec *spec.Swagger, formats strfmt.Registry) *untypedParamBinder {
45 binder := new(untypedParamBinder)
46 binder.Name = param.Name
47 binder.parameter = ¶m
48 binder.formats = formats
49 if param.In != "body" {
50 binder.validator = validate.NewParamValidator(¶m, formats)
51 } else {
52 binder.validator = validate.NewSchemaValidator(param.Schema, spec, param.Name, formats)
53 }
54
55 return binder
56 }
57
58 type untypedParamBinder struct {
59 parameter *spec.Parameter
60 formats strfmt.Registry
61 Name string
62 validator validate.EntityValidator
63 }
64
65 func (p *untypedParamBinder) Type() reflect.Type {
66 return p.typeForSchema(p.parameter.Type, p.parameter.Format, p.parameter.Items)
67 }
68
69 func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type {
70 switch tpe {
71 case "boolean":
72 return reflect.TypeOf(true)
73
74 case typeString:
75 if tt, ok := p.formats.GetType(format); ok {
76 return tt
77 }
78 return reflect.TypeOf("")
79
80 case "integer":
81 switch format {
82 case "int8":
83 return reflect.TypeOf(int8(0))
84 case "int16":
85 return reflect.TypeOf(int16(0))
86 case "int32":
87 return reflect.TypeOf(int32(0))
88 case "int64":
89 return reflect.TypeOf(int64(0))
90 default:
91 return reflect.TypeOf(int64(0))
92 }
93
94 case "number":
95 switch format {
96 case "float":
97 return reflect.TypeOf(float32(0))
98 case "double":
99 return reflect.TypeOf(float64(0))
100 }
101
102 case typeArray:
103 if items == nil {
104 return nil
105 }
106 itemsType := p.typeForSchema(items.Type, items.Format, items.Items)
107 if itemsType == nil {
108 return nil
109 }
110 return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type()
111
112 case "file":
113 return reflect.TypeOf(&runtime.File{}).Elem()
114
115 case "object":
116 return reflect.TypeOf(map[string]interface{}{})
117 }
118 return nil
119 }
120
121 func (p *untypedParamBinder) allowsMulti() bool {
122 return p.parameter.In == "query" || p.parameter.In == "formData"
123 }
124
125 func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) {
126 name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type
127 if tpe == typeArray {
128 if cf == "multi" {
129 if !p.allowsMulti() {
130 return nil, false, false, errors.InvalidCollectionFormat(name, in, cf)
131 }
132 vv, hasKey, _ := values.GetOK(name)
133 return vv, false, hasKey, nil
134 }
135
136 v, hk, hv := values.GetOK(name)
137 if !hv {
138 return nil, false, hk, nil
139 }
140 d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target)
141 return d, c, hk, e
142 }
143
144 vv, hk, _ := values.GetOK(name)
145 return vv, false, hk, nil
146 }
147
148 func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error {
149
150 switch p.parameter.In {
151 case "query":
152 data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target)
153 if err != nil {
154 return err
155 }
156 if custom {
157 return nil
158 }
159
160 return p.bindValue(data, hasKey, target)
161
162 case "header":
163 data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target)
164 if err != nil {
165 return err
166 }
167 if custom {
168 return nil
169 }
170 return p.bindValue(data, hasKey, target)
171
172 case "path":
173 data, custom, hasKey, err := p.readValue(routeParams, target)
174 if err != nil {
175 return err
176 }
177 if custom {
178 return nil
179 }
180 return p.bindValue(data, hasKey, target)
181
182 case "formData":
183 var err error
184 var mt string
185
186 mt, _, e := runtime.ContentType(request.Header)
187 if e != nil {
188
189
190 err = e
191 }
192
193 if err != nil {
194 return errors.InvalidContentType("", []string{"multipart/form-data", "application/x-www-form-urlencoded"})
195 }
196
197 if mt != "multipart/form-data" && mt != "application/x-www-form-urlencoded" {
198 return errors.InvalidContentType(mt, []string{"multipart/form-data", "application/x-www-form-urlencoded"})
199 }
200
201 if mt == "multipart/form-data" {
202 if err = request.ParseMultipartForm(defaultMaxMemory); err != nil {
203 return errors.NewParseError(p.Name, p.parameter.In, "", err)
204 }
205 }
206
207 if err = request.ParseForm(); err != nil {
208 return errors.NewParseError(p.Name, p.parameter.In, "", err)
209 }
210
211 if p.parameter.Type == "file" {
212 file, header, ffErr := request.FormFile(p.parameter.Name)
213 if ffErr != nil {
214 if p.parameter.Required {
215 return errors.NewParseError(p.Name, p.parameter.In, "", ffErr)
216 }
217
218 return nil
219 }
220
221 target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header}))
222 return nil
223 }
224
225 if request.MultipartForm != nil {
226 data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target)
227 if rvErr != nil {
228 return rvErr
229 }
230 if custom {
231 return nil
232 }
233 return p.bindValue(data, hasKey, target)
234 }
235 data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target)
236 if err != nil {
237 return err
238 }
239 if custom {
240 return nil
241 }
242 return p.bindValue(data, hasKey, target)
243
244 case "body":
245 newValue := reflect.New(target.Type())
246 if !runtime.HasBody(request) {
247 if p.parameter.Default != nil {
248 target.Set(reflect.ValueOf(p.parameter.Default))
249 }
250
251 return nil
252 }
253 if err := consumer.Consume(request.Body, newValue.Interface()); err != nil {
254 if err == io.EOF && p.parameter.Default != nil {
255 target.Set(reflect.ValueOf(p.parameter.Default))
256 return nil
257 }
258 tpe := p.parameter.Type
259 if p.parameter.Format != "" {
260 tpe = p.parameter.Format
261 }
262 return errors.InvalidType(p.Name, p.parameter.In, tpe, nil)
263 }
264 target.Set(reflect.Indirect(newValue))
265 return nil
266 default:
267 return errors.New(500, fmt.Sprintf("invalid parameter location %q", p.parameter.In))
268 }
269 }
270
271 func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflect.Value) error {
272 if p.parameter.Type == typeArray {
273 return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey)
274 }
275 var d string
276 if len(data) > 0 {
277 d = data[len(data)-1]
278 }
279 return p.setFieldValue(target, p.parameter.Default, d, hasKey)
280 }
281
282 func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue interface{}, data string, hasKey bool) error {
283 tpe := p.parameter.Type
284 if p.parameter.Format != "" {
285 tpe = p.parameter.Format
286 }
287
288 if (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) && p.parameter.Required && p.parameter.Default == nil {
289 return errors.Required(p.Name, p.parameter.In, data)
290 }
291
292 ok, err := p.tryUnmarshaler(target, defaultValue, data)
293 if err != nil {
294 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
295 }
296 if ok {
297 return nil
298 }
299
300 defVal := reflect.Zero(target.Type())
301 if defaultValue != nil {
302 defVal = reflect.ValueOf(defaultValue)
303 }
304
305 if tpe == "byte" {
306 if data == "" {
307 if target.CanSet() {
308 target.SetBytes(defVal.Bytes())
309 }
310 return nil
311 }
312
313 b, err := base64.StdEncoding.DecodeString(data)
314 if err != nil {
315 b, err = base64.URLEncoding.DecodeString(data)
316 if err != nil {
317 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
318 }
319 }
320 if target.CanSet() {
321 target.SetBytes(b)
322 }
323 return nil
324 }
325
326 switch target.Kind() {
327 case reflect.Bool:
328 if data == "" {
329 if target.CanSet() {
330 target.SetBool(defVal.Bool())
331 }
332 return nil
333 }
334 b, err := swag.ConvertBool(data)
335 if err != nil {
336 return err
337 }
338 if target.CanSet() {
339 target.SetBool(b)
340 }
341 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
342 if data == "" {
343 if target.CanSet() {
344 rd := defVal.Convert(reflect.TypeOf(int64(0)))
345 target.SetInt(rd.Int())
346 }
347 return nil
348 }
349 i, err := strconv.ParseInt(data, 10, 64)
350 if err != nil {
351 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
352 }
353 if target.OverflowInt(i) {
354 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
355 }
356 if target.CanSet() {
357 target.SetInt(i)
358 }
359
360 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
361 if data == "" {
362 if target.CanSet() {
363 rd := defVal.Convert(reflect.TypeOf(uint64(0)))
364 target.SetUint(rd.Uint())
365 }
366 return nil
367 }
368 u, err := strconv.ParseUint(data, 10, 64)
369 if err != nil {
370 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
371 }
372 if target.OverflowUint(u) {
373 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
374 }
375 if target.CanSet() {
376 target.SetUint(u)
377 }
378
379 case reflect.Float32, reflect.Float64:
380 if data == "" {
381 if target.CanSet() {
382 rd := defVal.Convert(reflect.TypeOf(float64(0)))
383 target.SetFloat(rd.Float())
384 }
385 return nil
386 }
387 f, err := strconv.ParseFloat(data, 64)
388 if err != nil {
389 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
390 }
391 if target.OverflowFloat(f) {
392 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
393 }
394 if target.CanSet() {
395 target.SetFloat(f)
396 }
397
398 case reflect.String:
399 value := data
400 if value == "" {
401 value = defVal.String()
402 }
403
404 if target.CanSet() {
405 target.SetString(value)
406 }
407
408 case reflect.Ptr:
409 if data == "" && defVal.Kind() == reflect.Ptr {
410 if target.CanSet() {
411 target.Set(defVal)
412 }
413 return nil
414 }
415 newVal := reflect.New(target.Type().Elem())
416 if err := p.setFieldValue(reflect.Indirect(newVal), defVal, data, hasKey); err != nil {
417 return err
418 }
419 if target.CanSet() {
420 target.Set(newVal)
421 }
422
423 default:
424 return errors.InvalidType(p.Name, p.parameter.In, tpe, data)
425 }
426 return nil
427 }
428
429 func (p *untypedParamBinder) tryUnmarshaler(target reflect.Value, defaultValue interface{}, data string) (bool, error) {
430 if !target.CanSet() {
431 return false, nil
432 }
433
434 if reflect.PtrTo(target.Type()).Implements(textUnmarshalType) {
435 if defaultValue != nil && len(data) == 0 {
436 target.Set(reflect.ValueOf(defaultValue))
437 return true, nil
438 }
439 value := reflect.New(target.Type())
440 if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil {
441 return true, err
442 }
443 target.Set(reflect.Indirect(value))
444 return true, nil
445 }
446 return false, nil
447 }
448
449 func (p *untypedParamBinder) readFormattedSliceFieldValue(data string, target reflect.Value) ([]string, bool, error) {
450 ok, err := p.tryUnmarshaler(target, p.parameter.Default, data)
451 if err != nil {
452 return nil, true, err
453 }
454 if ok {
455 return nil, true, nil
456 }
457
458 return swag.SplitByFormat(data, p.parameter.CollectionFormat), false, nil
459 }
460
461 func (p *untypedParamBinder) setSliceFieldValue(target reflect.Value, defaultValue interface{}, data []string, hasKey bool) error {
462 sz := len(data)
463 if (!hasKey || (!p.parameter.AllowEmptyValue && (sz == 0 || (sz == 1 && data[0] == "")))) && p.parameter.Required && defaultValue == nil {
464 return errors.Required(p.Name, p.parameter.In, data)
465 }
466
467 defVal := reflect.Zero(target.Type())
468 if defaultValue != nil {
469 defVal = reflect.ValueOf(defaultValue)
470 }
471
472 if !target.CanSet() {
473 return nil
474 }
475 if sz == 0 {
476 target.Set(defVal)
477 return nil
478 }
479
480 value := reflect.MakeSlice(reflect.SliceOf(target.Type().Elem()), sz, sz)
481
482 for i := 0; i < sz; i++ {
483 if err := p.setFieldValue(value.Index(i), nil, data[i], hasKey); err != nil {
484 return err
485 }
486 }
487
488 target.Set(value)
489
490 return nil
491 }
492
View as plain text