...

Source file src/edge-infra.dev/pkg/edge/api/utils/mock_server.go

Documentation: edge-infra.dev/pkg/edge/api/utils

     1  package utils
     2  
     3  import (
     4  	"encoding/json"
     5  	"encoding/xml"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"reflect"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  const (
    17  	formContentType   = "application/x-www-form-urlencoded"
    18  	jsonContentType   = "application/json"
    19  	xmlContentType    = "application/xml"
    20  	contentTypeHeader = "Content-Type"
    21  	acceptHeader      = "Accept"
    22  )
    23  
    24  // MockHTTPTestServer represents a mock http server.
    25  type MockHTTPTestServer struct {
    26  	Server              *httptest.Server
    27  	AllowedContentTypes map[string]bool
    28  	Routes              []*Route
    29  	NotFound            func(http.ResponseWriter, *http.Request)
    30  	wg                  sync.WaitGroup
    31  }
    32  
    33  // Route represents a mock hhtp server route.
    34  type Route struct {
    35  	Name                  string
    36  	Method                string
    37  	Path                  string
    38  	responseWriter        http.ResponseWriter
    39  	request               *http.Request
    40  	Callback              func(http.ResponseWriter, *http.Request)
    41  	AssertionFuncOverride func(http.ResponseWriter, *http.Request) bool
    42  }
    43  
    44  // NewMockHTTPTestServer returns a new mock http server.
    45  func NewMockHTTPTestServer() *MockHTTPTestServer {
    46  	mockServer := &MockHTTPTestServer{
    47  		Routes:              make([]*Route, 0),
    48  		AllowedContentTypes: make(map[string]bool, 0),
    49  	}
    50  	mockServer.Server = NewTestServer(mockServer)
    51  	return mockServer
    52  }
    53  
    54  // AddTestServer attaches a http test server to the mock server.
    55  func NewTestServer(m *MockHTTPTestServer) *httptest.Server {
    56  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		if _, exists := m.AllowedContentTypes[r.Header.Get("Content-Type")]; !exists {
    58  			WriteBadResponse(w, nil)
    59  			return
    60  		}
    61  		if len(m.Routes) > 0 { //nolint complexity
    62  			var found uint64
    63  			for _, route := range m.Routes {
    64  				m.wg.Add(1)
    65  				route.responseWriter = w
    66  				route.request = r
    67  				go func(route *Route) {
    68  					defer m.wg.Done()
    69  					if route.AssertionFuncOverride != nil {
    70  						if route.AssertionFuncOverride(w, r) {
    71  							route.Callback(w, r)
    72  							atomic.AddUint64(&found, 1)
    73  							return
    74  						}
    75  					} else {
    76  						if route.Path == r.URL.String() && (route.Method == r.Method) {
    77  							route.Callback(w, r)
    78  							atomic.AddUint64(&found, 1)
    79  							return
    80  						}
    81  					}
    82  				}(route)
    83  			}
    84  			m.wg.Wait()
    85  			if found > 0 {
    86  				return
    87  			}
    88  			if m.NotFound != nil {
    89  				m.NotFound(w, r)
    90  			} else {
    91  				WriteCustomResponse(w, http.StatusNotFound, nil)
    92  			}
    93  			return
    94  		}
    95  	}))
    96  }
    97  
    98  // SetAllowedContentType sets the allowed content types for the mock server.
    99  func (m *MockHTTPTestServer) SetAllowedContentType(contentType string) *MockHTTPTestServer {
   100  	m.AllowedContentTypes[contentType] = true
   101  	return m
   102  }
   103  
   104  // AddAllowedContentType adds a content type to the allowed content types list.
   105  func (m *MockHTTPTestServer) AddAllowedContentType(contentTypes ...string) *MockHTTPTestServer {
   106  	for _, contentType := range contentTypes {
   107  		m.AllowedContentTypes[contentType] = true
   108  	}
   109  	return m
   110  }
   111  
   112  // AddNotFound adds a route that is called if none of the provided routes matches.
   113  func (m *MockHTTPTestServer) AddNotFound(notFound func(http.ResponseWriter, *http.Request)) *MockHTTPTestServer {
   114  	m.NotFound = notFound
   115  	return m
   116  }
   117  
   118  // DefaultNotFound adds a default route that is called if none of the provided routes matches.
   119  func (m *MockHTTPTestServer) DefaultNotFound() *MockHTTPTestServer {
   120  	m.NotFound = NotFound
   121  	return m
   122  }
   123  
   124  // AddRoute adds a new route to the mock http server.
   125  func (m *MockHTTPTestServer) AddRoute(route *Route) *MockHTTPTestServer {
   126  	m.Routes = append(m.Routes, route)
   127  	return m
   128  }
   129  
   130  // Post adds a post method route to the mock http server.
   131  func (m *MockHTTPTestServer) Post(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   132  	m.Routes = append(m.Routes, &Route{
   133  		Name:                  name,
   134  		Method:                http.MethodPost,
   135  		Path:                  path,
   136  		Callback:              callback,
   137  		AssertionFuncOverride: assertionFuncOverride,
   138  	})
   139  }
   140  
   141  // Put adds a put method route to the mock http server.
   142  func (m *MockHTTPTestServer) Put(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   143  	m.Routes = append(m.Routes, &Route{
   144  		Name:                  name,
   145  		Method:                http.MethodPut,
   146  		Path:                  path,
   147  		Callback:              callback,
   148  		AssertionFuncOverride: assertionFuncOverride,
   149  	})
   150  }
   151  
   152  // Patch adds a Patch method route to the mock http server.
   153  func (m *MockHTTPTestServer) Patch(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   154  	m.Routes = append(m.Routes, &Route{
   155  		Name:                  name,
   156  		Method:                http.MethodPatch,
   157  		Path:                  path,
   158  		Callback:              callback,
   159  		AssertionFuncOverride: assertionFuncOverride,
   160  	})
   161  }
   162  
   163  // Get adds a Get method route to the mock http server.
   164  func (m *MockHTTPTestServer) Get(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   165  	m.Routes = append(m.Routes, &Route{
   166  		Name:                  name,
   167  		Method:                http.MethodGet,
   168  		Path:                  path,
   169  		Callback:              callback,
   170  		AssertionFuncOverride: assertionFuncOverride,
   171  	})
   172  }
   173  
   174  // Head adds a Head method route to the mock http server.
   175  func (m *MockHTTPTestServer) Head(name, path string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   176  	m.Routes = append(m.Routes, &Route{
   177  		Name:                  name,
   178  		Method:                http.MethodHead,
   179  		Path:                  path,
   180  		Callback:              callback,
   181  		AssertionFuncOverride: assertionFuncOverride,
   182  	})
   183  }
   184  
   185  // Any adds a route to the mock http server that accepts any method and matches based on the supplied assert function override.
   186  func (m *MockHTTPTestServer) Any(name string, callback func(w http.ResponseWriter, r *http.Request), assertionFuncOverride func(w http.ResponseWriter, r *http.Request) bool) {
   187  	m.Routes = append(m.Routes, &Route{
   188  		Name:                  name,
   189  		Callback:              callback,
   190  		AssertionFuncOverride: assertionFuncOverride,
   191  	})
   192  }
   193  
   194  // SetRoutes sets the list of supported routes for the mock http server.
   195  func (m *MockHTTPTestServer) SetRoutes(route []*Route) *MockHTTPTestServer {
   196  	m.Routes = route
   197  	return m
   198  }
   199  
   200  // Close closes the mock http server from connections.
   201  func (m *MockHTTPTestServer) Close() {
   202  	m.Server.Close()
   203  }
   204  
   205  // NewRoute returns a new route.
   206  func NewRoute() *Route {
   207  	return &Route{}
   208  }
   209  
   210  // GetResponseWriter returns a http response writer for the route.
   211  func (m *Route) GetResponseWriter() http.ResponseWriter {
   212  	return m.responseWriter
   213  }
   214  
   215  // GetRequest returns the http request for the route.
   216  func (m *Route) GetRequest() *http.Request {
   217  	return m.request
   218  }
   219  
   220  // WriteBadResponse returns a http 500 bad request status for the route specified.
   221  func WriteBadResponse(w http.ResponseWriter, body []byte) {
   222  	w.WriteHeader(http.StatusBadRequest)
   223  	if len(body) == 0 {
   224  		body = []byte("Bad Request")
   225  	}
   226  	_, _ = w.Write(body)
   227  }
   228  
   229  // WriteOkResponse returns a http 200 ok status for the route specified.
   230  func WriteOkResponse(w http.ResponseWriter, body []byte) {
   231  	w.WriteHeader(http.StatusOK)
   232  	_, _ = w.Write(body)
   233  }
   234  
   235  // WriteCustomResponse returns the specified status code and body.
   236  func WriteCustomResponse(w http.ResponseWriter, statusCode int, response []byte) {
   237  	w.WriteHeader(statusCode)
   238  	_, _ = w.Write(response)
   239  }
   240  
   241  // FromJSON unmarshals the json data into the receiver.
   242  // Note: receiver must be a pointer.
   243  func FromJSON(data []byte, receiver interface{}) error {
   244  	return json.Unmarshal(data, receiver)
   245  }
   246  
   247  // FromXML unmarshals the xml data into the receiver.
   248  // Note: receiver must be a pointer.
   249  func FromXML(data []byte, receiver interface{}) error {
   250  	return xml.Unmarshal(data, receiver)
   251  }
   252  
   253  // ReadRequestBody reads the http request body into the receiver
   254  // Note: receiver must be a pointer.
   255  func ReadRequestBody(r *http.Request, receiver interface{}) error {
   256  	body, err := io.ReadAll(r.Body)
   257  	if err != nil {
   258  		return err
   259  	}
   260  	return ReadByContentType(r, body, receiver)
   261  }
   262  
   263  // ReadByContentType reads the request body into the receiver
   264  // Supports JSON and x-www-form-urlencoded.
   265  func ReadByContentType(r *http.Request, body []byte, receiver interface{}) error {
   266  	if len(r.Header[contentTypeHeader]) == 0 {
   267  		return errors.New("content type header is empty")
   268  	}
   269  	switch r.Header[contentTypeHeader][0] {
   270  	case jsonContentType:
   271  		return readJSON(body, receiver)
   272  	case xmlContentType:
   273  		return readXML(body, receiver)
   274  	case formContentType:
   275  		return readFormData(body, receiver)
   276  	default:
   277  		return errors.New("content type not supported")
   278  	}
   279  }
   280  
   281  // readFormData reads the x-www-form-urlencoded body into receiver as a map[string]string
   282  func readFormData(body []byte, receiver interface{}) error {
   283  	if reflect.ValueOf(receiver).Kind() != reflect.Map {
   284  		return errors.New("readFormData receiver must be a map")
   285  	}
   286  	details := receiver.(map[string]string)
   287  	form := strings.Split(string(body), "&")
   288  	for _, formVal := range form {
   289  		formDetails := strings.Split(formVal, "=")
   290  		details[formDetails[0]] = formDetails[1]
   291  	}
   292  	return nil
   293  }
   294  
   295  // readJSON reads the json body into receiver.
   296  func readJSON(body []byte, receiver interface{}) error {
   297  	return FromJSON(body, receiver)
   298  }
   299  
   300  // readXML reads the xml body into receiver.
   301  func readXML(body []byte, receiver interface{}) error {
   302  	return FromXML(body, receiver)
   303  }
   304  
   305  // ToJSON marshals the data to json.
   306  func ToJSON(data interface{}) ([]byte, error) {
   307  	return json.Marshal(data)
   308  }
   309  
   310  // WriteJSON writes json data to http response.
   311  func WriteJSON(w http.ResponseWriter, data interface{}) {
   312  	responseBody, err := ToJSON(data)
   313  	if err != nil {
   314  		WriteCustomResponse(w, http.StatusInternalServerError, []byte("Internal Server Error"))
   315  		return
   316  	}
   317  	w.Header().Add(contentTypeHeader, jsonContentType)
   318  	_, _ = w.Write(responseBody)
   319  }
   320  
   321  // NotFound returns a 404 Not Found response.
   322  func NotFound(w http.ResponseWriter, _ *http.Request) {
   323  	WriteCustomResponse(w, http.StatusNotFound, []byte("404 Not Found"))
   324  }
   325  

View as plain text