1 package zstd
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "os"
8 "strings"
9 "testing"
10
11 "github.com/klauspost/compress/zip"
12 )
13
14 func TestDecoder_SmallDict(t *testing.T) {
15
16 zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
17 dicts := readDicts(t, zr)
18 dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
19 if err != nil {
20 t.Fatal(err)
21 return
22 }
23 defer dec.Close()
24 for _, tt := range zr.File {
25 if !strings.HasSuffix(tt.Name, ".zst") {
26 continue
27 }
28 t.Run("decodeall-"+tt.Name, func(t *testing.T) {
29 r, err := tt.Open()
30 if err != nil {
31 t.Fatal(err)
32 }
33 defer r.Close()
34 in, err := io.ReadAll(r)
35 if err != nil {
36 t.Fatal(err)
37 }
38 got, err := dec.DecodeAll(in, nil)
39 if err != nil {
40 t.Fatal(err)
41 }
42 _, err = dec.DecodeAll(in, got[:0])
43 if err != nil {
44 t.Fatal(err)
45 }
46 })
47 }
48 }
49
50 func TestEncoder_SmallDict(t *testing.T) {
51
52 zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
53 var dicts [][]byte
54 var encs []*Encoder
55 var noDictEncs []*Encoder
56 var encNames []string
57
58 for _, tt := range zr.File {
59 if !strings.HasSuffix(tt.Name, ".dict") {
60 continue
61 }
62 func() {
63 r, err := tt.Open()
64 if err != nil {
65 t.Fatal(err)
66 }
67 defer r.Close()
68 in, err := io.ReadAll(r)
69 if err != nil {
70 t.Fatal(err)
71 }
72 dicts = append(dicts, in)
73 for level := SpeedFastest; level < speedLast; level++ {
74 if isRaceTest && level >= SpeedBestCompression {
75 break
76 }
77 enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
78 if err != nil {
79 t.Fatal(err)
80 }
81 encs = append(encs, enc)
82 encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
83
84 enc, err = NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
85 if err != nil {
86 t.Fatal(err)
87 }
88 noDictEncs = append(noDictEncs, enc)
89 }
90 }()
91 }
92 dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
93 if err != nil {
94 t.Fatal(err)
95 return
96 }
97 defer dec.Close()
98 for i, tt := range zr.File {
99 if testing.Short() && i > 100 {
100 break
101 }
102 if !strings.HasSuffix(tt.Name, ".zst") {
103 continue
104 }
105 r, err := tt.Open()
106 if err != nil {
107 t.Fatal(err)
108 }
109 defer r.Close()
110 in, err := io.ReadAll(r)
111 if err != nil {
112 t.Fatal(err)
113 }
114 decoded, err := dec.DecodeAll(in, nil)
115 if err != nil {
116 t.Fatal(err)
117 }
118 if testing.Short() && len(decoded) > 1000 {
119 continue
120 }
121
122 t.Run("encodeall-"+tt.Name, func(t *testing.T) {
123
124 var b []byte
125 var tmp []byte
126 for i := range encs {
127 i := i
128 t.Run(encNames[i], func(t *testing.T) {
129 b = encs[i].EncodeAll(decoded, b[:0])
130 tmp, err = dec.DecodeAll(in, tmp[:0])
131 if err != nil {
132 t.Fatal(err)
133 }
134 if !bytes.Equal(tmp, decoded) {
135 t.Fatal("output mismatch")
136 }
137
138 tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
139
140 if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
141 t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
142
143 if len(b) > 250 {
144 t.Error("output was bigger than expected")
145 }
146 }
147 })
148 }
149 })
150 t.Run("stream-"+tt.Name, func(t *testing.T) {
151
152 var tmp []byte
153 for i := range encs {
154 i := i
155 enc := encs[i]
156 t.Run(encNames[i], func(t *testing.T) {
157 var buf bytes.Buffer
158 enc.ResetContentSize(&buf, int64(len(decoded)))
159 _, err := enc.Write(decoded)
160 if err != nil {
161 t.Fatal(err)
162 }
163 err = enc.Close()
164 if err != nil {
165 t.Fatal(err)
166 }
167 tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
168 if err != nil {
169 t.Fatal(err)
170 }
171 if !bytes.Equal(tmp, decoded) {
172 t.Fatal("output mismatch")
173 }
174 var buf2 bytes.Buffer
175 noDictEncs[i].Reset(&buf2)
176 noDictEncs[i].Write(decoded)
177 noDictEncs[i].Close()
178
179 if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
180 t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
181
182 if buf.Len() > 250 {
183 t.Error("output was bigger than expected")
184 }
185 }
186 })
187 }
188 })
189 }
190 }
191
192 func TestEncoder_SmallDictFresh(t *testing.T) {
193
194 zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
195 var dicts [][]byte
196 var encs []func() *Encoder
197 var noDictEncs []*Encoder
198 var encNames []string
199
200 for _, tt := range zr.File {
201 if !strings.HasSuffix(tt.Name, ".dict") {
202 continue
203 }
204 func() {
205 r, err := tt.Open()
206 if err != nil {
207 t.Fatal(err)
208 }
209 defer r.Close()
210 in, err := io.ReadAll(r)
211 if err != nil {
212 t.Fatal(err)
213 }
214 dicts = append(dicts, in)
215 for level := SpeedFastest; level < speedLast; level++ {
216 if isRaceTest && level >= SpeedBestCompression {
217 break
218 }
219 level := level
220 encs = append(encs, func() *Encoder {
221 enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
222 if err != nil {
223 t.Fatal(err)
224 }
225 return enc
226 })
227 encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
228
229 enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
230 if err != nil {
231 t.Fatal(err)
232 }
233 noDictEncs = append(noDictEncs, enc)
234 }
235 }()
236 }
237 dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
238 if err != nil {
239 t.Fatal(err)
240 return
241 }
242 defer dec.Close()
243 for i, tt := range zr.File {
244 if testing.Short() && i > 100 {
245 break
246 }
247 if !strings.HasSuffix(tt.Name, ".zst") {
248 continue
249 }
250 r, err := tt.Open()
251 if err != nil {
252 t.Fatal(err)
253 }
254 defer r.Close()
255 in, err := io.ReadAll(r)
256 if err != nil {
257 t.Fatal(err)
258 }
259 decoded, err := dec.DecodeAll(in, nil)
260 if err != nil {
261 t.Fatal(err)
262 }
263 if testing.Short() && len(decoded) > 1000 {
264 continue
265 }
266
267 t.Run("encodeall-"+tt.Name, func(t *testing.T) {
268
269 var b []byte
270 var tmp []byte
271 for i := range encs {
272 i := i
273 t.Run(encNames[i], func(t *testing.T) {
274 enc := encs[i]()
275 defer enc.Close()
276 b = enc.EncodeAll(decoded, b[:0])
277 tmp, err = dec.DecodeAll(in, tmp[:0])
278 if err != nil {
279 t.Fatal(err)
280 }
281 if !bytes.Equal(tmp, decoded) {
282 t.Fatal("output mismatch")
283 }
284
285 tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
286
287 if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
288 t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
289
290 if len(b) > 250 {
291 t.Error("output was bigger than expected")
292 }
293 }
294 })
295 }
296 })
297 t.Run("stream-"+tt.Name, func(t *testing.T) {
298
299 var tmp []byte
300 for i := range encs {
301 i := i
302 t.Run(encNames[i], func(t *testing.T) {
303 enc := encs[i]()
304 defer enc.Close()
305 var buf bytes.Buffer
306 enc.ResetContentSize(&buf, int64(len(decoded)))
307 _, err := enc.Write(decoded)
308 if err != nil {
309 t.Fatal(err)
310 }
311 err = enc.Close()
312 if err != nil {
313 t.Fatal(err)
314 }
315 tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
316 if err != nil {
317 t.Fatal(err)
318 }
319 if !bytes.Equal(tmp, decoded) {
320 t.Fatal("output mismatch")
321 }
322 var buf2 bytes.Buffer
323 noDictEncs[i].Reset(&buf2)
324 noDictEncs[i].Write(decoded)
325 noDictEncs[i].Close()
326
327 if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
328 t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
329
330 if buf.Len() > 250 {
331 t.Error("output was bigger than expected")
332 }
333 }
334 })
335 }
336 })
337 }
338 }
339
340 func benchmarkEncodeAllLimitedBySize(b *testing.B, lowerLimit int, upperLimit int) {
341 zr := testCreateZipReader("testdata/dict-tests-small.zip", b)
342 t := testing.TB(b)
343
344 var dicts [][]byte
345 var encs []*Encoder
346 var encNames []string
347
348 for _, tt := range zr.File {
349 if !strings.HasSuffix(tt.Name, ".dict") {
350 continue
351 }
352 func() {
353 r, err := tt.Open()
354 if err != nil {
355 t.Fatal(err)
356 }
357 defer r.Close()
358 in, err := io.ReadAll(r)
359 if err != nil {
360 t.Fatal(err)
361 }
362 dicts = append(dicts, in)
363 for level := SpeedFastest; level < speedLast; level++ {
364 enc, err := NewWriter(nil, WithEncoderDict(in), WithEncoderLevel(level))
365 if err != nil {
366 t.Fatal(err)
367 }
368 encs = append(encs, enc)
369 encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
370 }
371 }()
372 }
373 const nPer = int(speedLast - SpeedFastest)
374 dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
375 if err != nil {
376 t.Fatal(err)
377 return
378 }
379 defer dec.Close()
380
381 tested := make(map[int]struct{})
382 for j, tt := range zr.File {
383 if !strings.HasSuffix(tt.Name, ".zst") {
384 continue
385 }
386 r, err := tt.Open()
387 if err != nil {
388 t.Fatal(err)
389 }
390 defer r.Close()
391 in, err := io.ReadAll(r)
392 if err != nil {
393 t.Fatal(err)
394 }
395 decoded, err := dec.DecodeAll(in, nil)
396 if err != nil {
397 t.Fatal(err)
398 }
399
400
401 if _, ok := tested[len(decoded)]; ok {
402 continue
403 }
404 tested[len(decoded)] = struct{}{}
405
406 if len(decoded) < lowerLimit {
407 continue
408 }
409
410 if upperLimit > 0 && len(decoded) > upperLimit {
411 continue
412 }
413
414 for i := range encs {
415
416 if i == nPer-1 {
417 break
418 }
419
420 encIdx := (i + j*nPer) % len(encs)
421 enc := encs[encIdx]
422 b.Run(fmt.Sprintf("length-%d-%s", len(decoded), encNames[encIdx]), func(b *testing.B) {
423 b.RunParallel(func(pb *testing.PB) {
424 dst := make([]byte, 0, len(decoded)+10)
425 b.SetBytes(int64(len(decoded)))
426 b.ResetTimer()
427 b.ReportAllocs()
428 for pb.Next() {
429 dst = enc.EncodeAll(decoded, dst[:0])
430 }
431 })
432 })
433 }
434 }
435 }
436
437 func BenchmarkEncodeAllDict0_1024(b *testing.B) {
438 benchmarkEncodeAllLimitedBySize(b, 0, 1024)
439 }
440
441 func BenchmarkEncodeAllDict1024_8192(b *testing.B) {
442 benchmarkEncodeAllLimitedBySize(b, 1024, 8192)
443 }
444
445 func BenchmarkEncodeAllDict8192_16384(b *testing.B) {
446 benchmarkEncodeAllLimitedBySize(b, 8192, 16384)
447 }
448
449 func BenchmarkEncodeAllDict16384_65536(b *testing.B) {
450 benchmarkEncodeAllLimitedBySize(b, 16384, 65536)
451 }
452
453 func BenchmarkEncodeAllDict65536_0(b *testing.B) {
454 benchmarkEncodeAllLimitedBySize(b, 65536, 0)
455 }
456
457 func TestDecoder_MoreDicts(t *testing.T) {
458
459
460 fn := "testdata/zstd-dict-tests.zip"
461 data, err := os.ReadFile(fn)
462 if err != nil {
463 t.Skip("extended dict test not found.")
464 }
465 zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
466 if err != nil {
467 t.Fatal(err)
468 }
469
470 var dicts [][]byte
471 for _, tt := range zr.File {
472 if !strings.HasSuffix(tt.Name, ".dict") {
473 continue
474 }
475 func() {
476 r, err := tt.Open()
477 if err != nil {
478 t.Fatal(err)
479 }
480 defer r.Close()
481 in, err := io.ReadAll(r)
482 if err != nil {
483 t.Fatal(err)
484 }
485 dicts = append(dicts, in)
486 }()
487 }
488 dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
489 if err != nil {
490 t.Fatal(err)
491 return
492 }
493 defer dec.Close()
494 for i, tt := range zr.File {
495 if !strings.HasSuffix(tt.Name, ".zst") {
496 continue
497 }
498 if testing.Short() && i > 50 {
499 continue
500 }
501 t.Run("decodeall-"+tt.Name, func(t *testing.T) {
502 r, err := tt.Open()
503 if err != nil {
504 t.Fatal(err)
505 }
506 defer r.Close()
507 in, err := io.ReadAll(r)
508 if err != nil {
509 t.Fatal(err)
510 }
511 got, err := dec.DecodeAll(in, nil)
512 if err != nil {
513 t.Fatal(err)
514 }
515 _, err = dec.DecodeAll(in, got[:0])
516 if err != nil {
517 t.Fatal(err)
518 }
519 })
520 }
521 }
522
523 func TestDecoder_MoreDicts2(t *testing.T) {
524
525
526 fn := "testdata/zstd-dict-tests.zip"
527 data, err := os.ReadFile(fn)
528 if err != nil {
529 t.Skip("extended dict test not found.")
530 }
531 zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
532 if err != nil {
533 t.Fatal(err)
534 }
535
536 var dicts [][]byte
537 for _, tt := range zr.File {
538 if !strings.HasSuffix(tt.Name, ".dict") {
539 continue
540 }
541 func() {
542 r, err := tt.Open()
543 if err != nil {
544 t.Fatal(err)
545 }
546 defer r.Close()
547 in, err := io.ReadAll(r)
548 if err != nil {
549 t.Fatal(err)
550 }
551 dicts = append(dicts, in)
552 }()
553 }
554 dec, err := NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...))
555 if err != nil {
556 t.Fatal(err)
557 return
558 }
559 defer dec.Close()
560 for i, tt := range zr.File {
561 if !strings.HasSuffix(tt.Name, ".zst") {
562 continue
563 }
564 if testing.Short() && i > 50 {
565 continue
566 }
567 t.Run("decodeall-"+tt.Name, func(t *testing.T) {
568 r, err := tt.Open()
569 if err != nil {
570 t.Fatal(err)
571 }
572 defer r.Close()
573 in, err := io.ReadAll(r)
574 if err != nil {
575 t.Fatal(err)
576 }
577 got, err := dec.DecodeAll(in, nil)
578 if err != nil {
579 t.Fatal(err)
580 }
581 _, err = dec.DecodeAll(in, got[:0])
582 if err != nil {
583 t.Fatal(err)
584 }
585 })
586 }
587 }
588
589 func readDicts(tb testing.TB, zr *zip.Reader) [][]byte {
590 var dicts [][]byte
591 for _, tt := range zr.File {
592 if !strings.HasSuffix(tt.Name, ".dict") {
593 continue
594 }
595 func() {
596 r, err := tt.Open()
597 if err != nil {
598 tb.Fatal(err)
599 }
600 defer r.Close()
601 in, err := io.ReadAll(r)
602 if err != nil {
603 tb.Fatal(err)
604 }
605 dicts = append(dicts, in)
606 }()
607 }
608 return dicts
609 }
610
611
612 func TestDecoderRawDict(t *testing.T) {
613 t.Parallel()
614
615 dict, err := os.ReadFile("testdata/delta/source.txt")
616 if err != nil {
617 t.Fatal(err)
618 }
619
620 delta, err := os.Open("testdata/delta/target.txt.zst")
621 if err != nil {
622 t.Fatal(err)
623 }
624 defer delta.Close()
625
626 dec, err := NewReader(delta, WithDecoderDictRaw(0, dict))
627 if err != nil {
628 t.Fatal(err)
629 }
630
631 out, err := io.ReadAll(dec)
632 if err != nil {
633 t.Fatal(err)
634 }
635
636 ref, err := os.ReadFile("testdata/delta/target.txt")
637 if err != nil {
638 t.Fatal(err)
639 }
640
641 if !bytes.Equal(out, ref) {
642 t.Errorf("mismatch: got %q, wanted %q", out, ref)
643 }
644 }
645
View as plain text