1
2
3
4
5
6
7 package driver
8
9 import (
10 "bytes"
11 "compress/zlib"
12 "fmt"
13 "io"
14 "sync"
15
16 "github.com/golang/snappy"
17 "github.com/klauspost/compress/zstd"
18 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
19 )
20
21
22 type CompressionOpts struct {
23 Compressor wiremessage.CompressorID
24 ZlibLevel int
25 ZstdLevel int
26 UncompressedSize int32
27 }
28
29
30
31
32 func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
33 enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
34 if err != nil {
35 panic(err)
36 }
37 return enc
38 }
39
40 var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
41 0: nil,
42 zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
43 zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
44 zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
45 zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
46 }
47
48 func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
49 if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
50 return zstdEncoders[level], nil
51 }
52
53 return nil, fmt.Errorf("invalid zstd compression level: %d", level)
54 }
55
56
57
58 const zlibEncodersOffset = -zlib.HuffmanOnly
59
60 var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool
61
62 func getZlibEncoder(level int) (*zlibEncoder, error) {
63 if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
64 if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
65 return enc, nil
66 }
67 writer, err := zlib.NewWriterLevel(nil, level)
68 if err != nil {
69 return nil, err
70 }
71 enc := &zlibEncoder{writer: writer, level: level}
72 return enc, nil
73 }
74
75 return nil, fmt.Errorf("invalid zlib compression level: %d", level)
76 }
77
78 func putZlibEncoder(enc *zlibEncoder) {
79 if enc != nil {
80 zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
81 }
82 }
83
84 type zlibEncoder struct {
85 writer *zlib.Writer
86 buf bytes.Buffer
87 level int
88 }
89
90 func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
91 defer putZlibEncoder(e)
92
93 e.buf.Reset()
94 e.writer.Reset(&e.buf)
95
96 _, err := e.writer.Write(src)
97 if err != nil {
98 return nil, err
99 }
100 err = e.writer.Close()
101 if err != nil {
102 return nil, err
103 }
104 dst = append(dst[:0], e.buf.Bytes()...)
105 return dst, nil
106 }
107
108
109 func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
110 switch opts.Compressor {
111 case wiremessage.CompressorNoOp:
112 return in, nil
113 case wiremessage.CompressorSnappy:
114 return snappy.Encode(nil, in), nil
115 case wiremessage.CompressorZLib:
116 encoder, err := getZlibEncoder(opts.ZlibLevel)
117 if err != nil {
118 return nil, err
119 }
120 return encoder.Encode(nil, in)
121 case wiremessage.CompressorZstd:
122 encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
123 if err != nil {
124 return nil, err
125 }
126 return encoder.EncodeAll(in, nil), nil
127 default:
128 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
129 }
130 }
131
132 var zstdReaderPool = sync.Pool{
133 New: func() interface{} {
134 r, _ := zstd.NewReader(nil)
135 return r
136 },
137 }
138
139
140 func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
141 switch opts.Compressor {
142 case wiremessage.CompressorNoOp:
143 return in, nil
144 case wiremessage.CompressorSnappy:
145 l, err := snappy.DecodedLen(in)
146 if err != nil {
147 return nil, fmt.Errorf("decoding compressed length %w", err)
148 } else if int32(l) != opts.UncompressedSize {
149 return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
150 }
151 out := make([]byte, opts.UncompressedSize)
152 return snappy.Decode(out, in)
153 case wiremessage.CompressorZLib:
154 r, err := zlib.NewReader(bytes.NewReader(in))
155 if err != nil {
156 return nil, err
157 }
158 out := make([]byte, opts.UncompressedSize)
159 if _, err := io.ReadFull(r, out); err != nil {
160 return nil, err
161 }
162 if err := r.Close(); err != nil {
163 return nil, err
164 }
165 return out, nil
166 case wiremessage.CompressorZstd:
167 buf := make([]byte, 0, opts.UncompressedSize)
168
169
170 r := zstdReaderPool.Get().(*zstd.Decoder)
171 out, err := r.DecodeAll(in, buf)
172 zstdReaderPool.Put(r)
173 return out, err
174 default:
175 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
176 }
177 }
178
View as plain text