...
1 package matchers
2
3 import (
4 "fmt"
5 "net/http"
6 "net/http/httptest"
7 "reflect"
8 "strings"
9
10 "github.com/onsi/gomega/format"
11 "github.com/onsi/gomega/internal/gutil"
12 )
13
14 type HaveHTTPStatusMatcher struct {
15 Expected []interface{}
16 }
17
18 func (matcher *HaveHTTPStatusMatcher) Match(actual interface{}) (success bool, err error) {
19 var resp *http.Response
20 switch a := actual.(type) {
21 case *http.Response:
22 resp = a
23 case *httptest.ResponseRecorder:
24 resp = a.Result()
25 default:
26 return false, fmt.Errorf("HaveHTTPStatus matcher expects *http.Response or *httptest.ResponseRecorder. Got:\n%s", format.Object(actual, 1))
27 }
28
29 if len(matcher.Expected) == 0 {
30 return false, fmt.Errorf("HaveHTTPStatus matcher must be passed an int or a string. Got nothing")
31 }
32
33 for _, expected := range matcher.Expected {
34 switch e := expected.(type) {
35 case int:
36 if resp.StatusCode == e {
37 return true, nil
38 }
39 case string:
40 if resp.Status == e {
41 return true, nil
42 }
43 default:
44 return false, fmt.Errorf("HaveHTTPStatus matcher must be passed int or string types. Got:\n%s", format.Object(expected, 1))
45 }
46 }
47
48 return false, nil
49 }
50
51 func (matcher *HaveHTTPStatusMatcher) FailureMessage(actual interface{}) (message string) {
52 return fmt.Sprintf("Expected\n%s\n%s\n%s", formatHttpResponse(actual), "to have HTTP status", matcher.expectedString())
53 }
54
55 func (matcher *HaveHTTPStatusMatcher) NegatedFailureMessage(actual interface{}) (message string) {
56 return fmt.Sprintf("Expected\n%s\n%s\n%s", formatHttpResponse(actual), "not to have HTTP status", matcher.expectedString())
57 }
58
59 func (matcher *HaveHTTPStatusMatcher) expectedString() string {
60 var lines []string
61 for _, expected := range matcher.Expected {
62 lines = append(lines, format.Object(expected, 1))
63 }
64 return strings.Join(lines, "\n")
65 }
66
67 func formatHttpResponse(input interface{}) string {
68 var resp *http.Response
69 switch r := input.(type) {
70 case *http.Response:
71 resp = r
72 case *httptest.ResponseRecorder:
73 resp = r.Result()
74 default:
75 return "cannot format invalid HTTP response"
76 }
77
78 body := "<nil>"
79 if resp.Body != nil {
80 defer resp.Body.Close()
81 data, err := gutil.ReadAll(resp.Body)
82 if err != nil {
83 data = []byte("<error reading body>")
84 }
85 body = format.Object(string(data), 0)
86 }
87
88 var s strings.Builder
89 s.WriteString(fmt.Sprintf("%s<%s>: {\n", format.Indent, reflect.TypeOf(input)))
90 s.WriteString(fmt.Sprintf("%s%sStatus: %s\n", format.Indent, format.Indent, format.Object(resp.Status, 0)))
91 s.WriteString(fmt.Sprintf("%s%sStatusCode: %s\n", format.Indent, format.Indent, format.Object(resp.StatusCode, 0)))
92 s.WriteString(fmt.Sprintf("%s%sBody: %s\n", format.Indent, format.Indent, body))
93 s.WriteString(fmt.Sprintf("%s}", format.Indent))
94
95 return s.String()
96 }
97
View as plain text