...
1 package testutil
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "net/http"
9 "net/url"
10 "sort"
11 "strings"
12 )
13
14
15 type RequestResponseMap []RequestResponseMapping
16
17
18
19 type RequestResponseMapping struct {
20 Request Request
21 Response Response
22 }
23
24
25 type Request struct {
26
27 Method string
28
29
30 Route string
31
32
33 QueryParams map[string][]string
34
35
36 Body []byte
37
38
39 Headers http.Header
40 }
41
42 func (r Request) String() string {
43 queryString := ""
44 if len(r.QueryParams) > 0 {
45 keys := make([]string, 0, len(r.QueryParams))
46 queryParts := make([]string, 0, len(r.QueryParams))
47 for k := range r.QueryParams {
48 keys = append(keys, k)
49 }
50 sort.Strings(keys)
51 for _, k := range keys {
52 for _, val := range r.QueryParams[k] {
53 queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val)))
54 }
55 }
56 queryString = "?" + strings.Join(queryParts, "&")
57 }
58 var headers []string
59 if len(r.Headers) > 0 {
60 var headerKeys []string
61 for k := range r.Headers {
62 headerKeys = append(headerKeys, k)
63 }
64 sort.Strings(headerKeys)
65
66 for _, k := range headerKeys {
67 for _, val := range r.Headers[k] {
68 headers = append(headers, fmt.Sprintf("%s:%s", k, val))
69 }
70 }
71
72 }
73 return fmt.Sprintf("%s %s%s\n%s\n%s", r.Method, r.Route, queryString, headers, r.Body)
74 }
75
76
77 type Response struct {
78
79 StatusCode int
80
81
82 Headers http.Header
83
84
85 Body []byte
86 }
87
88
89
90 type testHandler struct {
91 responseMap map[string][]Response
92 }
93
94
95
96
97
98 func NewHandler(requestResponseMap RequestResponseMap) http.Handler {
99 responseMap := make(map[string][]Response)
100 for _, mapping := range requestResponseMap {
101 responses, ok := responseMap[mapping.Request.String()]
102 if ok {
103 responseMap[mapping.Request.String()] = append(responses, mapping.Response)
104 } else {
105 responseMap[mapping.Request.String()] = []Response{mapping.Response}
106 }
107 }
108 return &testHandler{responseMap: responseMap}
109 }
110
111 func (app *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112 defer r.Body.Close()
113
114 requestBody, _ := ioutil.ReadAll(r.Body)
115 request := Request{
116 Method: r.Method,
117 Route: r.URL.Path,
118 QueryParams: r.URL.Query(),
119 Body: requestBody,
120 Headers: make(map[string][]string),
121 }
122
123
124 for k, v := range r.Header {
125 if k == "If-None-Match" {
126 request.Headers[k] = v
127 }
128 }
129
130 responses, ok := app.responseMap[request.String()]
131
132 if !ok || len(responses) == 0 {
133 http.NotFound(w, r)
134 return
135 }
136
137 response := responses[0]
138 app.responseMap[request.String()] = responses[1:]
139
140 responseHeader := w.Header()
141 for k, v := range response.Headers {
142 responseHeader[k] = v
143 }
144
145 w.WriteHeader(response.StatusCode)
146
147 io.Copy(w, bytes.NewReader(response.Body))
148 }
149
View as plain text