1
2
3
4 package zstd
5
6 import (
7 "bytes"
8 "fmt"
9 "io"
10 rdebug "runtime/debug"
11 "testing"
12
13 "github.com/klauspost/compress/internal/cpuinfo"
14 "github.com/klauspost/compress/internal/fuzz"
15 )
16
17 func FuzzDecodeAll(f *testing.F) {
18 fuzz.AddFromZip(f, "testdata/decode-regression.zip", fuzz.TypeRaw, false)
19 fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
20 fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
21
22 f.Fuzz(func(t *testing.T, b []byte) {
23
24 defer func() {
25 if r := recover(); r != nil {
26 rdebug.PrintStack()
27 t.Fatal(r)
28 }
29 }()
30
31 decLow, err := NewReader(nil, WithDecoderLowmem(true), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true))
32 if err != nil {
33 t.Fatal(err)
34 }
35 defer decLow.Close()
36 decHi, err := NewReader(nil, WithDecoderLowmem(false), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true))
37 if err != nil {
38 t.Fatal(err)
39 }
40 defer decHi.Close()
41 b1, err1 := decLow.DecodeAll(b, make([]byte, 0, len(b)))
42 b2, err2 := decHi.DecodeAll(b, make([]byte, 0, len(b)))
43 if err1 != err2 {
44 if (err1 == nil) != (err2 == nil) {
45 t.Errorf("err low: %v, hi: %v", err1, err2)
46 }
47 }
48 if err1 != nil {
49 b1, b2 = b1[:0], b2[:0]
50 }
51 if !bytes.Equal(b1, b2) {
52 t.Fatalf("Output mismatch, low: %v, hi: %v", err1, err2)
53 }
54 })
55 }
56
57 func FuzzDecAllNoBMI2(f *testing.F) {
58 if !cpuinfo.HasBMI2() {
59 f.Skip("No BMI, so already tested")
60 return
61 }
62 defer cpuinfo.DisableBMI2()()
63 FuzzDecodeAll(f)
64 }
65
66 func FuzzDecoder(f *testing.F) {
67 fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
68 fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
69
70
71 brLow := newBytesReader(nil)
72 brHi := newBytesReader(nil)
73 f.Fuzz(func(t *testing.T, b []byte) {
74
75 defer func() {
76 if r := recover(); r != nil {
77 rdebug.PrintStack()
78 t.Fatal(r)
79 }
80 }()
81 brLow.Reset(b)
82 brHi.Reset(b)
83 decLow, err := NewReader(brLow, WithDecoderLowmem(true), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true), WithDecodeBuffersBelow(8<<10))
84 if err != nil {
85 t.Fatal(err)
86 }
87 defer decLow.Close()
88
89
90 decHi, err := NewReader(brHi, WithDecoderLowmem(false), WithDecoderConcurrency(1), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true), WithDecodeBuffersBelow(8<<10))
91 if err != nil {
92 t.Fatal(err)
93 }
94 defer decHi.Close()
95
96 if debugDecoder {
97 fmt.Println("LOW CONCURRENT")
98 }
99 b1, err1 := io.ReadAll(decLow)
100
101 if debugDecoder {
102 fmt.Println("HI NOT CONCURRENT")
103 }
104 b2, err2 := io.ReadAll(decHi)
105 if err1 != err2 {
106 if (err1 == nil) != (err2 == nil) {
107 t.Errorf("err low concurrent: %v, hi: %v", err1, err2)
108 }
109 }
110 if err1 != nil {
111 b1, b2 = b1[:0], b2[:0]
112 }
113 if !bytes.Equal(b1, b2) {
114 t.Fatalf("Output mismatch, low concurrent: %v, hi: %v", err1, err2)
115 }
116 })
117 }
118
119 func FuzzNoBMI2Dec(f *testing.F) {
120 if !cpuinfo.HasBMI2() {
121 f.Skip("No BMI, so already tested")
122 return
123 }
124 defer cpuinfo.DisableBMI2()()
125 FuzzDecoder(f)
126 }
127
128 func FuzzEncoding(f *testing.F) {
129 fuzz.AddFromZip(f, "testdata/fuzz/encode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
130 fuzz.AddFromZip(f, "testdata/comp-crashers.zip", fuzz.TypeRaw, false)
131 fuzz.AddFromZip(f, "testdata/fuzz/encode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
132
133 const (
134
135 startFuzz = SpeedFastest
136 endFuzz = SpeedBestCompression
137
138
139 testDicts = true
140 )
141
142 var dec *Decoder
143 var encs [SpeedBestCompression + 1]*Encoder
144 var encsD [SpeedBestCompression + 1]*Encoder
145
146 var dicts [][]byte
147 if testDicts {
148 zr := testCreateZipReader("testdata/dict-tests-small.zip", f)
149 dicts = readDicts(f, zr)
150 }
151
152 if testing.Short() && *fuzzEndF > int(SpeedBetterCompression) {
153 *fuzzEndF = int(SpeedBetterCompression)
154 }
155
156 initEnc := func() func() {
157 var err error
158 dec, err = NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...), WithDecoderMaxWindow(64<<10), WithDecoderMaxMemory(uint64(*fuzzMaxF)))
159 if err != nil {
160 panic(err)
161 }
162 for level := startFuzz; level <= endFuzz; level++ {
163 encs[level], err = NewWriter(nil, WithEncoderCRC(true), WithEncoderLevel(level), WithEncoderConcurrency(2), WithWindowSize(64<<10), WithZeroFrames(true), WithLowerEncoderMem(true))
164 if testDicts {
165 encsD[level], err = NewWriter(nil, WithEncoderCRC(true), WithEncoderLevel(level), WithEncoderConcurrency(2), WithWindowSize(64<<10), WithZeroFrames(true), WithEncoderDict(dicts[int(level)%len(dicts)]), WithLowerEncoderMem(true), WithLowerEncoderMem(true))
166 }
167 }
168 return func() {
169 dec.Close()
170 for _, enc := range encs {
171 if enc != nil {
172 enc.Close()
173 }
174 }
175 if testDicts {
176 for _, enc := range encsD {
177 if enc != nil {
178 enc.Close()
179 }
180 }
181 }
182 }
183 }
184
185 f.Cleanup(initEnc())
186
187 var dst bytes.Buffer
188
189 f.Fuzz(func(t *testing.T, data []byte) {
190
191 defer func() {
192 if r := recover(); r != nil {
193 stack := rdebug.Stack()
194 t.Fatalf("%v:\n%v", r, string(stack))
195 }
196 }()
197 if len(data) > *fuzzMaxF {
198 return
199 }
200 var bufSize = len(data)
201 if bufSize > 2 {
202
203 bufSize = int(data[0]) | int(data[1])<<8
204 if bufSize >= len(data) {
205 bufSize = len(data) / 2
206 }
207 }
208
209 for level := *fuzzStartF; level <= *fuzzEndF; level++ {
210 enc := encs[level]
211 dst.Reset()
212 enc.Reset(&dst)
213 n, err := enc.Write(data)
214 if err != nil {
215 t.Fatal(err)
216 }
217 if n != len(data) {
218 t.Fatal(fmt.Sprintln("Level", level, "Short write, got:", n, "want:", len(data)))
219 }
220
221 encoded := enc.EncodeAll(data, make([]byte, 0, bufSize))
222 if len(encoded) > enc.MaxEncodedSize(len(data)) {
223 t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
224 }
225
226 got, err := dec.DecodeAll(encoded, make([]byte, 0, bufSize))
227 if err != nil {
228 t.Fatal(fmt.Sprintln("Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
229 }
230 if !bytes.Equal(got, data) {
231 t.Fatal(fmt.Sprintln("Level", level, "DecodeAll output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded)))
232 }
233
234 err = enc.Close()
235 if err != nil {
236 t.Fatal(fmt.Sprintln("Level", level, "Close (buffer) error:", err))
237 }
238 encoded2 := dst.Bytes()
239 if len(encoded2) > enc.MaxEncodedSize(len(data)) {
240 t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
241 }
242 if !bytes.Equal(encoded, encoded2) {
243 got, err = dec.DecodeAll(encoded2, got[:0])
244 if err != nil {
245 t.Fatal(fmt.Sprintln("Level", level, "DecodeAll (buffer) error:", err, "\norg:", len(data), "\nencoded", len(encoded2)))
246 }
247 if !bytes.Equal(got, data) {
248 t.Fatal(fmt.Sprintln("Level", level, "DecodeAll (buffer) output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded2)))
249 }
250 }
251 if !testDicts {
252 continue
253 }
254 enc = encsD[level]
255 dst.Reset()
256 enc.Reset(&dst)
257 n, err = enc.Write(data)
258 if err != nil {
259 t.Fatal(err)
260 }
261 if n != len(data) {
262 t.Fatal(fmt.Sprintln("Dict Level", level, "Short write, got:", n, "want:", len(data)))
263 }
264
265 encoded = enc.EncodeAll(data, encoded[:0])
266 if len(encoded) > enc.MaxEncodedSize(len(data)) {
267 t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
268 }
269 got, err = dec.DecodeAll(encoded, got[:0])
270 if err != nil {
271 t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
272 }
273 if !bytes.Equal(got, data) {
274 t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded)))
275 }
276
277 err = enc.Close()
278 if err != nil {
279 t.Fatal(fmt.Sprintln("Dict Level", level, "Close (buffer) error:", err))
280 }
281 encoded2 = dst.Bytes()
282 if len(encoded2) > enc.MaxEncodedSize(len(data)) {
283 t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
284 }
285 if !bytes.Equal(encoded, encoded2) {
286 got, err = dec.DecodeAll(encoded2, got[:0])
287 if err != nil {
288 t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll (buffer) error:", err, "\norg:", len(data), "\nencoded", len(encoded2)))
289 }
290 if !bytes.Equal(got, data) {
291 t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll (buffer) output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded2)))
292 }
293 }
294 }
295 })
296 }
297
View as plain text