1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package proxy
16
17 import (
18 "bytes"
19 "encoding/json"
20 "errors"
21 "fmt"
22 "log"
23 "net/http"
24 "os"
25 "reflect"
26 "sync"
27
28 "github.com/google/martian/v3/martianlog"
29 )
30
31
32 func ForReplaying(filename string, port int) (*Proxy, error) {
33 p, err := newProxy(filename)
34 if err != nil {
35 return nil, err
36 }
37 lg, err := readLog(filename)
38 if err != nil {
39 return nil, err
40 }
41 calls, err := constructCalls(lg)
42 if err != nil {
43 return nil, err
44 }
45 p.Initial = lg.Initial
46 p.mproxy.SetRoundTripper(&replayRoundTripper{
47 calls: calls,
48 ignoreHeaders: p.ignoreHeaders,
49 conv: lg.Converter,
50 })
51
52
53
54 logger := martianlog.NewLogger()
55 logger.SetDecode(true)
56 p.mproxy.SetRequestModifier(logger)
57 p.mproxy.SetResponseModifier(logger)
58
59 if err := p.start(port); err != nil {
60 return nil, err
61 }
62 return p, nil
63 }
64
65 func readLog(filename string) (*Log, error) {
66 bytes, err := os.ReadFile(filename)
67 if err != nil {
68 return nil, err
69 }
70 var lg Log
71 if err := json.Unmarshal(bytes, &lg); err != nil {
72 return nil, fmt.Errorf("%s: %v", filename, err)
73 }
74 if lg.Version != LogVersion {
75 return nil, fmt.Errorf(
76 "httpreplay: read log version %s but current version is %s; re-record the log",
77 lg.Version, LogVersion)
78 }
79 return &lg, nil
80 }
81
82
83 type call struct {
84 req *Request
85 res *Response
86 }
87
88 func constructCalls(lg *Log) ([]*call, error) {
89 ignoreIDs := map[string]bool{}
90 callsByID := map[string]*call{}
91 var calls []*call
92 for _, e := range lg.Entries {
93 if ignoreIDs[e.ID] {
94 continue
95 }
96 c, ok := callsByID[e.ID]
97 switch {
98 case !ok:
99 if e.Request == nil {
100 return nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID)
101 }
102 if e.Request.Method == "CONNECT" {
103
104 ignoreIDs[e.ID] = true
105 } else {
106 c := &call{e.Request, e.Response}
107 calls = append(calls, c)
108 callsByID[e.ID] = c
109 }
110 case e.Request != nil:
111 if e.Response != nil {
112 return nil, errors.New("entry has both request and response")
113 }
114 c.req = e.Request
115 case e.Response != nil:
116 c.res = e.Response
117 default:
118 return nil, errors.New("entry has neither request nor response")
119 }
120 }
121 for _, c := range calls {
122 if c.req == nil || c.res == nil {
123 return nil, fmt.Errorf("missing request or response: %+v", c)
124 }
125 }
126 return calls, nil
127 }
128
129 type replayRoundTripper struct {
130 mu sync.Mutex
131 calls []*call
132 ignoreHeaders map[string]bool
133 conv *Converter
134 }
135
136 func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
137 if req.Body != nil {
138 defer req.Body.Close()
139 }
140 creq, err := r.conv.convertRequest(req)
141 if err != nil {
142 return nil, err
143 }
144 r.mu.Lock()
145 defer r.mu.Unlock()
146 for i, call := range r.calls {
147 if call == nil {
148 continue
149 }
150 if requestsMatch(creq, call.req, r.ignoreHeaders) {
151 r.calls[i] = nil
152 return toHTTPResponse(call.res, req), nil
153 }
154 }
155 return nil, fmt.Errorf("no matching request for %+v", req)
156 }
157
158
159 func requestsMatch(in, cand *Request, ignoreHeaders map[string]bool) bool {
160 if in.Method != cand.Method {
161 return false
162 }
163 if in.URL != cand.URL {
164 return false
165 }
166 if in.MediaType != cand.MediaType {
167 return false
168 }
169 if len(in.BodyParts) != len(cand.BodyParts) {
170 return false
171 }
172 for i, p1 := range in.BodyParts {
173 if !bytes.Equal(p1, cand.BodyParts[i]) {
174 return false
175 }
176 }
177
178 return headersMatch(in.Header, cand.Header, ignoreHeaders)
179 }
180
181
182
183
184 var DebugHeaders = false
185
186 func headersMatch(in, cand http.Header, ignores map[string]bool) bool {
187 for k1, v1 := range in {
188 if ignores[k1] {
189 continue
190 }
191 v2 := cand[k1]
192 if v2 == nil {
193 if DebugHeaders {
194 log.Printf("header %s: present in incoming request but not candidate", k1)
195 }
196 return false
197 }
198 if !reflect.DeepEqual(v1, v2) {
199 if DebugHeaders {
200 log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2)
201 }
202 return false
203 }
204 }
205 for k2 := range cand {
206 if ignores[k2] {
207 continue
208 }
209 if in[k2] == nil {
210 if DebugHeaders {
211 log.Printf("header %s: not in incoming request but present in candidate", k2)
212 }
213 return false
214 }
215 }
216 return true
217 }
218
View as plain text