1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package middleware
16
17 import (
18 "bytes"
19 stdcontext "context"
20 "net/http"
21 "net/http/httptest"
22 "strings"
23 "testing"
24
25 "github.com/go-openapi/errors"
26 "github.com/go-openapi/runtime"
27 "github.com/go-openapi/runtime/internal/testing/petstore"
28 "github.com/stretchr/testify/assert"
29 "github.com/stretchr/testify/require"
30 )
31
32 func newTestValidation(ctx *Context, next http.Handler) http.Handler {
33 return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
34 matched, rCtx, _ := ctx.RouteInfo(r)
35 if rCtx != nil {
36 r = rCtx
37 }
38 if matched == nil {
39 ctx.NotFound(rw, r)
40 return
41 }
42 _, r, result := ctx.BindAndValidate(r, matched)
43
44 if result != nil {
45 ctx.Respond(rw, r, matched.Produces, matched, result)
46 return
47 }
48
49 next.ServeHTTP(rw, r)
50 })
51 }
52
53 func TestContentTypeValidation(t *testing.T) {
54 spec, api := petstore.NewAPI(t)
55 context := NewContext(spec, api, nil)
56 context.router = DefaultRouter(spec, context.api)
57
58 mw := newTestValidation(context, http.HandlerFunc(terminator))
59
60 recorder := httptest.NewRecorder()
61 request, _ := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, "/api/pets", nil)
62 request.Header.Add("Accept", "*/*")
63 mw.ServeHTTP(recorder, request)
64 assert.Equal(t, http.StatusOK, recorder.Code)
65
66 recorder = httptest.NewRecorder()
67 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", nil)
68 request.Header.Add("content-type", "application(")
69 request.Header.Add("Accept", "application/json")
70 request.ContentLength = 1
71
72 mw.ServeHTTP(recorder, request)
73 assert.Equal(t, http.StatusBadRequest, recorder.Code)
74 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
75
76 recorder = httptest.NewRecorder()
77 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", nil)
78 request.Header.Add("Accept", "application/json")
79 request.Header.Add("content-type", "text/html")
80 request.ContentLength = 1
81
82 mw.ServeHTTP(recorder, request)
83 assert.Equal(t, http.StatusUnsupportedMediaType, recorder.Code)
84 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
85
86 recorder = httptest.NewRecorder()
87 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", strings.NewReader(`{"name":"dog"}`))
88 request.Header.Add("Accept", "application/json")
89 request.Header.Add("content-type", "text/html")
90 request.TransferEncoding = []string{"chunked"}
91
92 mw.ServeHTTP(recorder, request)
93 assert.Equal(t, http.StatusUnsupportedMediaType, recorder.Code)
94 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
95
96 recorder = httptest.NewRecorder()
97 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", nil)
98 request.Header.Add("Accept", "application/json+special")
99 request.Header.Add("content-type", "text/html")
100
101 mw.ServeHTTP(recorder, request)
102 assert.Equal(t, 406, recorder.Code)
103 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
104
105
106 recorder = httptest.NewRecorder()
107 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", nil)
108 request.Header.Add("Accept", "application/json")
109 request.Header.Add("content-type", "application/json+special")
110 request.ContentLength = 1
111
112 mw.ServeHTTP(recorder, request)
113 assert.Equal(t, 415, recorder.Code)
114 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
115
116
117 recorder = httptest.NewRecorder()
118 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", nil)
119 request.Header.Add("Accept", "application/json")
120 request.ContentLength = 1
121
122 mw.ServeHTTP(recorder, request)
123 assert.Equal(t, 415, recorder.Code)
124 assert.Equal(t, "application/json", recorder.Header().Get("content-type"))
125 }
126
127 func TestResponseFormatValidation(t *testing.T) {
128 spec, api := petstore.NewAPI(t)
129 context := NewContext(spec, api, nil)
130 context.router = DefaultRouter(spec, context.api)
131 mw := newTestValidation(context, http.HandlerFunc(terminator))
132
133 recorder := httptest.NewRecorder()
134 request, _ := http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", bytes.NewBufferString(`name: Dog`))
135 request.Header.Set(runtime.HeaderContentType, "application/x-yaml")
136 request.Header.Set(runtime.HeaderAccept, "application/x-yaml")
137
138 mw.ServeHTTP(recorder, request)
139 assert.Equal(t, 200, recorder.Code, recorder.Body.String())
140
141 recorder = httptest.NewRecorder()
142 request, _ = http.NewRequestWithContext(stdcontext.Background(), http.MethodPost, "/api/pets", bytes.NewBufferString(`name: Dog`))
143 request.Header.Set(runtime.HeaderContentType, "application/x-yaml")
144 request.Header.Set(runtime.HeaderAccept, "application/sml")
145
146 mw.ServeHTTP(recorder, request)
147 assert.Equal(t, http.StatusNotAcceptable, recorder.Code)
148 }
149
150 func TestValidateContentType(t *testing.T) {
151 data := []struct {
152 hdr string
153 allowed []string
154 err *errors.Validation
155 }{
156 {"application/json", []string{"application/json"}, nil},
157 {"application/json", []string{"application/x-yaml", "text/html"}, errors.InvalidContentType("application/json", []string{"application/x-yaml", "text/html"})},
158 {"text/html; charset=utf-8", []string{"text/html"}, nil},
159 {"text/html;charset=utf-8", []string{"text/html"}, nil},
160 {"", []string{"application/json"}, errors.InvalidContentType("", []string{"application/json"})},
161 {"text/html; charset=utf-8", []string{"application/json"}, errors.InvalidContentType("text/html; charset=utf-8", []string{"application/json"})},
162 {"application(", []string{"application/json"}, errors.InvalidContentType("application(", []string{"application/json"})},
163 {"application/json;char*", []string{"application/json"}, errors.InvalidContentType("application/json;char*", []string{"application/json"})},
164 {"application/octet-stream", []string{"image/jpeg", "application/*"}, nil},
165 {"image/png", []string{"*/*", "application/json"}, nil},
166 }
167
168 for _, v := range data {
169 err := validateContentType(v.allowed, v.hdr)
170 if v.err == nil {
171 require.NoError(t, err, "input: %q", v.hdr)
172 } else {
173 require.Error(t, err, "input: %q", v.hdr)
174 assert.IsType(t, &errors.Validation{}, err, "input: %q", v.hdr)
175 require.EqualErrorf(t, err, v.err.Error(), "input: %q", v.hdr)
176 assert.EqualValues(t, http.StatusUnsupportedMediaType, err.(*errors.Validation).Code())
177 }
178 }
179 }
180
View as plain text