1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package appconfig
16
17 import (
18 "bytes"
19 "encoding/base64"
20 "fmt"
21 "io/ioutil"
22 "net/http"
23 "strings"
24
25 "github.com/pkg/errors"
26 "gopkg.in/yaml.v2"
27 )
28
29 type RequestMatcher interface {
30 Matches(r *http.Request, body []byte) bool
31 }
32
33 type ExactPathMatcher string
34
35 func (m ExactPathMatcher) Matches(r *http.Request, body []byte) bool {
36 return r.URL.Path == string(m)
37 }
38
39 type Rule struct {
40 Matcher RequestMatcher
41 Count int
42
43 responses []SavedResponse
44 err error
45 }
46
47 type ResponsePlayer struct {
48 Rules []*Rule
49 }
50
51 func (rp *ResponsePlayer) AddRule(matcher RequestMatcher, file string) *Rule {
52 rule := &Rule{Matcher: matcher}
53 rp.Rules = append(rp.Rules, rule)
54
55 d, err := ioutil.ReadFile(file)
56 if err != nil {
57 rule.err = errors.Wrapf(err, "failed to read response file: %s", file)
58 return rule
59 }
60
61 if err := yaml.Unmarshal(d, &rule.responses); err != nil {
62 rule.err = errors.Wrapf(err, "failed to unmarshal response file: %s", file)
63 return rule
64 }
65
66 return rule
67 }
68
69 type SavedResponse struct {
70 Status int `yaml:"status"`
71 Headers map[string]string `yaml:"headers"`
72 Body string `yaml:"body"`
73 Binary bool
74 }
75
76 func (r *SavedResponse) Response(req *http.Request) *http.Response {
77 header := make(http.Header)
78 for k, v := range r.Headers {
79 header.Add(k, v)
80 }
81
82 var body []byte
83 if r.Binary {
84 b, err := base64.StdEncoding.DecodeString(r.Body)
85 if err != nil {
86 panic("invalid base64 encoded binary body")
87 }
88 body = b
89 } else {
90 body = []byte(r.Body)
91 }
92
93 return &http.Response{
94 Status: http.StatusText(r.Status),
95 StatusCode: r.Status,
96 Proto: "HTTP/1.1",
97 ProtoMajor: 1,
98 ProtoMinor: 1,
99
100 Header: header,
101 Body: ioutil.NopCloser(bytes.NewReader(body)),
102 ContentLength: int64(len(body)),
103
104 Request: req,
105 }
106 }
107
108 func (rp *ResponsePlayer) findMatch(req *http.Request) *Rule {
109 var body []byte
110 if req.Body != nil {
111 body, _ = ioutil.ReadAll(req.Body)
112 _ = req.Body.Close()
113 }
114
115 for _, rule := range rp.Rules {
116 if rule.Matcher.Matches(req, body) {
117 return rule
118 }
119 }
120 return nil
121 }
122
123 func (rp *ResponsePlayer) RoundTrip(req *http.Request) (*http.Response, error) {
124 rule := rp.findMatch(req)
125 if rule == nil {
126 return errorResponse(req, http.StatusGone, fmt.Sprintf("no matching rule for \"%s %s\"", req.Method, req.URL.Path))
127 }
128
129
130 if rule.err != nil {
131 return nil, rule.err
132 }
133
134
135 if len(rule.responses) == 0 {
136 return errorResponse(req, http.StatusGone, fmt.Sprintf("no responses for \"%s %s\"", req.Method, req.URL.Path))
137 }
138
139 index := rule.Count % len(rule.responses)
140 rule.Count++
141
142 return rule.responses[index].Response(req), nil
143 }
144
145 func errorResponse(req *http.Request, code int, msg string) (*http.Response, error) {
146 body := strings.NewReader(msg)
147
148 return &http.Response{
149 Status: http.StatusText(code),
150 StatusCode: code,
151 Proto: "HTTP/1.1",
152 ProtoMajor: 1,
153 ProtoMinor: 1,
154
155 Header: make(http.Header),
156 Body: ioutil.NopCloser(body),
157 ContentLength: body.Size(),
158
159 Request: req,
160 }, nil
161 }
162
View as plain text