...
1
2
3
4
5 package gzhttp
6
7 import (
8 "io"
9 "net/http"
10 "strings"
11 "sync"
12
13 "github.com/klauspost/compress/gzip"
14 "github.com/klauspost/compress/zstd"
15 )
16
17
18
19
20 func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
21 g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true}
22 for _, o := range opts {
23 o(&g)
24 }
25 var ae []string
26 if g.withZstd {
27 ae = append(ae, "zstd")
28 }
29 if g.withGzip {
30 ae = append(ae, "gzip")
31 }
32 g.acceptEncoding = strings.Join(ae, ",")
33 return &g
34 }
35
36 type transportOption func(c *gzRoundtripper)
37
38
39
40 func TransportEnableZstd(b bool) transportOption {
41 return func(c *gzRoundtripper) {
42 c.withZstd = b
43 }
44 }
45
46
47
48 func TransportEnableGzip(b bool) transportOption {
49 return func(c *gzRoundtripper) {
50 c.withGzip = b
51 }
52 }
53
54
55
56
57
58 func TransportCustomEval(fn func(header http.Header) bool) transportOption {
59 return func(c *gzRoundtripper) {
60 c.customEval = fn
61 }
62 }
63
64 type gzRoundtripper struct {
65 parent http.RoundTripper
66 acceptEncoding string
67 withZstd, withGzip bool
68 customEval func(header http.Header) bool
69 }
70
71 func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
72 var requestedComp bool
73 if req.Header.Get("Accept-Encoding") == "" &&
74 req.Header.Get("Range") == "" &&
75 req.Method != "HEAD" {
76
77
78
79
80
81
82
83
84
85
86
87
88 requestedComp = len(g.acceptEncoding) > 0
89 req.Header.Set("Accept-Encoding", g.acceptEncoding)
90 }
91
92 resp, err := g.parent.RoundTrip(req)
93 if err != nil || !requestedComp {
94 return resp, err
95 }
96 decompress := false
97 if g.customEval != nil {
98 if !g.customEval(resp.Header) {
99 return resp, nil
100 }
101 decompress = true
102 }
103
104 if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
105 resp.Body = &gzipReader{body: resp.Body}
106 resp.Header.Del("Content-Encoding")
107 resp.Header.Del("Content-Length")
108 resp.ContentLength = -1
109 resp.Uncompressed = true
110 }
111 if (decompress || g.withZstd) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
112 resp.Body = &zstdReader{body: resp.Body}
113 resp.Header.Del("Content-Encoding")
114 resp.Header.Del("Content-Length")
115 resp.ContentLength = -1
116 resp.Uncompressed = true
117 }
118
119 return resp, nil
120 }
121
122 var gzReaderPool sync.Pool
123
124
125
126 type gzipReader struct {
127 body io.ReadCloser
128 zr *gzip.Reader
129 zerr error
130 }
131
132 func (gz *gzipReader) Read(p []byte) (n int, err error) {
133 if gz.zr == nil {
134 if gz.zerr == nil {
135 zr, ok := gzReaderPool.Get().(*gzip.Reader)
136 if ok {
137 gz.zr, gz.zerr = zr, zr.Reset(gz.body)
138 } else {
139 gz.zr, gz.zerr = gzip.NewReader(gz.body)
140 }
141 }
142 if gz.zerr != nil {
143 return 0, gz.zerr
144 }
145 }
146
147 return gz.zr.Read(p)
148 }
149
150 func (gz *gzipReader) Close() error {
151 if gz.zr != nil {
152 gzReaderPool.Put(gz.zr)
153 gz.zr = nil
154 }
155 return gz.body.Close()
156 }
157
158
159
160 func asciiEqualFold(s, t string) bool {
161 if len(s) != len(t) {
162 return false
163 }
164 for i := 0; i < len(s); i++ {
165 if lower(s[i]) != lower(t[i]) {
166 return false
167 }
168 }
169 return true
170 }
171
172
173 func lower(b byte) byte {
174 if 'A' <= b && b <= 'Z' {
175 return b + ('a' - 'A')
176 }
177 return b
178 }
179
180
181 var zstdReaderPool sync.Pool
182
183
184
185 type zstdReader struct {
186 body io.ReadCloser
187 zr *zstd.Decoder
188 zerr error
189 }
190
191 func (zr *zstdReader) Read(p []byte) (n int, err error) {
192 if zr.zerr != nil {
193 return 0, zr.zerr
194 }
195 if zr.zr == nil {
196 if zr.zerr == nil {
197 reader, ok := zstdReaderPool.Get().(*zstd.Decoder)
198 if ok {
199 zr.zerr = reader.Reset(zr.body)
200 zr.zr = reader
201 } else {
202 zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1))
203 }
204 }
205 if zr.zerr != nil {
206 return 0, zr.zerr
207 }
208 }
209 n, err = zr.zr.Read(p)
210 if err != nil {
211
212
213 zr.zr.Reset(nil)
214 zstdReaderPool.Put(zr.zr)
215 zr.zr = nil
216 zr.zerr = err
217 }
218 return
219 }
220
221 func (zr *zstdReader) Close() error {
222 if zr.zr != nil {
223 zr.zr.Reset(nil)
224 zstdReaderPool.Put(zr.zr)
225 zr.zr = nil
226 }
227 return zr.body.Close()
228 }
229
View as plain text