1
16
17 package test
18
19 import (
20 "bytes"
21 "fmt"
22 "io"
23 "net"
24 "strings"
25 "sync"
26 "time"
27
28 "golang.org/x/net/http2"
29 "golang.org/x/net/http2/hpack"
30 )
31
32 type listenerWrapper struct {
33 net.Listener
34 mu sync.Mutex
35 rcw *rawConnWrapper
36 }
37
38 func listenWithConnControl(network, address string) (net.Listener, error) {
39 l, err := net.Listen(network, address)
40 if err != nil {
41 return nil, err
42 }
43 return &listenerWrapper{Listener: l}, nil
44 }
45
46
47
48 func (l *listenerWrapper) Accept() (net.Conn, error) {
49 c, err := l.Listener.Accept()
50 if err != nil {
51 return nil, err
52 }
53 l.mu.Lock()
54 l.rcw = newRawConnWrapperFromConn(c)
55 l.mu.Unlock()
56 return c, nil
57 }
58
59 func (l *listenerWrapper) getLastConn() *rawConnWrapper {
60 l.mu.Lock()
61 defer l.mu.Unlock()
62 return l.rcw
63 }
64
65 type dialerWrapper struct {
66 c net.Conn
67 rcw *rawConnWrapper
68 }
69
70 func (d *dialerWrapper) dialer(target string, t time.Duration) (net.Conn, error) {
71 c, err := net.DialTimeout("tcp", target, t)
72 d.c = c
73 d.rcw = newRawConnWrapperFromConn(c)
74 return c, err
75 }
76
77 func (d *dialerWrapper) getRawConnWrapper() *rawConnWrapper {
78 return d.rcw
79 }
80
81 type rawConnWrapper struct {
82 cc io.ReadWriteCloser
83 fr *http2.Framer
84
85
86 headerBuf bytes.Buffer
87 hpackEnc *hpack.Encoder
88
89
90 frc chan http2.Frame
91 frErrc chan error
92 }
93
94 func newRawConnWrapperFromConn(cc io.ReadWriteCloser) *rawConnWrapper {
95 rcw := &rawConnWrapper{
96 cc: cc,
97 frc: make(chan http2.Frame, 1),
98 frErrc: make(chan error, 1),
99 }
100 rcw.hpackEnc = hpack.NewEncoder(&rcw.headerBuf)
101 rcw.fr = http2.NewFramer(cc, cc)
102 rcw.fr.ReadMetaHeaders = hpack.NewDecoder(4096 , nil)
103
104 return rcw
105 }
106
107 func (rcw *rawConnWrapper) Close() error {
108 return rcw.cc.Close()
109 }
110
111 func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error {
112 err := rcw.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
113 if err != nil {
114 return fmt.Errorf("HPACK encoding error for %q/%q: %v", k, v, err)
115 }
116 return nil
117 }
118
119
120
121 func (rcw *rawConnWrapper) encodeRawHeader(headers ...string) []byte {
122 if len(headers)%2 == 1 {
123 panic("odd number of kv args")
124 }
125
126 rcw.headerBuf.Reset()
127
128 pseudoCount := map[string]int{}
129 var keys []string
130 vals := map[string][]string{}
131
132 for len(headers) > 0 {
133 k, v := headers[0], headers[1]
134 headers = headers[2:]
135 if _, ok := vals[k]; !ok {
136 keys = append(keys, k)
137 }
138 if strings.HasPrefix(k, ":") {
139 pseudoCount[k]++
140 if pseudoCount[k] == 1 {
141 vals[k] = []string{v}
142 } else {
143
144 vals[k] = append(vals[k], v)
145 }
146 } else {
147 vals[k] = append(vals[k], v)
148 }
149 }
150 for _, k := range keys {
151 for _, v := range vals[k] {
152 rcw.encodeHeaderField(k, v)
153 }
154 }
155 return rcw.headerBuf.Bytes()
156 }
157
158
159
160
161
162
163
164 func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte {
165 if len(headers)%2 == 1 {
166 panic("odd number of kv args")
167 }
168
169 rcw.headerBuf.Reset()
170
171 if len(headers) == 0 {
172
173
174 rcw.encodeHeaderField(":method", "GET")
175 rcw.encodeHeaderField(":path", "/")
176 rcw.encodeHeaderField(":scheme", "https")
177 return rcw.headerBuf.Bytes()
178 }
179
180 if len(headers) == 2 && headers[0] == ":method" {
181
182 rcw.encodeHeaderField(":method", headers[1])
183 rcw.encodeHeaderField(":path", "/")
184 rcw.encodeHeaderField(":scheme", "https")
185 return rcw.headerBuf.Bytes()
186 }
187
188 pseudoCount := map[string]int{}
189 keys := []string{":method", ":path", ":scheme"}
190 vals := map[string][]string{
191 ":method": {"GET"},
192 ":path": {"/"},
193 ":scheme": {"https"},
194 }
195 for len(headers) > 0 {
196 k, v := headers[0], headers[1]
197 headers = headers[2:]
198 if _, ok := vals[k]; !ok {
199 keys = append(keys, k)
200 }
201 if strings.HasPrefix(k, ":") {
202 pseudoCount[k]++
203 if pseudoCount[k] == 1 {
204 vals[k] = []string{v}
205 } else {
206
207 vals[k] = append(vals[k], v)
208 }
209 } else {
210 vals[k] = append(vals[k], v)
211 }
212 }
213 for _, k := range keys {
214 for _, v := range vals[k] {
215 rcw.encodeHeaderField(k, v)
216 }
217 }
218 return rcw.headerBuf.Bytes()
219 }
220
221 func (rcw *rawConnWrapper) writeHeaders(p http2.HeadersFrameParam) error {
222 if err := rcw.fr.WriteHeaders(p); err != nil {
223 return fmt.Errorf("error writing HEADERS: %v", err)
224 }
225 return nil
226 }
227
228 func (rcw *rawConnWrapper) writeRSTStream(streamID uint32, code http2.ErrCode) error {
229 if err := rcw.fr.WriteRSTStream(streamID, code); err != nil {
230 return fmt.Errorf("error writing RST_STREAM: %v", err)
231 }
232 return nil
233 }
234
235 func (rcw *rawConnWrapper) writeGoAway(maxStreamID uint32, code http2.ErrCode, debugData []byte) error {
236 if err := rcw.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
237 return fmt.Errorf("error writing GoAway: %v", err)
238 }
239 return nil
240 }
241
242 func (rcw *rawConnWrapper) writeRawFrame(t http2.FrameType, flags http2.Flags, streamID uint32, payload []byte) error {
243 if err := rcw.fr.WriteRawFrame(t, flags, streamID, payload); err != nil {
244 return fmt.Errorf("error writing Raw Frame: %v", err)
245 }
246 return nil
247 }
248
View as plain text