package utils import ( "encoding/json" "encoding/xml" "errors" "io" "net/http" "net/http/httptest" "reflect" "strings" "sync" "sync/atomic" ) const ( formContentType = "application/x-www-form-urlencoded" jsonContentType = "application/json" xmlContentType = "application/xml" contentTypeHeader = "Content-Type" acceptHeader = "Accept" ) // MockHTTPTestServer represents a mock http server. type MockHTTPTestServer struct { Server *httptest.Server AllowedContentTypes map[string]bool Routes []*Route NotFound func(http.ResponseWriter, *http.Request) wg sync.WaitGroup } // Route represents a mock hhtp server route. type Route struct { Name string Method string Path string responseWriter http.ResponseWriter request *http.Request Callback func(http.ResponseWriter, *http.Request) AssertionFuncOverride func(http.ResponseWriter, *http.Request) bool } // NewMockHTTPTestServer returns a new mock http server. func NewMockHTTPTestServer() *MockHTTPTestServer { mockServer := &MockHTTPTestServer{ Routes: make([]*Route, 0), AllowedContentTypes: make(map[string]bool, 0), } mockServer.Server = NewTestServer(mockServer) return mockServer } // AddTestServer attaches a http test server to the mock server. func NewTestServer(m *MockHTTPTestServer) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, exists := m.AllowedContentTypes[r.Header.Get("Content-Type")]; !exists { WriteBadResponse(w, nil) return } if len(m.Routes) > 0 { //nolint complexity var found uint64 for _, route := range m.Routes { m.wg.Add(1) route.responseWriter = w route.request = r go func(route *Route) { defer m.wg.Done() if route.AssertionFuncOverride != nil { if route.AssertionFuncOverride(w, r) { route.Callback(w, r) atomic.AddUint64(&found, 1) return } } else { if route.Path == r.URL.String() && (route.Method == r.Method) { route.Callback(w, r) atomic.AddUint64(&found, 1) return } } }(route) } m.wg.Wait() if found > 0 { return } if m.NotFound != nil { m.NotFound(w, r) } else { WriteCustomResponse(w, http.StatusNotFound, nil) } return } })) } // SetAllowedContentType sets the allowed content types for the mock server. func (m *MockHTTPTestServer) SetAllowedContentType(contentType string) *MockHTTPTestServer { m.AllowedContentTypes[contentType] = true return m } // AddAllowedContentType adds a content type to the allowed content types list. func (m *MockHTTPTestServer) AddAllowedContentType(contentTypes ...string) *MockHTTPTestServer { for _, contentType := range contentTypes { m.AllowedContentTypes[contentType] = true } return m } // AddNotFound adds a route that is called if none of the provided routes matches. func (m *MockHTTPTestServer) AddNotFound(notFound func(http.ResponseWriter, *http.Request)) *MockHTTPTestServer { m.NotFound = notFound return m } // DefaultNotFound adds a default route that is called if none of the provided routes matches. func (m *MockHTTPTestServer) DefaultNotFound() *MockHTTPTestServer { m.NotFound = NotFound return m } // AddRoute adds a new route to the mock http server. func (m *MockHTTPTestServer) AddRoute(route *Route) *MockHTTPTestServer { m.Routes = append(m.Routes, route) return m } // Post adds a post method route to the mock http server. func (m *MockHTTPTestServer) Post(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Method: http.MethodPost, Path: path, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // Put adds a put method route to the mock http server. func (m *MockHTTPTestServer) Put(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Method: http.MethodPut, Path: path, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // Patch adds a Patch method route to the mock http server. func (m *MockHTTPTestServer) Patch(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Method: http.MethodPatch, Path: path, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // Get adds a Get method route to the mock http server. func (m *MockHTTPTestServer) Get(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Method: http.MethodGet, Path: path, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // Head adds a Head method route to the mock http server. func (m *MockHTTPTestServer) Head(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Method: http.MethodHead, Path: path, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // Any adds a route to the mock http server that accepts any method and matches based on the supplied assert function override. func (m *MockHTTPTestServer) Any(name string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) { m.Routes = append(m.Routes, &Route{ Name: name, Callback: callback, AssertionFuncOverride: assertionFuncOverride, }) } // SetRoutes sets the list of supported routes for the mock http server. func (m *MockHTTPTestServer) SetRoutes(route []*Route) *MockHTTPTestServer { m.Routes = route return m } // Close closes the mock http server from connections. func (m *MockHTTPTestServer) Close() { m.Server.Close() } // NewRoute returns a new route. func NewRoute() *Route { return &Route{} } // GetResponseWriter returns a http response writer for the route. func (m *Route) GetResponseWriter() http.ResponseWriter { return m.responseWriter } // GetRequest returns the http request for the route. func (m *Route) GetRequest() *http.Request { return m.request } // WriteBadResponse returns a http 500 bad request status for the route specified. func WriteBadResponse(w http.ResponseWriter, body []byte) { w.WriteHeader(http.StatusBadRequest) if len(body) == 0 { body = []byte("Bad Request") } _, _ = w.Write(body) } // WriteOkResponse returns a http 200 ok status for the route specified. func WriteOkResponse(w http.ResponseWriter, body []byte) { w.WriteHeader(http.StatusOK) _, _ = w.Write(body) } // WriteCustomResponse returns the specified status code and body. func WriteCustomResponse(w http.ResponseWriter, statusCode int, response []byte) { w.WriteHeader(statusCode) _, _ = w.Write(response) } // FromJSON unmarshals the json data into the receiver. // Note: receiver must be a pointer. func FromJSON(data []byte, receiver interface{}) error { return json.Unmarshal(data, receiver) } // FromXML unmarshals the xml data into the receiver. // Note: receiver must be a pointer. func FromXML(data []byte, receiver interface{}) error { return xml.Unmarshal(data, receiver) } // ReadRequestBody reads the http request body into the receiver // Note: receiver must be a pointer. func ReadRequestBody(r *http.Request, receiver interface{}) error { body, err := io.ReadAll(r.Body) if err != nil { return err } return ReadByContentType(r, body, receiver) } // ReadByContentType reads the request body into the receiver // Supports JSON and x-www-form-urlencoded. func ReadByContentType(r *http.Request, body []byte, receiver interface{}) error { if len(r.Header[contentTypeHeader]) == 0 { return errors.New("content type header is empty") } switch r.Header[contentTypeHeader][0] { case jsonContentType: return readJSON(body, receiver) case xmlContentType: return readXML(body, receiver) case formContentType: return readFormData(body, receiver) default: return errors.New("content type not supported") } } // readFormData reads the x-www-form-urlencoded body into receiver as a map[string]string func readFormData(body []byte, receiver interface{}) error { if reflect.ValueOf(receiver).Kind() != reflect.Map { return errors.New("readFormData receiver must be a map") } details := receiver.(map[string]string) form := strings.Split(string(body), "&") for _, formVal := range form { formDetails := strings.Split(formVal, "=") details[formDetails[0]] = formDetails[1] } return nil } // readJSON reads the json body into receiver. func readJSON(body []byte, receiver interface{}) error { return FromJSON(body, receiver) } // readXML reads the xml body into receiver. func readXML(body []byte, receiver interface{}) error { return FromXML(body, receiver) } // ToJSON marshals the data to json. func ToJSON(data interface{}) ([]byte, error) { return json.Marshal(data) } // WriteJSON writes json data to http response. func WriteJSON(w http.ResponseWriter, data interface{}) { responseBody, err := ToJSON(data) if err != nil { WriteCustomResponse(w, http.StatusInternalServerError, []byte("Internal Server Error")) return } w.Header().Add(contentTypeHeader, jsonContentType) _, _ = w.Write(responseBody) } // NotFound returns a 404 Not Found response. func NotFound(w http.ResponseWriter, _ *http.Request) { WriteCustomResponse(w, http.StatusNotFound, []byte("404 Not Found")) }