...
1
16
17 package testing
18
19 import (
20 "io"
21 "net/http"
22 "net/url"
23 "reflect"
24 "sync"
25 )
26
27
28
29 type TestInterface interface {
30 Errorf(format string, args ...interface{})
31 Logf(format string, args ...interface{})
32 }
33
34
35 type LogInterface interface {
36 Logf(format string, args ...interface{})
37 }
38
39
40
41
42 type FakeHandler struct {
43 RequestReceived *http.Request
44 RequestBody string
45 StatusCode int
46 ResponseBody string
47
48
49 T LogInterface
50
51
52 lock sync.Mutex
53 requestCount int
54 hasBeenChecked bool
55
56 SkipRequestFn func(verb string, url url.URL) bool
57 }
58
59 func (f *FakeHandler) SetResponseBody(responseBody string) {
60 f.lock.Lock()
61 defer f.lock.Unlock()
62 f.ResponseBody = responseBody
63 }
64
65 func (f *FakeHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) {
66 f.lock.Lock()
67 defer f.lock.Unlock()
68
69 if f.SkipRequestFn != nil && f.SkipRequestFn(request.Method, *request.URL) {
70 response.Header().Set("Content-Type", "application/json")
71 response.WriteHeader(f.StatusCode)
72 response.Write([]byte(f.ResponseBody))
73 return
74 }
75
76 f.requestCount++
77 if f.hasBeenChecked {
78 panic("got request after having been validated")
79 }
80
81 f.RequestReceived = request
82 response.Header().Set("Content-Type", "application/json")
83 response.WriteHeader(f.StatusCode)
84 response.Write([]byte(f.ResponseBody))
85
86 bodyReceived, err := io.ReadAll(request.Body)
87 if err != nil && f.T != nil {
88 f.T.Logf("Received read error: %v", err)
89 }
90 f.RequestBody = string(bodyReceived)
91 if f.T != nil {
92 f.T.Logf("request body: %s", f.RequestBody)
93 }
94 }
95
96 func (f *FakeHandler) ValidateRequestCount(t TestInterface, count int) bool {
97 ok := true
98 f.lock.Lock()
99 defer f.lock.Unlock()
100 if f.requestCount != count {
101 ok = false
102 t.Errorf("Expected %d call, but got %d. Only the last call is recorded and checked.", count, f.requestCount)
103 }
104 f.hasBeenChecked = true
105 return ok
106 }
107
108
109 func (f *FakeHandler) ValidateRequest(t TestInterface, expectedPath, expectedMethod string, body *string) {
110 f.lock.Lock()
111 defer f.lock.Unlock()
112 if f.requestCount != 1 {
113 t.Logf("Expected 1 call, but got %v. Only the last call is recorded and checked.", f.requestCount)
114 }
115 f.hasBeenChecked = true
116
117 expectURL, err := url.Parse(expectedPath)
118 if err != nil {
119 t.Errorf("Couldn't parse %v as a URL.", expectedPath)
120 }
121 if f.RequestReceived == nil {
122 t.Errorf("Unexpected nil request received for %s", expectedPath)
123 return
124 }
125 if f.RequestReceived.URL.Path != expectURL.Path {
126 t.Errorf("Unexpected request path for request %#v, received: %q, expected: %q", f.RequestReceived, f.RequestReceived.URL.Path, expectURL.Path)
127 }
128 if e, a := expectURL.Query(), f.RequestReceived.URL.Query(); !reflect.DeepEqual(e, a) {
129 t.Errorf("Unexpected query for request %#v, received: %q, expected: %q", f.RequestReceived, a, e)
130 }
131 if f.RequestReceived.Method != expectedMethod {
132 t.Errorf("Unexpected method: %q, expected: %q", f.RequestReceived.Method, expectedMethod)
133 }
134 if body != nil {
135 if *body != f.RequestBody {
136 t.Errorf("Received body:\n%s\n Doesn't match expected body:\n%s", f.RequestBody, *body)
137 }
138 }
139 }
140
View as plain text