1 package gzhttp
2
3 import (
4 "bufio"
5 "crypto/rand"
6 "crypto/sha256"
7 "encoding/binary"
8 "errors"
9 "fmt"
10 "hash/crc32"
11 "io"
12 "math"
13 "math/bits"
14 "mime"
15 "net"
16 "net/http"
17 "strconv"
18 "strings"
19 "sync"
20
21 "github.com/klauspost/compress/gzhttp/writer"
22 "github.com/klauspost/compress/gzhttp/writer/gzkp"
23 "github.com/klauspost/compress/gzip"
24 )
25
26 const (
27
28
29
30 HeaderNoCompression = "No-Gzip-Compression"
31
32 vary = "Vary"
33 acceptEncoding = "Accept-Encoding"
34 contentEncoding = "Content-Encoding"
35 contentRange = "Content-Range"
36 acceptRanges = "Accept-Ranges"
37 contentType = "Content-Type"
38 contentLength = "Content-Length"
39 eTag = "ETag"
40 )
41
42 type codings map[string]float64
43
44 const (
45
46
47
48 DefaultQValue = 1.0
49
50
51
52
53
54
55 DefaultMinSize = 1024
56 )
57
58
59
60
61
62 type GzipResponseWriter struct {
63 http.ResponseWriter
64 level int
65 gwFactory writer.GzipWriterFactory
66 gw writer.GzipWriter
67
68 code int
69
70 minSize int
71 buf []byte
72 ignore bool
73 keepAcceptRanges bool
74 setContentType bool
75 suffixETag string
76 dropETag bool
77 sha256Jitter bool
78 randomJitter string
79 jitterBuffer int
80
81 contentTypeFilter func(ct string) bool
82 }
83
84 type GzipResponseWriterWithCloseNotify struct {
85 *GzipResponseWriter
86 }
87
88 func (w GzipResponseWriterWithCloseNotify) CloseNotify() <-chan bool {
89 return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
90 }
91
92
93 func (w *GzipResponseWriter) Write(b []byte) (int, error) {
94
95 if w.gw != nil {
96 return w.gw.Write(b)
97 }
98
99
100 if w.ignore {
101 return w.ResponseWriter.Write(b)
102 }
103
104
105
106 wantBuf := 512
107 if w.minSize > wantBuf {
108 wantBuf = w.minSize
109 }
110 if w.jitterBuffer > 0 && w.jitterBuffer > wantBuf {
111 wantBuf = w.jitterBuffer
112 }
113 toAdd := len(b)
114 if len(w.buf)+toAdd > wantBuf {
115 toAdd = wantBuf - len(w.buf)
116 }
117 w.buf = append(w.buf, b[:toAdd]...)
118 remain := b[toAdd:]
119 hdr := w.Header()
120
121
122 if len(hdr[HeaderNoCompression]) == 0 && hdr.Get(contentEncoding) == "" && hdr.Get(contentRange) == "" {
123
124 cl, _ := atoi(hdr.Get(contentLength))
125 ct := hdr.Get(contentType)
126 if cl == 0 || cl >= w.minSize && (ct == "" || w.contentTypeFilter(ct)) {
127
128 if len(w.buf) < w.minSize && cl == 0 || (w.jitterBuffer > 0 && len(w.buf) < w.jitterBuffer) {
129 return len(b), nil
130 }
131
132
133 if cl >= w.minSize || len(w.buf) >= w.minSize {
134
135 if ct == "" {
136 ct = http.DetectContentType(w.buf)
137 }
138
139
140
141 if _, ok := hdr[contentType]; w.setContentType && !ok {
142 hdr.Set(contentType, ct)
143 }
144
145
146 if w.contentTypeFilter(ct) {
147 if err := w.startGzip(remain); err != nil {
148 return 0, err
149 }
150 if len(remain) > 0 {
151 if _, err := w.gw.Write(remain); err != nil {
152 return 0, err
153 }
154 }
155 return len(b), nil
156 }
157 }
158 }
159 }
160
161 if err := w.startPlain(); err != nil {
162 return 0, err
163 }
164 if len(remain) > 0 {
165 if _, err := w.ResponseWriter.Write(remain); err != nil {
166 return 0, err
167 }
168 }
169 return len(b), nil
170 }
171
172 func (w *GzipResponseWriter) Unwrap() http.ResponseWriter {
173 return w.ResponseWriter
174 }
175
176 var castagnoliTable = crc32.MakeTable(crc32.Castagnoli)
177
178
179 func (w *GzipResponseWriter) startGzip(remain []byte) error {
180
181 w.Header().Set(contentEncoding, "gzip")
182
183
184
185
186 w.Header().Del(contentLength)
187
188
189 if !w.keepAcceptRanges {
190 w.Header().Del(acceptRanges)
191 }
192
193
194 if w.suffixETag != "" && !w.dropETag && w.Header().Get(eTag) != "" {
195 orig := w.Header().Get(eTag)
196 insertPoint := strings.LastIndex(orig, `"`)
197 if insertPoint == -1 {
198 insertPoint = len(orig)
199 }
200 w.Header().Set(eTag, orig[:insertPoint]+w.suffixETag+orig[insertPoint:])
201 }
202
203
204 if w.dropETag {
205 w.Header().Del(eTag)
206 }
207
208
209 if w.code != 0 {
210 w.ResponseWriter.WriteHeader(w.code)
211
212 w.code = 0
213 }
214
215
216
217
218 if len(w.buf) > 0 {
219
220 w.init()
221
222
223
224 if len(w.randomJitter) > 0 {
225 var jitRNG uint32
226 if w.jitterBuffer > 0 {
227 if w.sha256Jitter {
228 h := sha256.New()
229 h.Write(w.buf)
230
231 if len(remain) > 0 && len(w.buf) < w.jitterBuffer {
232 remain := remain
233 if len(remain)+len(w.buf) > w.jitterBuffer {
234 remain = remain[:w.jitterBuffer-len(w.buf)]
235 }
236 h.Write(remain)
237 }
238 var tmp [sha256.Size]byte
239 jitRNG = binary.LittleEndian.Uint32(h.Sum(tmp[:0]))
240 } else {
241 h := crc32.Update(0, castagnoliTable, w.buf)
242
243 if len(remain) > 0 && len(w.buf) < w.jitterBuffer {
244 remain := remain
245 if len(remain)+len(w.buf) > w.jitterBuffer {
246 remain = remain[:w.jitterBuffer-len(w.buf)]
247 }
248 h = crc32.Update(h, castagnoliTable, remain)
249 }
250 jitRNG = bits.RotateLeft32(h, 19) ^ 0xab0755de
251 }
252 } else {
253
254 var tmp [4]byte
255 _, err := rand.Read(tmp[:])
256 if err != nil {
257 return fmt.Errorf("gzhttp: %w", err)
258 }
259 jitRNG = binary.LittleEndian.Uint32(tmp[:])
260 }
261 jit := w.randomJitter[:1+jitRNG%uint32(len(w.randomJitter)-1)]
262 w.gw.(writer.GzipWriterExt).SetHeader(writer.Header{Comment: jit})
263 }
264 n, err := w.gw.Write(w.buf)
265
266
267
268
269 if err == nil && n < len(w.buf) {
270 err = io.ErrShortWrite
271 }
272 w.buf = w.buf[:0]
273 return err
274 }
275 return nil
276 }
277
278
279 func (w *GzipResponseWriter) startPlain() error {
280 w.Header().Del(HeaderNoCompression)
281 if w.code != 0 {
282 w.ResponseWriter.WriteHeader(w.code)
283
284 w.code = 0
285 }
286
287 w.ignore = true
288
289 if len(w.buf) == 0 {
290 return nil
291 }
292 n, err := w.ResponseWriter.Write(w.buf)
293
294
295
296 if err == nil && n < len(w.buf) {
297 err = io.ErrShortWrite
298 }
299
300 w.buf = w.buf[:0]
301 return err
302 }
303
304
305
306 func (w *GzipResponseWriter) WriteHeader(code int) {
307
308
309 if shouldWrite1xxResponses() && code >= 100 && code <= 199 {
310 w.ResponseWriter.WriteHeader(code)
311 return
312 }
313
314 if w.code == 0 {
315 w.code = code
316 }
317 }
318
319
320
321 func (w *GzipResponseWriter) init() {
322
323
324 w.gw = w.gwFactory.New(w.ResponseWriter, w.level)
325 }
326
327
328 func (w *GzipResponseWriter) Close() error {
329 if w.ignore {
330 return nil
331 }
332 if w.gw == nil {
333 var (
334 ct = w.Header().Get(contentType)
335 ce = w.Header().Get(contentEncoding)
336 cr = w.Header().Get(contentRange)
337 )
338 if ct == "" {
339 ct = http.DetectContentType(w.buf)
340
341
342
343 if _, ok := w.Header()[contentType]; w.setContentType && !ok {
344 w.Header().Set(contentType, ct)
345 }
346 }
347
348 if len(w.buf) == 0 || len(w.buf) < w.minSize || len(w.Header()[HeaderNoCompression]) != 0 || ce != "" || cr != "" || !w.contentTypeFilter(ct) {
349
350 return w.startPlain()
351 }
352 err := w.startGzip(nil)
353 if err != nil {
354 return err
355 }
356 }
357
358 err := w.gw.Close()
359 w.gw = nil
360 return err
361 }
362
363
364
365
366
367
368
369 func (w *GzipResponseWriter) Flush() {
370 if w.gw == nil && !w.ignore {
371 if len(w.buf) == 0 {
372
373 return
374 }
375 var (
376 cl, _ = atoi(w.Header().Get(contentLength))
377 ct = w.Header().Get(contentType)
378 ce = w.Header().Get(contentEncoding)
379 cr = w.Header().Get(contentRange)
380 )
381
382 if ct == "" {
383 ct = http.DetectContentType(w.buf)
384
385
386
387 if _, ok := w.Header()[contentType]; w.setContentType && !ok {
388 w.Header().Set(contentType, ct)
389 }
390 }
391 if cl == 0 {
392
393 cl = w.minSize
394 }
395
396
397 if len(w.Header()[HeaderNoCompression]) == 0 && ce == "" && cr == "" && cl >= w.minSize && w.contentTypeFilter(ct) {
398 w.startGzip(nil)
399 } else {
400 w.startPlain()
401 }
402 }
403
404 if w.gw != nil {
405 w.gw.Flush()
406 }
407
408 if fw, ok := w.ResponseWriter.(http.Flusher); ok {
409 fw.Flush()
410 }
411 }
412
413
414
415 func (w *GzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
416 if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
417 return hj.Hijack()
418 }
419 return nil, nil, fmt.Errorf("http.Hijacker interface is not supported")
420 }
421
422
423 var _ http.Hijacker = &GzipResponseWriter{}
424
425 var onceDefault sync.Once
426 var defaultWrapper func(http.Handler) http.HandlerFunc
427
428
429 func GzipHandler(h http.Handler) http.HandlerFunc {
430 onceDefault.Do(func() {
431 var err error
432 defaultWrapper, err = NewWrapper()
433 if err != nil {
434 panic(err)
435 }
436 })
437
438 return defaultWrapper(h)
439 }
440
441 var grwPool = sync.Pool{New: func() interface{} { return &GzipResponseWriter{} }}
442
443
444 func NewWrapper(opts ...option) (func(http.Handler) http.HandlerFunc, error) {
445 c := &config{
446 level: gzip.DefaultCompression,
447 minSize: DefaultMinSize,
448 writer: writer.GzipWriterFactory{
449 Levels: gzkp.Levels,
450 New: gzkp.NewWriter,
451 },
452 contentTypes: DefaultContentTypeFilter,
453 setContentType: true,
454 }
455
456 for _, o := range opts {
457 o(c)
458 }
459
460 if err := c.validate(); err != nil {
461 return nil, err
462 }
463
464 return func(h http.Handler) http.HandlerFunc {
465 return func(w http.ResponseWriter, r *http.Request) {
466 w.Header().Add(vary, acceptEncoding)
467 if acceptsGzip(r) {
468 gw := grwPool.Get().(*GzipResponseWriter)
469 *gw = GzipResponseWriter{
470 ResponseWriter: w,
471 gwFactory: c.writer,
472 level: c.level,
473 minSize: c.minSize,
474 contentTypeFilter: c.contentTypes,
475 keepAcceptRanges: c.keepAcceptRanges,
476 dropETag: c.dropETag,
477 suffixETag: c.suffixETag,
478 buf: gw.buf,
479 setContentType: c.setContentType,
480 randomJitter: c.randomJitter,
481 jitterBuffer: c.jitterBuffer,
482 sha256Jitter: c.sha256Jitter,
483 }
484 if len(gw.buf) > 0 {
485 gw.buf = gw.buf[:0]
486 }
487 defer func() {
488 gw.Close()
489 gw.ResponseWriter = nil
490 grwPool.Put(gw)
491 }()
492
493 if _, ok := w.(http.CloseNotifier); ok {
494 gwcn := GzipResponseWriterWithCloseNotify{gw}
495 h.ServeHTTP(gwcn, r)
496 } else {
497 h.ServeHTTP(gw, r)
498 }
499 w.Header().Del(HeaderNoCompression)
500 } else {
501 h.ServeHTTP(newNoGzipResponseWriter(w), r)
502 w.Header().Del(HeaderNoCompression)
503 }
504 }
505 }, nil
506 }
507
508
509
510 type parsedContentType struct {
511 mediaType string
512 params map[string]string
513 }
514
515
516 func (pct parsedContentType) equals(mediaType string, params map[string]string) bool {
517 if pct.mediaType != mediaType {
518 return false
519 }
520
521 if len(pct.params) == 0 {
522 return true
523 }
524
525
526 if len(pct.params) != len(params) {
527 return false
528 }
529 for k, v := range pct.params {
530 if w, ok := params[k]; !ok || v != w {
531 return false
532 }
533 }
534 return true
535 }
536
537
538 type config struct {
539 minSize int
540 level int
541 writer writer.GzipWriterFactory
542 contentTypes func(ct string) bool
543 keepAcceptRanges bool
544 setContentType bool
545 suffixETag string
546 dropETag bool
547 jitterBuffer int
548 randomJitter string
549 sha256Jitter bool
550 }
551
552 func (c *config) validate() error {
553 min, max := c.writer.Levels()
554 if c.level < min || c.level > max {
555 return fmt.Errorf("invalid compression level requested: %d, valid range %d -> %d", c.level, min, max)
556 }
557
558 if c.minSize < 0 {
559 return fmt.Errorf("minimum size must be more than zero")
560 }
561 if len(c.randomJitter) >= math.MaxUint16 {
562 return fmt.Errorf("random jitter size exceeded")
563 }
564 if len(c.randomJitter) > 0 {
565 gzw, ok := c.writer.New(io.Discard, c.level).(writer.GzipWriterExt)
566 if !ok {
567 return errors.New("the custom compressor does not allow setting headers for random jitter")
568 }
569 gzw.Close()
570 }
571 return nil
572 }
573
574 type option func(c *config)
575
576 func MinSize(size int) option {
577 return func(c *config) {
578 c.minSize = size
579 }
580 }
581
582
583 func CompressionLevel(level int) option {
584 return func(c *config) {
585 c.level = level
586 }
587 }
588
589
590
591
592 func SetContentType(b bool) option {
593 return func(c *config) {
594 c.setContentType = b
595 }
596 }
597
598
599
600
601
602
603 func Implementation(writer writer.GzipWriterFactory) option {
604 return func(c *config) {
605 c.writer = writer
606 }
607 }
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628 func ContentTypes(types []string) option {
629 return func(c *config) {
630 var contentTypes []parsedContentType
631 for _, v := range types {
632 mediaType, params, err := mime.ParseMediaType(v)
633 if err == nil {
634 contentTypes = append(contentTypes, parsedContentType{mediaType, params})
635 }
636 }
637 c.contentTypes = func(ct string) bool {
638 return handleContentType(contentTypes, ct)
639 }
640 }
641 }
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662 func ExceptContentTypes(types []string) option {
663 return func(c *config) {
664 var contentTypes []parsedContentType
665 for _, v := range types {
666 mediaType, params, err := mime.ParseMediaType(v)
667 if err == nil {
668 contentTypes = append(contentTypes, parsedContentType{mediaType, params})
669 }
670 }
671 c.contentTypes = func(ct string) bool {
672 return !handleContentType(contentTypes, ct)
673 }
674 }
675 }
676
677
678
679
680 func KeepAcceptRanges() option {
681 return func(c *config) {
682 c.keepAcceptRanges = true
683 }
684 }
685
686
687
688
689
690
691
692
693
694
695 func ContentTypeFilter(compress func(ct string) bool) option {
696 return func(c *config) {
697 c.contentTypes = compress
698 }
699 }
700
701
702
703
704
705
706
707
708
709 func SuffixETag(suffix string) option {
710 return func(c *config) {
711 c.suffixETag = suffix
712 }
713 }
714
715
716
717
718
719
720
721
722
723
724 func DropETag() option {
725 return func(c *config) {
726 c.dropETag = true
727 }
728 }
729
730
731
732
733
734
735
736
737
738
739 func RandomJitter(n, buffer int, paranoid bool) option {
740 return func(c *config) {
741 if n > 0 {
742 c.sha256Jitter = paranoid
743 c.randomJitter = strings.Repeat("Padding-", 1+(n/8))[:n+1]
744 c.jitterBuffer = buffer
745 if c.jitterBuffer == 0 {
746 c.jitterBuffer = 64 << 10
747 }
748 } else {
749 c.randomJitter = ""
750 c.jitterBuffer = 0
751 }
752 }
753 }
754
755
756
757 func acceptsGzip(r *http.Request) bool {
758
759
760
761
762 return r.Method != http.MethodHead && parseEncodingGzip(r.Header.Get(acceptEncoding)) > 0
763 }
764
765
766 func handleContentType(contentTypes []parsedContentType, ct string) bool {
767
768 if len(contentTypes) == 0 {
769 return true
770 }
771
772 mediaType, params, err := mime.ParseMediaType(ct)
773 if err != nil {
774 return false
775 }
776
777 for _, c := range contentTypes {
778 if c.equals(mediaType, params) {
779 return true
780 }
781 }
782
783 return false
784 }
785
786
787 func parseEncodingGzip(s string) float64 {
788 s = strings.TrimSpace(s)
789
790 for len(s) > 0 {
791 stop := strings.IndexByte(s, ',')
792 if stop < 0 {
793 stop = len(s)
794 }
795 coding, qvalue, _ := parseCoding(s[:stop])
796
797 if coding == "gzip" {
798 return qvalue
799 }
800 if stop == len(s) {
801 break
802 }
803 s = s[stop+1:]
804 }
805 return 0
806 }
807
808 func parseEncodings(s string) (codings, error) {
809 split := strings.Split(s, ",")
810 c := make(codings, len(split))
811 var e []string
812
813 for _, ss := range split {
814 coding, qvalue, err := parseCoding(ss)
815
816 if err != nil {
817 e = append(e, err.Error())
818 } else {
819 c[coding] = qvalue
820 }
821 }
822
823
824
825 if len(e) > 0 {
826 return c, fmt.Errorf("errors while parsing encodings: %s", strings.Join(e, ", "))
827 }
828
829 return c, nil
830 }
831
832 var errEmptyEncoding = errors.New("empty content-coding")
833
834
835
836
837 func parseCoding(s string) (coding string, qvalue float64, err error) {
838
839 if len(s) == 0 {
840 return "", 0, errEmptyEncoding
841 }
842 if !strings.ContainsRune(s, ';') {
843 coding = strings.ToLower(strings.TrimSpace(s))
844 if coding == "" {
845 err = errEmptyEncoding
846 }
847 return coding, DefaultQValue, err
848 }
849 for n, part := range strings.Split(s, ";") {
850 part = strings.TrimSpace(part)
851 qvalue = DefaultQValue
852
853 if n == 0 {
854 coding = strings.ToLower(part)
855 } else if strings.HasPrefix(part, "q=") {
856 qvalue, err = strconv.ParseFloat(strings.TrimPrefix(part, "q="), 64)
857
858 if qvalue < 0.0 {
859 qvalue = 0.0
860 } else if qvalue > 1.0 {
861 qvalue = 1.0
862 }
863 }
864 }
865
866 if coding == "" {
867 err = errEmptyEncoding
868 }
869
870 return
871 }
872
873
874 var excludePrefixDefault = []string{"video/", "audio/", "image/jp"}
875
876
877
878 var excludeContainsDefault = []string{"compress", "zip", "snappy", "lzma", "xz", "zstd", "brotli", "stuffit"}
879
880
881 func DefaultContentTypeFilter(ct string) bool {
882 ct = strings.TrimSpace(strings.ToLower(ct))
883 if ct == "" {
884 return true
885 }
886 for _, s := range excludeContainsDefault {
887 if strings.Contains(ct, s) {
888 return false
889 }
890 }
891
892 for _, prefix := range excludePrefixDefault {
893 if strings.HasPrefix(ct, prefix) {
894 return false
895 }
896 }
897 return true
898 }
899
900
901 func CompressAllContentTypeFilter(ct string) bool {
902 return true
903 }
904
905 const intSize = 32 << (^uint(0) >> 63)
906
907
908 func atoi(s string) (int, bool) {
909 if len(s) == 0 {
910 return 0, false
911 }
912 sLen := len(s)
913 if intSize == 32 && (0 < sLen && sLen < 10) ||
914 intSize == 64 && (0 < sLen && sLen < 19) {
915
916 s0 := s
917 if s[0] == '-' || s[0] == '+' {
918 s = s[1:]
919 if len(s) < 1 {
920 return 0, false
921 }
922 }
923
924 n := 0
925 for _, ch := range []byte(s) {
926 ch -= '0'
927 if ch > 9 {
928 return 0, false
929 }
930 n = n*10 + int(ch)
931 }
932 if s0[0] == '-' {
933 n = -n
934 }
935 return n, true
936 }
937
938
939 i64, err := strconv.ParseInt(s, 10, 0)
940 return int(i64), err == nil
941 }
942
943 type unwrapper interface {
944 Unwrap() http.ResponseWriter
945 }
946
947
948
949
950 func newNoGzipResponseWriter(w http.ResponseWriter) http.ResponseWriter {
951 n := &NoGzipResponseWriter{ResponseWriter: w}
952 if hj, ok := w.(http.Hijacker); ok {
953 x := struct {
954 http.ResponseWriter
955 http.Hijacker
956 http.Flusher
957 unwrapper
958 }{
959 ResponseWriter: n,
960 Hijacker: hj,
961 Flusher: n,
962 unwrapper: n,
963 }
964 return x
965 }
966
967 return n
968 }
969
970
971 type NoGzipResponseWriter struct {
972 http.ResponseWriter
973 hdrCleaned bool
974 }
975
976 func (n *NoGzipResponseWriter) CloseNotify() <-chan bool {
977 if cn, ok := n.ResponseWriter.(http.CloseNotifier); ok {
978 return cn.CloseNotify()
979 }
980 return nil
981 }
982
983 func (n *NoGzipResponseWriter) Flush() {
984 if !n.hdrCleaned {
985 n.ResponseWriter.Header().Del(HeaderNoCompression)
986 n.hdrCleaned = true
987 }
988 if f, ok := n.ResponseWriter.(http.Flusher); ok {
989 f.Flush()
990 }
991 }
992
993 func (n *NoGzipResponseWriter) Header() http.Header {
994 return n.ResponseWriter.Header()
995 }
996
997 func (n *NoGzipResponseWriter) Write(bytes []byte) (int, error) {
998 if !n.hdrCleaned {
999 n.ResponseWriter.Header().Del(HeaderNoCompression)
1000 n.hdrCleaned = true
1001 }
1002 return n.ResponseWriter.Write(bytes)
1003 }
1004
1005 func (n *NoGzipResponseWriter) WriteHeader(statusCode int) {
1006 if !n.hdrCleaned {
1007 n.ResponseWriter.Header().Del(HeaderNoCompression)
1008 n.hdrCleaned = true
1009 }
1010 n.ResponseWriter.WriteHeader(statusCode)
1011 }
1012
1013 func (n *NoGzipResponseWriter) Unwrap() http.ResponseWriter {
1014 return n.ResponseWriter
1015 }
1016
View as plain text