...
1 package matchers
2
3 import (
4 "fmt"
5 "net/http"
6 "net/http/httptest"
7
8 "github.com/onsi/gomega/format"
9 "github.com/onsi/gomega/types"
10 )
11
12 type HaveHTTPHeaderWithValueMatcher struct {
13 Header string
14 Value interface{}
15 }
16
17 func (matcher *HaveHTTPHeaderWithValueMatcher) Match(actual interface{}) (success bool, err error) {
18 headerValue, err := matcher.extractHeader(actual)
19 if err != nil {
20 return false, err
21 }
22
23 headerMatcher, err := matcher.getSubMatcher()
24 if err != nil {
25 return false, err
26 }
27
28 return headerMatcher.Match(headerValue)
29 }
30
31 func (matcher *HaveHTTPHeaderWithValueMatcher) FailureMessage(actual interface{}) string {
32 headerValue, err := matcher.extractHeader(actual)
33 if err != nil {
34 panic(err)
35 }
36
37 headerMatcher, err := matcher.getSubMatcher()
38 if err != nil {
39 panic(err)
40 }
41
42 diff := format.IndentString(headerMatcher.FailureMessage(headerValue), 1)
43 return fmt.Sprintf("HTTP header %q:\n%s", matcher.Header, diff)
44 }
45
46 func (matcher *HaveHTTPHeaderWithValueMatcher) NegatedFailureMessage(actual interface{}) (message string) {
47 headerValue, err := matcher.extractHeader(actual)
48 if err != nil {
49 panic(err)
50 }
51
52 headerMatcher, err := matcher.getSubMatcher()
53 if err != nil {
54 panic(err)
55 }
56
57 diff := format.IndentString(headerMatcher.NegatedFailureMessage(headerValue), 1)
58 return fmt.Sprintf("HTTP header %q:\n%s", matcher.Header, diff)
59 }
60
61 func (matcher *HaveHTTPHeaderWithValueMatcher) getSubMatcher() (types.GomegaMatcher, error) {
62 switch m := matcher.Value.(type) {
63 case string:
64 return &EqualMatcher{Expected: matcher.Value}, nil
65 case types.GomegaMatcher:
66 return m, nil
67 default:
68 return nil, fmt.Errorf("HaveHTTPHeaderWithValue matcher must be passed a string or a GomegaMatcher. Got:\n%s", format.Object(matcher.Value, 1))
69 }
70 }
71
72 func (matcher *HaveHTTPHeaderWithValueMatcher) extractHeader(actual interface{}) (string, error) {
73 switch r := actual.(type) {
74 case *http.Response:
75 return r.Header.Get(matcher.Header), nil
76 case *httptest.ResponseRecorder:
77 return r.Result().Header.Get(matcher.Header), nil
78 default:
79 return "", fmt.Errorf("HaveHTTPHeaderWithValue matcher expects *http.Response or *httptest.ResponseRecorder. Got:\n%s", format.Object(actual, 1))
80 }
81 }
82
View as plain text