1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package middleware
16
17 import (
18 "mime"
19 "net/http"
20 "strings"
21
22 "github.com/go-openapi/errors"
23 "github.com/go-openapi/swag"
24
25 "github.com/go-openapi/runtime"
26 )
27
28 type validation struct {
29 context *Context
30 result []error
31 request *http.Request
32 route *MatchedRoute
33 bound map[string]interface{}
34 }
35
36
37 func validateContentType(allowed []string, actual string) error {
38 if len(allowed) == 0 {
39 return nil
40 }
41 mt, _, err := mime.ParseMediaType(actual)
42 if err != nil {
43 return errors.InvalidContentType(actual, allowed)
44 }
45 if swag.ContainsStringsCI(allowed, mt) {
46 return nil
47 }
48 if swag.ContainsStringsCI(allowed, "*/*") {
49 return nil
50 }
51 parts := strings.Split(actual, "/")
52 if len(parts) == 2 && swag.ContainsStringsCI(allowed, parts[0]+"/*") {
53 return nil
54 }
55 return errors.InvalidContentType(actual, allowed)
56 }
57
58 func validateRequest(ctx *Context, request *http.Request, route *MatchedRoute) *validation {
59 validate := &validation{
60 context: ctx,
61 request: request,
62 route: route,
63 bound: make(map[string]interface{}),
64 }
65 validate.debugLogf("validating request %s %s", request.Method, request.URL.EscapedPath())
66
67 validate.contentType()
68 if len(validate.result) == 0 {
69 validate.responseFormat()
70 }
71 if len(validate.result) == 0 {
72 validate.parameters()
73 }
74
75 return validate
76 }
77
78 func (v *validation) debugLogf(format string, args ...any) {
79 v.context.debugLogf(format, args...)
80 }
81
82 func (v *validation) parameters() {
83 v.debugLogf("validating request parameters for %s %s", v.request.Method, v.request.URL.EscapedPath())
84 if result := v.route.Binder.Bind(v.request, v.route.Params, v.route.Consumer, v.bound); result != nil {
85 if result.Error() == "validation failure list" {
86 for _, e := range result.(*errors.Validation).Value.([]interface{}) {
87 v.result = append(v.result, e.(error))
88 }
89 return
90 }
91 v.result = append(v.result, result)
92 }
93 }
94
95 func (v *validation) contentType() {
96 if len(v.result) == 0 && runtime.HasBody(v.request) {
97 v.debugLogf("validating body content type for %s %s", v.request.Method, v.request.URL.EscapedPath())
98 ct, _, req, err := v.context.ContentType(v.request)
99 if err != nil {
100 v.result = append(v.result, err)
101 } else {
102 v.request = req
103 }
104
105 if len(v.result) == 0 {
106 v.debugLogf("validating content type for %q against [%s]", ct, strings.Join(v.route.Consumes, ", "))
107 if err := validateContentType(v.route.Consumes, ct); err != nil {
108 v.result = append(v.result, err)
109 }
110 }
111 if ct != "" && v.route.Consumer == nil {
112 cons, ok := v.route.Consumers[ct]
113 if !ok {
114 v.result = append(v.result, errors.New(500, "no consumer registered for %s", ct))
115 } else {
116 v.route.Consumer = cons
117 }
118 }
119 }
120 }
121
122 func (v *validation) responseFormat() {
123
124
125
126 if str, rCtx := v.context.ResponseFormat(v.request, v.route.Produces); str == "" && len(v.route.Produces) > 0 {
127 v.request = rCtx
128 v.result = append(v.result, errors.InvalidResponseFormat(v.request.Header.Get(runtime.HeaderAccept), v.route.Produces))
129 }
130 }
131
View as plain text