1 package utils
2
3 import (
4 "encoding/json"
5 "encoding/xml"
6 "errors"
7 "io"
8 "net/http"
9 "net/http/httptest"
10 "reflect"
11 "strings"
12 "sync"
13 "sync/atomic"
14 )
15
16 const (
17 formContentType = "application/x-www-form-urlencoded"
18 jsonContentType = "application/json"
19 xmlContentType = "application/xml"
20 contentTypeHeader = "Content-Type"
21 acceptHeader = "Accept"
22 )
23
24
25 type MockHTTPTestServer struct {
26 Server *httptest.Server
27 AllowedContentTypes map[string]bool
28 Routes []*Route
29 NotFound func(http.ResponseWriter, *http.Request)
30 wg sync.WaitGroup
31 }
32
33
34 type Route struct {
35 Name string
36 Method string
37 Path string
38 responseWriter http.ResponseWriter
39 request *http.Request
40 Callback func(http.ResponseWriter, *http.Request)
41 AssertionFuncOverride func(http.ResponseWriter, *http.Request) bool
42 }
43
44
45 func NewMockHTTPTestServer() *MockHTTPTestServer {
46 mockServer := &MockHTTPTestServer{
47 Routes: make([]*Route, 0),
48 AllowedContentTypes: make(map[string]bool, 0),
49 }
50 mockServer.Server = NewTestServer(mockServer)
51 return mockServer
52 }
53
54
55 func NewTestServer(m *MockHTTPTestServer) *httptest.Server {
56 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57 if _, exists := m.AllowedContentTypes[r.Header.Get("Content-Type")]; !exists {
58 WriteBadResponse(w, nil)
59 return
60 }
61 if len(m.Routes) > 0 {
62 var found uint64
63 for _, route := range m.Routes {
64 m.wg.Add(1)
65 route.responseWriter = w
66 route.request = r
67 go func(route *Route) {
68 defer m.wg.Done()
69 if route.AssertionFuncOverride != nil {
70 if route.AssertionFuncOverride(w, r) {
71 route.Callback(w, r)
72 atomic.AddUint64(&found, 1)
73 return
74 }
75 } else {
76 if route.Path == r.URL.String() && (route.Method == r.Method) {
77 route.Callback(w, r)
78 atomic.AddUint64(&found, 1)
79 return
80 }
81 }
82 }(route)
83 }
84 m.wg.Wait()
85 if found > 0 {
86 return
87 }
88 if m.NotFound != nil {
89 m.NotFound(w, r)
90 } else {
91 WriteCustomResponse(w, http.StatusNotFound, nil)
92 }
93 return
94 }
95 }))
96 }
97
98
99 func (m *MockHTTPTestServer) SetAllowedContentType(contentType string) *MockHTTPTestServer {
100 m.AllowedContentTypes[contentType] = true
101 return m
102 }
103
104
105 func (m *MockHTTPTestServer) AddAllowedContentType(contentTypes ...string) *MockHTTPTestServer {
106 for _, contentType := range contentTypes {
107 m.AllowedContentTypes[contentType] = true
108 }
109 return m
110 }
111
112
113 func (m *MockHTTPTestServer) AddNotFound(notFound func(http.ResponseWriter, *http.Request)) *MockHTTPTestServer {
114 m.NotFound = notFound
115 return m
116 }
117
118
119 func (m *MockHTTPTestServer) DefaultNotFound() *MockHTTPTestServer {
120 m.NotFound = NotFound
121 return m
122 }
123
124
125 func (m *MockHTTPTestServer) AddRoute(route *Route) *MockHTTPTestServer {
126 m.Routes = append(m.Routes, route)
127 return m
128 }
129
130
131 func (m *MockHTTPTestServer) Post(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
132 m.Routes = append(m.Routes, &Route{
133 Name: name,
134 Method: http.MethodPost,
135 Path: path,
136 Callback: callback,
137 AssertionFuncOverride: assertionFuncOverride,
138 })
139 }
140
141
142 func (m *MockHTTPTestServer) Put(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
143 m.Routes = append(m.Routes, &Route{
144 Name: name,
145 Method: http.MethodPut,
146 Path: path,
147 Callback: callback,
148 AssertionFuncOverride: assertionFuncOverride,
149 })
150 }
151
152
153 func (m *MockHTTPTestServer) Patch(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
154 m.Routes = append(m.Routes, &Route{
155 Name: name,
156 Method: http.MethodPatch,
157 Path: path,
158 Callback: callback,
159 AssertionFuncOverride: assertionFuncOverride,
160 })
161 }
162
163
164 func (m *MockHTTPTestServer) Get(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
165 m.Routes = append(m.Routes, &Route{
166 Name: name,
167 Method: http.MethodGet,
168 Path: path,
169 Callback: callback,
170 AssertionFuncOverride: assertionFuncOverride,
171 })
172 }
173
174
175 func (m *MockHTTPTestServer) Head(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
176 m.Routes = append(m.Routes, &Route{
177 Name: name,
178 Method: http.MethodHead,
179 Path: path,
180 Callback: callback,
181 AssertionFuncOverride: assertionFuncOverride,
182 })
183 }
184
185
186 func (m *MockHTTPTestServer) Any(name string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
187 m.Routes = append(m.Routes, &Route{
188 Name: name,
189 Callback: callback,
190 AssertionFuncOverride: assertionFuncOverride,
191 })
192 }
193
194
195 func (m *MockHTTPTestServer) SetRoutes(route []*Route) *MockHTTPTestServer {
196 m.Routes = route
197 return m
198 }
199
200
201 func (m *MockHTTPTestServer) Close() {
202 m.Server.Close()
203 }
204
205
206 func NewRoute() *Route {
207 return &Route{}
208 }
209
210
211 func (m *Route) GetResponseWriter() http.ResponseWriter {
212 return m.responseWriter
213 }
214
215
216 func (m *Route) GetRequest() *http.Request {
217 return m.request
218 }
219
220
221 func WriteBadResponse(w http.ResponseWriter, body []byte) {
222 w.WriteHeader(http.StatusBadRequest)
223 if len(body) == 0 {
224 body = []byte("Bad Request")
225 }
226 _, _ = w.Write(body)
227 }
228
229
230 func WriteOkResponse(w http.ResponseWriter, body []byte) {
231 w.WriteHeader(http.StatusOK)
232 _, _ = w.Write(body)
233 }
234
235
236 func WriteCustomResponse(w http.ResponseWriter, statusCode int, response []byte) {
237 w.WriteHeader(statusCode)
238 _, _ = w.Write(response)
239 }
240
241
242
243 func FromJSON(data []byte, receiver interface{}) error {
244 return json.Unmarshal(data, receiver)
245 }
246
247
248
249 func FromXML(data []byte, receiver interface{}) error {
250 return xml.Unmarshal(data, receiver)
251 }
252
253
254
255 func ReadRequestBody(r *http.Request, receiver interface{}) error {
256 body, err := io.ReadAll(r.Body)
257 if err != nil {
258 return err
259 }
260 return ReadByContentType(r, body, receiver)
261 }
262
263
264
265 func ReadByContentType(r *http.Request, body []byte, receiver interface{}) error {
266 if len(r.Header[contentTypeHeader]) == 0 {
267 return errors.New("content type header is empty")
268 }
269 switch r.Header[contentTypeHeader][0] {
270 case jsonContentType:
271 return readJSON(body, receiver)
272 case xmlContentType:
273 return readXML(body, receiver)
274 case formContentType:
275 return readFormData(body, receiver)
276 default:
277 return errors.New("content type not supported")
278 }
279 }
280
281
282 func readFormData(body []byte, receiver interface{}) error {
283 if reflect.ValueOf(receiver).Kind() != reflect.Map {
284 return errors.New("readFormData receiver must be a map")
285 }
286 details := receiver.(map[string]string)
287 form := strings.Split(string(body), "&")
288 for _, formVal := range form {
289 formDetails := strings.Split(formVal, "=")
290 details[formDetails[0]] = formDetails[1]
291 }
292 return nil
293 }
294
295
296 func readJSON(body []byte, receiver interface{}) error {
297 return FromJSON(body, receiver)
298 }
299
300
301 func readXML(body []byte, receiver interface{}) error {
302 return FromXML(body, receiver)
303 }
304
305
306 func ToJSON(data interface{}) ([]byte, error) {
307 return json.Marshal(data)
308 }
309
310
311 func WriteJSON(w http.ResponseWriter, data interface{}) {
312 responseBody, err := ToJSON(data)
313 if err != nil {
314 WriteCustomResponse(w, http.StatusInternalServerError, []byte("Internal Server Error"))
315 return
316 }
317 w.Header().Add(contentTypeHeader, jsonContentType)
318 _, _ = w.Write(responseBody)
319 }
320
321
322 func NotFound(w http.ResponseWriter, _ *http.Request) {
323 WriteCustomResponse(w, http.StatusNotFound, []byte("404 Not Found"))
324 }
325
View as plain text