1
2
3
4
5 package zstd
6
7 import (
8 "context"
9 "encoding/binary"
10 "io"
11 "sync"
12
13 "github.com/klauspost/compress/zstd/internal/xxhash"
14 )
15
16
17
18
19
20
21
22 type Decoder struct {
23 o decoderOptions
24
25
26 decoders chan *blockDec
27
28
29 current decoderState
30
31
32 syncStream struct {
33 decodedFrame uint64
34 br readerWrapper
35 enabled bool
36 inFrame bool
37 dstBuf []byte
38 }
39
40 frame *frameDec
41
42
43 dicts map[uint32]*dict
44
45
46 streamWg sync.WaitGroup
47 }
48
49
50
51 type decoderState struct {
52
53 decodeOutput
54
55
56 output chan decodeOutput
57
58
59 cancel context.CancelFunc
60
61
62 crc *xxhash.Digest
63
64 flushed bool
65 }
66
67 var (
68
69 _ = io.WriterTo(&Decoder{})
70 _ = io.Reader(&Decoder{})
71 )
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
88 initPredefined()
89 var d Decoder
90 d.o.setDefault()
91 for _, o := range opts {
92 err := o(&d.o)
93 if err != nil {
94 return nil, err
95 }
96 }
97 d.current.crc = xxhash.New()
98 d.current.flushed = true
99
100 if r == nil {
101 d.current.err = ErrDecoderNilInput
102 }
103
104
105 d.dicts = make(map[uint32]*dict, len(d.o.dicts))
106 for _, dc := range d.o.dicts {
107 d.dicts[dc.id] = dc
108 }
109 d.o.dicts = nil
110
111
112 d.decoders = make(chan *blockDec, d.o.concurrent)
113 for i := 0; i < d.o.concurrent; i++ {
114 dec := newBlockDec(d.o.lowMem)
115 dec.localFrame = newFrameDec(d.o)
116 d.decoders <- dec
117 }
118
119 if r == nil {
120 return &d, nil
121 }
122 return &d, d.Reset(r)
123 }
124
125
126
127
128 func (d *Decoder) Read(p []byte) (int, error) {
129 var n int
130 for {
131 if len(d.current.b) > 0 {
132 filled := copy(p, d.current.b)
133 p = p[filled:]
134 d.current.b = d.current.b[filled:]
135 n += filled
136 }
137 if len(p) == 0 {
138 break
139 }
140 if len(d.current.b) == 0 {
141
142 if d.current.err != nil {
143 break
144 }
145 if !d.nextBlock(n == 0) {
146 return n, d.current.err
147 }
148 }
149 }
150 if len(d.current.b) > 0 {
151 if debugDecoder {
152 println("returning", n, "still bytes left:", len(d.current.b))
153 }
154
155 return n, nil
156 }
157 if d.current.err != nil {
158 d.drainOutput()
159 }
160 if debugDecoder {
161 println("returning", n, d.current.err, len(d.decoders))
162 }
163 return n, d.current.err
164 }
165
166
167
168
169
170
171 func (d *Decoder) Reset(r io.Reader) error {
172 if d.current.err == ErrDecoderClosed {
173 return d.current.err
174 }
175
176 d.drainOutput()
177
178 d.syncStream.br.r = nil
179 if r == nil {
180 d.current.err = ErrDecoderNilInput
181 if len(d.current.b) > 0 {
182 d.current.b = d.current.b[:0]
183 }
184 d.current.flushed = true
185 return nil
186 }
187
188
189 if bb, ok := r.(byter); ok && bb.Len() < d.o.decodeBufsBelow && !d.o.limitToCap {
190 bb2 := bb
191 if debugDecoder {
192 println("*bytes.Buffer detected, doing sync decode, len:", bb.Len())
193 }
194 b := bb2.Bytes()
195 var dst []byte
196 if cap(d.syncStream.dstBuf) > 0 {
197 dst = d.syncStream.dstBuf[:0]
198 }
199
200 dst, err := d.DecodeAll(b, dst)
201 if err == nil {
202 err = io.EOF
203 }
204
205 d.syncStream.dstBuf = dst
206 d.current.b = dst
207 d.current.err = err
208 d.current.flushed = true
209 if debugDecoder {
210 println("sync decode to", len(dst), "bytes, err:", err)
211 }
212 return nil
213 }
214
215 d.stashDecoder()
216 d.current.decodeOutput = decodeOutput{}
217 d.current.err = nil
218 d.current.flushed = false
219 d.current.d = nil
220 d.syncStream.dstBuf = nil
221
222
223 d.streamWg.Wait()
224 if d.frame == nil {
225 d.frame = newFrameDec(d.o)
226 }
227
228 if d.o.concurrent == 1 {
229 return d.startSyncDecoder(r)
230 }
231
232 d.current.output = make(chan decodeOutput, d.o.concurrent)
233 ctx, cancel := context.WithCancel(context.Background())
234 d.current.cancel = cancel
235 d.streamWg.Add(1)
236 go d.startStreamDecoder(ctx, r, d.current.output)
237
238 return nil
239 }
240
241
242 func (d *Decoder) drainOutput() {
243 if d.current.cancel != nil {
244 if debugDecoder {
245 println("cancelling current")
246 }
247 d.current.cancel()
248 d.current.cancel = nil
249 }
250 if d.current.d != nil {
251 if debugDecoder {
252 printf("re-adding current decoder %p, decoders: %d", d.current.d, len(d.decoders))
253 }
254 d.decoders <- d.current.d
255 d.current.d = nil
256 d.current.b = nil
257 }
258 if d.current.output == nil || d.current.flushed {
259 println("current already flushed")
260 return
261 }
262 for v := range d.current.output {
263 if v.d != nil {
264 if debugDecoder {
265 printf("re-adding decoder %p", v.d)
266 }
267 d.decoders <- v.d
268 }
269 }
270 d.current.output = nil
271 d.current.flushed = true
272 }
273
274
275
276
277 func (d *Decoder) WriteTo(w io.Writer) (int64, error) {
278 var n int64
279 for {
280 if len(d.current.b) > 0 {
281 n2, err2 := w.Write(d.current.b)
282 n += int64(n2)
283 if err2 != nil && (d.current.err == nil || d.current.err == io.EOF) {
284 d.current.err = err2
285 } else if n2 != len(d.current.b) {
286 d.current.err = io.ErrShortWrite
287 }
288 }
289 if d.current.err != nil {
290 break
291 }
292 d.nextBlock(true)
293 }
294 err := d.current.err
295 if err != nil {
296 d.drainOutput()
297 }
298 if err == io.EOF {
299 err = nil
300 }
301 return n, err
302 }
303
304
305
306
307
308
309 func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
310 if d.decoders == nil {
311 return dst, ErrDecoderClosed
312 }
313
314
315 block := <-d.decoders
316 frame := block.localFrame
317 initialSize := len(dst)
318 defer func() {
319 if debugDecoder {
320 printf("re-adding decoder: %p", block)
321 }
322 frame.rawInput = nil
323 frame.bBuf = nil
324 if frame.history.decoders.br != nil {
325 frame.history.decoders.br.in = nil
326 }
327 d.decoders <- block
328 }()
329 frame.bBuf = input
330
331 for {
332 frame.history.reset()
333 err := frame.reset(&frame.bBuf)
334 if err != nil {
335 if err == io.EOF {
336 if debugDecoder {
337 println("frame reset return EOF")
338 }
339 return dst, nil
340 }
341 return dst, err
342 }
343 if err = d.setDict(frame); err != nil {
344 return nil, err
345 }
346 if frame.WindowSize > d.o.maxWindowSize {
347 if debugDecoder {
348 println("window size exceeded:", frame.WindowSize, ">", d.o.maxWindowSize)
349 }
350 return dst, ErrWindowSizeExceeded
351 }
352 if frame.FrameContentSize != fcsUnknown {
353 if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
354 if debugDecoder {
355 println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
356 }
357 return dst, ErrDecoderSizeExceeded
358 }
359 if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
360 if debugDecoder {
361 println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
362 }
363 return dst, ErrDecoderSizeExceeded
364 }
365 if cap(dst)-len(dst) < int(frame.FrameContentSize) {
366 dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)+compressedBlockOverAlloc)
367 copy(dst2, dst)
368 dst = dst2
369 }
370 }
371
372 if cap(dst) == 0 && !d.o.limitToCap {
373
374
375 size := len(input) * 2
376
377 if size > 1<<20 {
378 size = 1 << 20
379 }
380 if uint64(size) > d.o.maxDecodedSize {
381 size = int(d.o.maxDecodedSize)
382 }
383 dst = make([]byte, 0, size)
384 }
385
386 dst, err = frame.runDecoder(dst, block)
387 if err != nil {
388 return dst, err
389 }
390 if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
391 return dst, ErrDecoderSizeExceeded
392 }
393 if len(frame.bBuf) == 0 {
394 if debugDecoder {
395 println("frame dbuf empty")
396 }
397 break
398 }
399 }
400 return dst, nil
401 }
402
403
404
405
406
407
408 func (d *Decoder) nextBlock(blocking bool) (ok bool) {
409 if d.current.err != nil {
410
411 return false
412 }
413 d.current.b = d.current.b[:0]
414
415
416 if d.syncStream.enabled {
417 if !blocking {
418 return false
419 }
420 ok = d.nextBlockSync()
421 if !ok {
422 d.stashDecoder()
423 }
424 return ok
425 }
426
427
428 d.stashDecoder()
429 if blocking {
430 d.current.decodeOutput, ok = <-d.current.output
431 } else {
432 select {
433 case d.current.decodeOutput, ok = <-d.current.output:
434 default:
435 return false
436 }
437 }
438 if !ok {
439
440 d.current.err = io.ErrUnexpectedEOF
441 return false
442 }
443 next := d.current.decodeOutput
444 if next.d != nil && next.d.async.newHist != nil {
445 d.current.crc.Reset()
446 }
447 if debugDecoder {
448 var tmp [4]byte
449 binary.LittleEndian.PutUint32(tmp[:], uint32(xxhash.Sum64(next.b)))
450 println("got", len(d.current.b), "bytes, error:", d.current.err, "data crc:", tmp)
451 }
452
453 if d.o.ignoreChecksum {
454 return true
455 }
456
457 if len(next.b) > 0 {
458 d.current.crc.Write(next.b)
459 }
460 if next.err == nil && next.d != nil && next.d.hasCRC {
461 got := uint32(d.current.crc.Sum64())
462 if got != next.d.checkCRC {
463 if debugDecoder {
464 printf("CRC Check Failed: %08x (got) != %08x (on stream)\n", got, next.d.checkCRC)
465 }
466 d.current.err = ErrCRCMismatch
467 } else {
468 if debugDecoder {
469 printf("CRC ok %08x\n", got)
470 }
471 }
472 }
473
474 return true
475 }
476
477 func (d *Decoder) nextBlockSync() (ok bool) {
478 if d.current.d == nil {
479 d.current.d = <-d.decoders
480 }
481 for len(d.current.b) == 0 {
482 if !d.syncStream.inFrame {
483 d.frame.history.reset()
484 d.current.err = d.frame.reset(&d.syncStream.br)
485 if d.current.err == nil {
486 d.current.err = d.setDict(d.frame)
487 }
488 if d.current.err != nil {
489 return false
490 }
491 if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize {
492 d.current.err = ErrDecoderSizeExceeded
493 return false
494 }
495
496 d.syncStream.decodedFrame = 0
497 d.syncStream.inFrame = true
498 }
499 d.current.err = d.frame.next(d.current.d)
500 if d.current.err != nil {
501 return false
502 }
503 d.frame.history.ensureBlock()
504 if debugDecoder {
505 println("History trimmed:", len(d.frame.history.b), "decoded already:", d.syncStream.decodedFrame)
506 }
507 histBefore := len(d.frame.history.b)
508 d.current.err = d.current.d.decodeBuf(&d.frame.history)
509
510 if d.current.err != nil {
511 println("error after:", d.current.err)
512 return false
513 }
514 d.current.b = d.frame.history.b[histBefore:]
515 if debugDecoder {
516 println("history after:", len(d.frame.history.b))
517 }
518
519
520 d.syncStream.decodedFrame += uint64(len(d.current.b))
521 if d.syncStream.decodedFrame > d.frame.FrameContentSize {
522 if debugDecoder {
523 printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
524 }
525 d.current.err = ErrFrameSizeExceeded
526 return false
527 }
528
529
530 if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize {
531 if debugDecoder {
532 printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
533 }
534 d.current.err = ErrFrameSizeMismatch
535 return false
536 }
537
538
539 if d.frame.HasCheckSum {
540 if !d.o.ignoreChecksum {
541 d.frame.crc.Write(d.current.b)
542 }
543 if d.current.d.Last {
544 if !d.o.ignoreChecksum {
545 d.current.err = d.frame.checkCRC()
546 } else {
547 d.current.err = d.frame.consumeCRC()
548 }
549 if d.current.err != nil {
550 println("CRC error:", d.current.err)
551 return false
552 }
553 }
554 }
555 d.syncStream.inFrame = !d.current.d.Last
556 }
557 return true
558 }
559
560 func (d *Decoder) stashDecoder() {
561 if d.current.d != nil {
562 if debugDecoder {
563 printf("re-adding current decoder %p", d.current.d)
564 }
565 d.decoders <- d.current.d
566 d.current.d = nil
567 }
568 }
569
570
571
572 func (d *Decoder) Close() {
573 if d.current.err == ErrDecoderClosed {
574 return
575 }
576 d.drainOutput()
577 if d.current.cancel != nil {
578 d.current.cancel()
579 d.streamWg.Wait()
580 d.current.cancel = nil
581 }
582 if d.decoders != nil {
583 close(d.decoders)
584 for dec := range d.decoders {
585 dec.Close()
586 }
587 d.decoders = nil
588 }
589 if d.current.d != nil {
590 d.current.d.Close()
591 d.current.d = nil
592 }
593 d.current.err = ErrDecoderClosed
594 }
595
596
597
598
599
600 func (d *Decoder) IOReadCloser() io.ReadCloser {
601 return closeWrapper{d: d}
602 }
603
604
605 type closeWrapper struct {
606 d *Decoder
607 }
608
609
610 func (c closeWrapper) WriteTo(w io.Writer) (n int64, err error) {
611 return c.d.WriteTo(w)
612 }
613
614
615 func (c closeWrapper) Read(p []byte) (n int, err error) {
616 return c.d.Read(p)
617 }
618
619
620 func (c closeWrapper) Close() error {
621 c.d.Close()
622 return nil
623 }
624
625 type decodeOutput struct {
626 d *blockDec
627 b []byte
628 err error
629 }
630
631 func (d *Decoder) startSyncDecoder(r io.Reader) error {
632 d.frame.history.reset()
633 d.syncStream.br = readerWrapper{r: r}
634 d.syncStream.inFrame = false
635 d.syncStream.enabled = true
636 d.syncStream.decodedFrame = 0
637 return nil
638 }
639
640
641
642
643
644
645
646 func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output chan decodeOutput) {
647 defer d.streamWg.Done()
648 br := readerWrapper{r: r}
649
650 var seqDecode = make(chan *blockDec, d.o.concurrent)
651 var seqExecute = make(chan *blockDec, d.o.concurrent)
652
653
654 go func() {
655 var hist history
656 var hasErr bool
657
658 for block := range seqDecode {
659 if hasErr {
660 if block != nil {
661 seqExecute <- block
662 }
663 continue
664 }
665 if block.async.newHist != nil {
666 if debugDecoder {
667 println("Async 1: new history, recent:", block.async.newHist.recentOffsets)
668 }
669 hist.reset()
670 hist.decoders = block.async.newHist.decoders
671 hist.recentOffsets = block.async.newHist.recentOffsets
672 hist.windowSize = block.async.newHist.windowSize
673 if block.async.newHist.dict != nil {
674 hist.setDict(block.async.newHist.dict)
675 }
676 }
677 if block.err != nil || block.Type != blockTypeCompressed {
678 hasErr = block.err != nil
679 seqExecute <- block
680 continue
681 }
682
683 hist.decoders.literals = block.async.literals
684 block.err = block.prepareSequences(block.async.seqData, &hist)
685 if debugDecoder && block.err != nil {
686 println("prepareSequences returned:", block.err)
687 }
688 hasErr = block.err != nil
689 if block.err == nil {
690 block.err = block.decodeSequences(&hist)
691 if debugDecoder && block.err != nil {
692 println("decodeSequences returned:", block.err)
693 }
694 hasErr = block.err != nil
695
696 block.async.seqSize = hist.decoders.seqSize
697 }
698 seqExecute <- block
699 }
700 close(seqExecute)
701 hist.reset()
702 }()
703
704 var wg sync.WaitGroup
705 wg.Add(1)
706
707
708 frameHistCache := d.frame.history.b
709 go func() {
710 var hist history
711 var decodedFrame uint64
712 var fcs uint64
713 var hasErr bool
714 for block := range seqExecute {
715 out := decodeOutput{err: block.err, d: block}
716 if block.err != nil || hasErr {
717 hasErr = true
718 output <- out
719 continue
720 }
721 if block.async.newHist != nil {
722 if debugDecoder {
723 println("Async 2: new history")
724 }
725 hist.reset()
726 hist.windowSize = block.async.newHist.windowSize
727 hist.allocFrameBuffer = block.async.newHist.allocFrameBuffer
728 if block.async.newHist.dict != nil {
729 hist.setDict(block.async.newHist.dict)
730 }
731
732 if cap(hist.b) < hist.allocFrameBuffer {
733 if cap(frameHistCache) >= hist.allocFrameBuffer {
734 hist.b = frameHistCache
735 } else {
736 hist.b = make([]byte, 0, hist.allocFrameBuffer)
737 println("Alloc history sized", hist.allocFrameBuffer)
738 }
739 }
740 hist.b = hist.b[:0]
741 fcs = block.async.fcs
742 decodedFrame = 0
743 }
744 do := decodeOutput{err: block.err, d: block}
745 switch block.Type {
746 case blockTypeRLE:
747 if debugDecoder {
748 println("add rle block length:", block.RLESize)
749 }
750
751 if cap(block.dst) < int(block.RLESize) {
752 if block.lowMem {
753 block.dst = make([]byte, block.RLESize)
754 } else {
755 block.dst = make([]byte, maxCompressedBlockSize)
756 }
757 }
758 block.dst = block.dst[:block.RLESize]
759 v := block.data[0]
760 for i := range block.dst {
761 block.dst[i] = v
762 }
763 hist.append(block.dst)
764 do.b = block.dst
765 case blockTypeRaw:
766 if debugDecoder {
767 println("add raw block length:", len(block.data))
768 }
769 hist.append(block.data)
770 do.b = block.data
771 case blockTypeCompressed:
772 if debugDecoder {
773 println("execute with history length:", len(hist.b), "window:", hist.windowSize)
774 }
775 hist.decoders.seqSize = block.async.seqSize
776 hist.decoders.literals = block.async.literals
777 do.err = block.executeSequences(&hist)
778 hasErr = do.err != nil
779 if debugDecoder && hasErr {
780 println("executeSequences returned:", do.err)
781 }
782 do.b = block.dst
783 }
784 if !hasErr {
785 decodedFrame += uint64(len(do.b))
786 if decodedFrame > fcs {
787 println("fcs exceeded", block.Last, fcs, decodedFrame)
788 do.err = ErrFrameSizeExceeded
789 hasErr = true
790 } else if block.Last && fcs != fcsUnknown && decodedFrame != fcs {
791 do.err = ErrFrameSizeMismatch
792 hasErr = true
793 } else {
794 if debugDecoder {
795 println("fcs ok", block.Last, fcs, decodedFrame)
796 }
797 }
798 }
799 output <- do
800 }
801 close(output)
802 frameHistCache = hist.b
803 wg.Done()
804 if debugDecoder {
805 println("decoder goroutines finished")
806 }
807 hist.reset()
808 }()
809
810 var hist history
811 decodeStream:
812 for {
813 var hasErr bool
814 hist.reset()
815 decodeBlock := func(block *blockDec) {
816 if hasErr {
817 if block != nil {
818 seqDecode <- block
819 }
820 return
821 }
822 if block.err != nil || block.Type != blockTypeCompressed {
823 hasErr = block.err != nil
824 seqDecode <- block
825 return
826 }
827
828 remain, err := block.decodeLiterals(block.data, &hist)
829 block.err = err
830 hasErr = block.err != nil
831 if err == nil {
832 block.async.literals = hist.decoders.literals
833 block.async.seqData = remain
834 } else if debugDecoder {
835 println("decodeLiterals error:", err)
836 }
837 seqDecode <- block
838 }
839 frame := d.frame
840 if debugDecoder {
841 println("New frame...")
842 }
843 var historySent bool
844 frame.history.reset()
845 err := frame.reset(&br)
846 if debugDecoder && err != nil {
847 println("Frame decoder returned", err)
848 }
849 if err == nil {
850 err = d.setDict(frame)
851 }
852 if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
853 if debugDecoder {
854 println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
855 }
856
857 err = ErrDecoderSizeExceeded
858 }
859 if err != nil {
860 select {
861 case <-ctx.Done():
862 case dec := <-d.decoders:
863 dec.sendErr(err)
864 decodeBlock(dec)
865 }
866 break decodeStream
867 }
868
869
870 for {
871 var dec *blockDec
872 select {
873 case <-ctx.Done():
874 break decodeStream
875 case dec = <-d.decoders:
876
877 }
878 err := frame.next(dec)
879 if !historySent {
880 h := frame.history
881 if debugDecoder {
882 println("Alloc History:", h.allocFrameBuffer)
883 }
884 hist.reset()
885 if h.dict != nil {
886 hist.setDict(h.dict)
887 }
888 dec.async.newHist = &h
889 dec.async.fcs = frame.FrameContentSize
890 historySent = true
891 } else {
892 dec.async.newHist = nil
893 }
894 if debugDecoder && err != nil {
895 println("next block returned error:", err)
896 }
897 dec.err = err
898 dec.hasCRC = false
899 if dec.Last && frame.HasCheckSum && err == nil {
900 crc, err := frame.rawInput.readSmall(4)
901 if len(crc) < 4 {
902 if err == nil {
903 err = io.ErrUnexpectedEOF
904
905 }
906 println("CRC missing?", err)
907 dec.err = err
908 } else {
909 dec.checkCRC = binary.LittleEndian.Uint32(crc)
910 dec.hasCRC = true
911 if debugDecoder {
912 printf("found crc to check: %08x\n", dec.checkCRC)
913 }
914 }
915 }
916 err = dec.err
917 last := dec.Last
918 decodeBlock(dec)
919 if err != nil {
920 break decodeStream
921 }
922 if last {
923 break
924 }
925 }
926 }
927 close(seqDecode)
928 wg.Wait()
929 hist.reset()
930 d.frame.history.b = frameHistCache
931 }
932
933 func (d *Decoder) setDict(frame *frameDec) (err error) {
934 dict, ok := d.dicts[frame.DictionaryID]
935 if ok {
936 if debugDecoder {
937 println("setting dict", frame.DictionaryID)
938 }
939 frame.history.setDict(dict)
940 } else if frame.DictionaryID != 0 {
941
942
943
944
945 err = ErrUnknownDictionary
946 }
947 return err
948 }
949
View as plain text