1
2
3
4
5
6
7 package driver
8
9 import (
10 "bytes"
11 "compress/zlib"
12 "os"
13 "testing"
14
15 "github.com/golang/snappy"
16 "github.com/klauspost/compress/zstd"
17
18 "go.mongodb.org/mongo-driver/internal/assert"
19 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
20 )
21
22 func TestCompression(t *testing.T) {
23 compressors := []wiremessage.CompressorID{
24 wiremessage.CompressorNoOp,
25 wiremessage.CompressorSnappy,
26 wiremessage.CompressorZLib,
27 wiremessage.CompressorZstd,
28 }
29
30 for _, compressor := range compressors {
31 t.Run(compressor.String(), func(t *testing.T) {
32 payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt")
33 opts := CompressionOpts{
34 Compressor: compressor,
35 ZlibLevel: wiremessage.DefaultZlibLevel,
36 ZstdLevel: wiremessage.DefaultZstdLevel,
37 UncompressedSize: int32(len(payload)),
38 }
39 compressed, err := CompressPayload(payload, opts)
40 assert.NoError(t, err)
41 assert.NotEqual(t, 0, len(compressed))
42 decompressed, err := DecompressPayload(compressed, opts)
43 assert.NoError(t, err)
44 assert.Equal(t, payload, decompressed)
45 })
46 }
47 }
48
49 func TestCompressionLevels(t *testing.T) {
50 in := []byte("abc")
51 wr := new(bytes.Buffer)
52
53 t.Run("ZLib", func(t *testing.T) {
54 opts := CompressionOpts{
55 Compressor: wiremessage.CompressorZLib,
56 }
57 for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
58 opts.ZlibLevel = lvl
59 _, err1 := CompressPayload(in, opts)
60 _, err2 := zlib.NewWriterLevel(wr, lvl)
61 if err2 != nil {
62 assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
63 } else {
64 assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
65 }
66 }
67 })
68
69 t.Run("Zstd", func(t *testing.T) {
70 opts := CompressionOpts{
71 Compressor: wiremessage.CompressorZstd,
72 }
73 for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
74 opts.ZstdLevel = int(lvl)
75 _, err1 := CompressPayload(in, opts)
76 _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
77 if err2 != nil {
78 assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
79 } else {
80 assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
81 }
82 }
83 })
84 }
85
86 func TestDecompressFailures(t *testing.T) {
87 t.Parallel()
88
89 t.Run("snappy decompress huge size", func(t *testing.T) {
90 t.Parallel()
91
92 opts := CompressionOpts{
93 Compressor: wiremessage.CompressorSnappy,
94 UncompressedSize: 100,
95 }
96
97
98
99 compressedData, err := CompressPayload(make([]byte, opts.UncompressedSize*2), opts)
100 assert.NoError(t, err, "premature error making compressed example")
101
102 _, err = DecompressPayload(compressedData, opts)
103 assert.Error(t, err)
104 })
105 }
106
107 var (
108 compressionPayload []byte
109 compressedSnappyPayload []byte
110 compressedZLibPayload []byte
111 compressedZstdPayload []byte
112 )
113
114 func initCompressionPayload(b *testing.B) {
115 if compressionPayload != nil {
116 return
117 }
118 data, err := os.ReadFile("testdata/compression.go")
119 if err != nil {
120 b.Fatal(err)
121 }
122 for i := 1; i < 10; i++ {
123 data = append(data, data...)
124 }
125 compressionPayload = data
126
127 compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data)
128
129 {
130 var buf bytes.Buffer
131 enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault))
132 if err != nil {
133 b.Fatal(err)
134 }
135 compressedZstdPayload = enc.EncodeAll(data, nil)
136 }
137
138 {
139 var buf bytes.Buffer
140 enc := zlib.NewWriter(&buf)
141 if _, err := enc.Write(data); err != nil {
142 b.Fatal(err)
143 }
144 if err := enc.Close(); err != nil {
145 b.Fatal(err)
146 }
147 if err := enc.Close(); err != nil {
148 b.Fatal(err)
149 }
150 compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...)
151 }
152
153 b.ResetTimer()
154 }
155
156 func BenchmarkCompressPayload(b *testing.B) {
157 initCompressionPayload(b)
158
159 compressors := []wiremessage.CompressorID{
160 wiremessage.CompressorSnappy,
161 wiremessage.CompressorZLib,
162 wiremessage.CompressorZstd,
163 }
164
165 for _, compressor := range compressors {
166 b.Run(compressor.String(), func(b *testing.B) {
167 opts := CompressionOpts{
168 Compressor: compressor,
169 ZlibLevel: wiremessage.DefaultZlibLevel,
170 ZstdLevel: wiremessage.DefaultZstdLevel,
171 }
172 payload := compressionPayload
173 b.SetBytes(int64(len(payload)))
174 b.ReportAllocs()
175 b.RunParallel(func(pb *testing.PB) {
176 for pb.Next() {
177 _, err := CompressPayload(payload, opts)
178 if err != nil {
179 b.Error(err)
180 }
181 }
182 })
183 })
184 }
185 }
186
187 func BenchmarkDecompressPayload(b *testing.B) {
188 initCompressionPayload(b)
189
190 benchmarks := []struct {
191 compressor wiremessage.CompressorID
192 payload []byte
193 }{
194 {wiremessage.CompressorSnappy, compressedSnappyPayload},
195 {wiremessage.CompressorZLib, compressedZLibPayload},
196 {wiremessage.CompressorZstd, compressedZstdPayload},
197 }
198
199 for _, bench := range benchmarks {
200 b.Run(bench.compressor.String(), func(b *testing.B) {
201 opts := CompressionOpts{
202 Compressor: bench.compressor,
203 ZlibLevel: wiremessage.DefaultZlibLevel,
204 ZstdLevel: wiremessage.DefaultZstdLevel,
205 UncompressedSize: int32(len(compressionPayload)),
206 }
207 payload := bench.payload
208 b.SetBytes(int64(len(compressionPayload)))
209 b.ReportAllocs()
210 b.RunParallel(func(pb *testing.PB) {
211 for pb.Next() {
212 _, err := DecompressPayload(payload, opts)
213 if err != nil {
214 b.Fatal(err)
215 }
216 }
217 })
218 })
219 }
220 }
221
View as plain text