1
2
3
4
5 package zstd
6
7 import (
8 "bytes"
9 "encoding/binary"
10 "errors"
11 "fmt"
12 "hash/crc32"
13 "io"
14 "os"
15 "path/filepath"
16 "sync"
17
18 "github.com/klauspost/compress/huff0"
19 "github.com/klauspost/compress/zstd/internal/xxhash"
20 )
21
22 type blockType uint8
23
24
25
26 const (
27 blockTypeRaw blockType = iota
28 blockTypeRLE
29 blockTypeCompressed
30 blockTypeReserved
31 )
32
33 type literalsBlockType uint8
34
35 const (
36 literalsBlockRaw literalsBlockType = iota
37 literalsBlockRLE
38 literalsBlockCompressed
39 literalsBlockTreeless
40 )
41
42 const (
43
44 maxCompressedBlockSize = 128 << 10
45
46 compressedBlockOverAlloc = 16
47 maxCompressedBlockSizeAlloc = 128<<10 + compressedBlockOverAlloc
48
49
50 maxBlockSize = (1 << 21) - 1
51
52 maxMatchLen = 131074
53 maxSequences = 0x7f00 + 0xffff
54
55
56
57 maxOffsetBits = 30
58 )
59
60 var (
61 huffDecoderPool = sync.Pool{New: func() interface{} {
62 return &huff0.Scratch{}
63 }}
64
65 fseDecoderPool = sync.Pool{New: func() interface{} {
66 return &fseDecoder{}
67 }}
68 )
69
70 type blockDec struct {
71
72 data []byte
73 dataStorage []byte
74
75
76 dst []byte
77
78
79 literalBuf []byte
80
81
82 WindowSize uint64
83
84 err error
85
86
87 checkCRC uint32
88 hasCRC bool
89
90
91
92 localFrame *frameDec
93
94 sequence []seqVals
95
96 async struct {
97 newHist *history
98 literals []byte
99 seqData []byte
100 seqSize int
101 fcs uint64
102 }
103
104
105 RLESize uint32
106
107 Type blockType
108
109
110 Last bool
111
112
113 lowMem bool
114 }
115
116 func (b *blockDec) String() string {
117 if b == nil {
118 return "<nil>"
119 }
120 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
121 }
122
123 func newBlockDec(lowMem bool) *blockDec {
124 b := blockDec{
125 lowMem: lowMem,
126 }
127 return &b
128 }
129
130
131
132 func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
133 b.WindowSize = windowSize
134 tmp, err := br.readSmall(3)
135 if err != nil {
136 println("Reading block header:", err)
137 return err
138 }
139 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
140 b.Last = bh&1 != 0
141 b.Type = blockType((bh >> 1) & 3)
142
143 cSize := int(bh >> 3)
144 maxSize := maxCompressedBlockSizeAlloc
145 switch b.Type {
146 case blockTypeReserved:
147 return ErrReservedBlockType
148 case blockTypeRLE:
149 if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) {
150 if debugDecoder {
151 printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b)
152 }
153 return ErrWindowSizeExceeded
154 }
155 b.RLESize = uint32(cSize)
156 if b.lowMem {
157 maxSize = cSize
158 }
159 cSize = 1
160 case blockTypeCompressed:
161 if debugDecoder {
162 println("Data size on stream:", cSize)
163 }
164 b.RLESize = 0
165 maxSize = maxCompressedBlockSizeAlloc
166 if windowSize < maxCompressedBlockSize && b.lowMem {
167 maxSize = int(windowSize) + compressedBlockOverAlloc
168 }
169 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
170 if debugDecoder {
171 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
172 }
173 return ErrCompressedSizeTooBig
174 }
175
176
177 if cSize < 2 {
178 return ErrBlockTooSmall
179 }
180 case blockTypeRaw:
181 if cSize > maxCompressedBlockSize || cSize > int(b.WindowSize) {
182 if debugDecoder {
183 printf("rle block too big: csize:%d block: %+v\n", uint64(cSize), b)
184 }
185 return ErrWindowSizeExceeded
186 }
187
188 b.RLESize = 0
189
190 maxSize = -1
191 default:
192 panic("Invalid block type")
193 }
194
195
196 if _, ok := br.(*byteBuf); !ok && cap(b.dataStorage) < cSize {
197
198 if b.lowMem || cSize > maxCompressedBlockSize {
199 b.dataStorage = make([]byte, 0, cSize+compressedBlockOverAlloc)
200 } else {
201 b.dataStorage = make([]byte, 0, maxCompressedBlockSizeAlloc)
202 }
203 }
204 b.data, err = br.readBig(cSize, b.dataStorage)
205 if err != nil {
206 if debugDecoder {
207 println("Reading block:", err, "(", cSize, ")", len(b.data))
208 printf("%T", br)
209 }
210 return err
211 }
212 if cap(b.dst) <= maxSize {
213 b.dst = make([]byte, 0, maxSize+1)
214 }
215 return nil
216 }
217
218
219 func (b *blockDec) sendErr(err error) {
220 b.Last = true
221 b.Type = blockTypeReserved
222 b.err = err
223 }
224
225
226
227 func (b *blockDec) Close() {
228 }
229
230
231 func (b *blockDec) decodeBuf(hist *history) error {
232 switch b.Type {
233 case blockTypeRLE:
234 if cap(b.dst) < int(b.RLESize) {
235 if b.lowMem {
236 b.dst = make([]byte, b.RLESize)
237 } else {
238 b.dst = make([]byte, maxCompressedBlockSize)
239 }
240 }
241 b.dst = b.dst[:b.RLESize]
242 v := b.data[0]
243 for i := range b.dst {
244 b.dst[i] = v
245 }
246 hist.appendKeep(b.dst)
247 return nil
248 case blockTypeRaw:
249 hist.appendKeep(b.data)
250 return nil
251 case blockTypeCompressed:
252 saved := b.dst
253
254 if hist.ignoreBuffer == 0 {
255 b.dst = hist.b
256 hist.b = nil
257 } else {
258 b.dst = b.dst[:0]
259 }
260 err := b.decodeCompressed(hist)
261 if debugDecoder {
262 println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err)
263 }
264 if hist.ignoreBuffer == 0 {
265 hist.b = b.dst
266 b.dst = saved
267 } else {
268 hist.appendKeep(b.dst)
269 }
270 return err
271 case blockTypeReserved:
272
273 return b.err
274 default:
275 panic("Invalid block type")
276 }
277 }
278
279 func (b *blockDec) decodeLiterals(in []byte, hist *history) (remain []byte, err error) {
280
281 if len(in) < 2 {
282 return in, ErrBlockTooSmall
283 }
284
285 litType := literalsBlockType(in[0] & 3)
286 var litRegenSize int
287 var litCompSize int
288 sizeFormat := (in[0] >> 2) & 3
289 var fourStreams bool
290 var literals []byte
291 switch litType {
292 case literalsBlockRaw, literalsBlockRLE:
293 switch sizeFormat {
294 case 0, 2:
295
296 litRegenSize = int(in[0] >> 3)
297 in = in[1:]
298 case 1:
299
300 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
301 in = in[2:]
302 case 3:
303
304 if len(in) < 3 {
305 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
306 return in, ErrBlockTooSmall
307 }
308 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
309 in = in[3:]
310 }
311 case literalsBlockCompressed, literalsBlockTreeless:
312 switch sizeFormat {
313 case 0, 1:
314
315 if len(in) < 3 {
316 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
317 return in, ErrBlockTooSmall
318 }
319 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
320 litRegenSize = int(n & 1023)
321 litCompSize = int(n >> 10)
322 fourStreams = sizeFormat == 1
323 in = in[3:]
324 case 2:
325 fourStreams = true
326 if len(in) < 4 {
327 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
328 return in, ErrBlockTooSmall
329 }
330 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
331 litRegenSize = int(n & 16383)
332 litCompSize = int(n >> 14)
333 in = in[4:]
334 case 3:
335 fourStreams = true
336 if len(in) < 5 {
337 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
338 return in, ErrBlockTooSmall
339 }
340 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
341 litRegenSize = int(n & 262143)
342 litCompSize = int(n >> 18)
343 in = in[5:]
344 }
345 }
346 if debugDecoder {
347 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams)
348 }
349 if litRegenSize > int(b.WindowSize) || litRegenSize > maxCompressedBlockSize {
350 return in, ErrWindowSizeExceeded
351 }
352
353 switch litType {
354 case literalsBlockRaw:
355 if len(in) < litRegenSize {
356 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
357 return in, ErrBlockTooSmall
358 }
359 literals = in[:litRegenSize]
360 in = in[litRegenSize:]
361
362 case literalsBlockRLE:
363 if len(in) < 1 {
364 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
365 return in, ErrBlockTooSmall
366 }
367 if cap(b.literalBuf) < litRegenSize {
368 if b.lowMem {
369 b.literalBuf = make([]byte, litRegenSize, litRegenSize+compressedBlockOverAlloc)
370 } else {
371 b.literalBuf = make([]byte, litRegenSize, maxCompressedBlockSize+compressedBlockOverAlloc)
372 }
373 }
374 literals = b.literalBuf[:litRegenSize]
375 v := in[0]
376 for i := range literals {
377 literals[i] = v
378 }
379 in = in[1:]
380 if debugDecoder {
381 printf("Found %d RLE compressed literals\n", litRegenSize)
382 }
383 case literalsBlockTreeless:
384 if len(in) < litCompSize {
385 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
386 return in, ErrBlockTooSmall
387 }
388
389 literals = in[:litCompSize]
390 in = in[litCompSize:]
391 if debugDecoder {
392 printf("Found %d compressed literals\n", litCompSize)
393 }
394 huff := hist.huffTree
395 if huff == nil {
396 return in, errors.New("literal block was treeless, but no history was defined")
397 }
398
399 if cap(b.literalBuf) < litRegenSize {
400 if b.lowMem {
401 b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc)
402 } else {
403 b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc)
404 }
405 }
406 var err error
407
408 huff.MaxDecodedSize = litRegenSize
409 if fourStreams {
410 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
411 } else {
412 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
413 }
414
415 if err != nil {
416 println("decompressing literals:", err)
417 return in, err
418 }
419 if len(literals) != litRegenSize {
420 return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
421 }
422
423 case literalsBlockCompressed:
424 if len(in) < litCompSize {
425 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
426 return in, ErrBlockTooSmall
427 }
428 literals = in[:litCompSize]
429 in = in[litCompSize:]
430
431 if cap(b.literalBuf) < litRegenSize {
432 if b.lowMem {
433 b.literalBuf = make([]byte, 0, litRegenSize+compressedBlockOverAlloc)
434 } else {
435 b.literalBuf = make([]byte, 0, maxCompressedBlockSize+compressedBlockOverAlloc)
436 }
437 }
438 huff := hist.huffTree
439 if huff == nil || (hist.dict != nil && huff == hist.dict.litEnc) {
440 huff = huffDecoderPool.Get().(*huff0.Scratch)
441 if huff == nil {
442 huff = &huff0.Scratch{}
443 }
444 }
445 var err error
446 if debugDecoder {
447 println("huff table input:", len(literals), "CRC:", crc32.ChecksumIEEE(literals))
448 }
449 huff, literals, err = huff0.ReadTable(literals, huff)
450 if err != nil {
451 println("reading huffman table:", err)
452 return in, err
453 }
454 hist.huffTree = huff
455 huff.MaxDecodedSize = litRegenSize
456
457 if fourStreams {
458 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
459 } else {
460 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
461 }
462 if err != nil {
463 println("decoding compressed literals:", err)
464 return in, err
465 }
466
467 if len(literals) != litRegenSize {
468 return in, fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
469 }
470
471 literals = b.literalBuf[:len(literals)]
472 if debugDecoder {
473 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
474 }
475 }
476 hist.decoders.literals = literals
477 return in, nil
478 }
479
480
481 func (b *blockDec) decodeCompressed(hist *history) error {
482 in := b.data
483 in, err := b.decodeLiterals(in, hist)
484 if err != nil {
485 return err
486 }
487 err = b.prepareSequences(in, hist)
488 if err != nil {
489 return err
490 }
491 if hist.decoders.nSeqs == 0 {
492 b.dst = append(b.dst, hist.decoders.literals...)
493 return nil
494 }
495 before := len(hist.decoders.out)
496 err = hist.decoders.decodeSync(hist.b[hist.ignoreBuffer:])
497 if err != nil {
498 return err
499 }
500 if hist.decoders.maxSyncLen > 0 {
501 hist.decoders.maxSyncLen += uint64(before)
502 hist.decoders.maxSyncLen -= uint64(len(hist.decoders.out))
503 }
504 b.dst = hist.decoders.out
505 hist.recentOffsets = hist.decoders.prevOffset
506 return nil
507 }
508
509 func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
510 if debugDecoder {
511 printf("prepareSequences: %d byte(s) input\n", len(in))
512 }
513
514
515 if len(in) < 1 {
516 return ErrBlockTooSmall
517 }
518 var nSeqs int
519 seqHeader := in[0]
520 switch {
521 case seqHeader < 128:
522 nSeqs = int(seqHeader)
523 in = in[1:]
524 case seqHeader < 255:
525 if len(in) < 2 {
526 return ErrBlockTooSmall
527 }
528 nSeqs = int(seqHeader-128)<<8 | int(in[1])
529 in = in[2:]
530 case seqHeader == 255:
531 if len(in) < 3 {
532 return ErrBlockTooSmall
533 }
534 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
535 in = in[3:]
536 }
537 if nSeqs == 0 && len(in) != 0 {
538
539 if debugDecoder {
540 printf("prepareSequences: 0 sequences, but %d byte(s) left on stream\n", len(in))
541 }
542 return ErrUnexpectedBlockSize
543 }
544
545 var seqs = &hist.decoders
546 seqs.nSeqs = nSeqs
547 if nSeqs > 0 {
548 if len(in) < 1 {
549 return ErrBlockTooSmall
550 }
551 br := byteReader{b: in, off: 0}
552 compMode := br.Uint8()
553 br.advance(1)
554 if debugDecoder {
555 printf("Compression modes: 0b%b", compMode)
556 }
557 if compMode&3 != 0 {
558 return errors.New("corrupt block: reserved bits not zero")
559 }
560 for i := uint(0); i < 3; i++ {
561 mode := seqCompMode((compMode >> (6 - i*2)) & 3)
562 if debugDecoder {
563 println("Table", tableIndex(i), "is", mode)
564 }
565 var seq *sequenceDec
566 switch tableIndex(i) {
567 case tableLiteralLengths:
568 seq = &seqs.litLengths
569 case tableOffsets:
570 seq = &seqs.offsets
571 case tableMatchLengths:
572 seq = &seqs.matchLengths
573 default:
574 panic("unknown table")
575 }
576 switch mode {
577 case compModePredefined:
578 if seq.fse != nil && !seq.fse.preDefined {
579 fseDecoderPool.Put(seq.fse)
580 }
581 seq.fse = &fsePredef[i]
582 case compModeRLE:
583 if br.remain() < 1 {
584 return ErrBlockTooSmall
585 }
586 v := br.Uint8()
587 br.advance(1)
588 if seq.fse == nil || seq.fse.preDefined {
589 seq.fse = fseDecoderPool.Get().(*fseDecoder)
590 }
591 symb, err := decSymbolValue(v, symbolTableX[i])
592 if err != nil {
593 printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
594 return err
595 }
596 seq.fse.setRLE(symb)
597 if debugDecoder {
598 printf("RLE set to 0x%x, code: %v", symb, v)
599 }
600 case compModeFSE:
601 println("Reading table for", tableIndex(i))
602 if seq.fse == nil || seq.fse.preDefined {
603 seq.fse = fseDecoderPool.Get().(*fseDecoder)
604 }
605 err := seq.fse.readNCount(&br, uint16(maxTableSymbol[i]))
606 if err != nil {
607 println("Read table error:", err)
608 return err
609 }
610 err = seq.fse.transform(symbolTableX[i])
611 if err != nil {
612 println("Transform table error:", err)
613 return err
614 }
615 if debugDecoder {
616 println("Read table ok", "symbolLen:", seq.fse.symbolLen)
617 }
618 case compModeRepeat:
619 seq.repeat = true
620 }
621 if br.overread() {
622 return io.ErrUnexpectedEOF
623 }
624 }
625 in = br.unread()
626 }
627 if debugDecoder {
628 println("Literals:", len(seqs.literals), "hash:", xxhash.Sum64(seqs.literals), "and", seqs.nSeqs, "sequences.")
629 }
630
631 if nSeqs == 0 {
632 if len(b.sequence) > 0 {
633 b.sequence = b.sequence[:0]
634 }
635 return nil
636 }
637 br := seqs.br
638 if br == nil {
639 br = &bitReader{}
640 }
641 if err := br.init(in); err != nil {
642 return err
643 }
644
645 if err := seqs.initialize(br, hist, b.dst); err != nil {
646 println("initializing sequences:", err)
647 return err
648 }
649
650 if false && hist.dict == nil {
651 fatalErr := func(err error) {
652 if err != nil {
653 panic(err)
654 }
655 }
656 fn := fmt.Sprintf("n-%d-lits-%d-prev-%d-%d-%d-win-%d.blk", hist.decoders.nSeqs, len(hist.decoders.literals), hist.recentOffsets[0], hist.recentOffsets[1], hist.recentOffsets[2], hist.windowSize)
657 var buf bytes.Buffer
658 fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.litLengths.fse))
659 fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
660 fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
661 buf.Write(in)
662 os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
663 }
664
665 return nil
666 }
667
668 func (b *blockDec) decodeSequences(hist *history) error {
669 if cap(b.sequence) < hist.decoders.nSeqs {
670 if b.lowMem {
671 b.sequence = make([]seqVals, 0, hist.decoders.nSeqs)
672 } else {
673 b.sequence = make([]seqVals, 0, 0x7F00+0xffff)
674 }
675 }
676 b.sequence = b.sequence[:hist.decoders.nSeqs]
677 if hist.decoders.nSeqs == 0 {
678 hist.decoders.seqSize = len(hist.decoders.literals)
679 return nil
680 }
681 hist.decoders.windowSize = hist.windowSize
682 hist.decoders.prevOffset = hist.recentOffsets
683
684 err := hist.decoders.decode(b.sequence)
685 hist.recentOffsets = hist.decoders.prevOffset
686 return err
687 }
688
689 func (b *blockDec) executeSequences(hist *history) error {
690 hbytes := hist.b
691 if len(hbytes) > hist.windowSize {
692 hbytes = hbytes[len(hbytes)-hist.windowSize:]
693
694 if hist.dict != nil {
695 hist.dict.content = nil
696 }
697 }
698 hist.decoders.windowSize = hist.windowSize
699 hist.decoders.out = b.dst[:0]
700 err := hist.decoders.execute(b.sequence, hbytes)
701 if err != nil {
702 return err
703 }
704 return b.updateHistory(hist)
705 }
706
707 func (b *blockDec) updateHistory(hist *history) error {
708 if len(b.data) > maxCompressedBlockSize {
709 return fmt.Errorf("compressed block size too large (%d)", len(b.data))
710 }
711
712 b.dst = hist.decoders.out
713 hist.recentOffsets = hist.decoders.prevOffset
714
715 if b.Last {
716
717 println("Last block, no history returned")
718 hist.b = hist.b[:0]
719 return nil
720 } else {
721 hist.append(b.dst)
722 if debugDecoder {
723 println("Finished block with ", len(b.sequence), "sequences. Added", len(b.dst), "to history, now length", len(hist.b))
724 }
725 }
726 hist.decoders.out, hist.decoders.literals = nil, nil
727
728 return nil
729 }
730
View as plain text