1
2
3
4
5 package zstd
6
7 import (
8 "crypto/rand"
9 "fmt"
10 "io"
11 "math"
12 rdebug "runtime/debug"
13 "sync"
14
15 "github.com/klauspost/compress/zstd/internal/xxhash"
16 )
17
18
19
20
21
22
23
24 type Encoder struct {
25 o encoderOptions
26 encoders chan encoder
27 state encoderState
28 init sync.Once
29 }
30
31 type encoder interface {
32 Encode(blk *blockEnc, src []byte)
33 EncodeNoHist(blk *blockEnc, src []byte)
34 Block() *blockEnc
35 CRC() *xxhash.Digest
36 AppendCRC([]byte) []byte
37 WindowSize(size int64) int32
38 UseBlock(*blockEnc)
39 Reset(d *dict, singleBlock bool)
40 }
41
42 type encoderState struct {
43 w io.Writer
44 filling []byte
45 current []byte
46 previous []byte
47 encoder encoder
48 writing *blockEnc
49 err error
50 writeErr error
51 nWritten int64
52 nInput int64
53 frameContentSize int64
54 headerWritten bool
55 eofWritten bool
56 fullFrameWritten bool
57
58
59 wg sync.WaitGroup
60
61 wWg sync.WaitGroup
62 }
63
64
65
66 func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
67 initPredefined()
68 var e Encoder
69 e.o.setDefault()
70 for _, o := range opts {
71 err := o(&e.o)
72 if err != nil {
73 return nil, err
74 }
75 }
76 if w != nil {
77 e.Reset(w)
78 }
79 return &e, nil
80 }
81
82 func (e *Encoder) initialize() {
83 if e.o.concurrent == 0 {
84 e.o.setDefault()
85 }
86 e.encoders = make(chan encoder, e.o.concurrent)
87 for i := 0; i < e.o.concurrent; i++ {
88 enc := e.o.encoder()
89 e.encoders <- enc
90 }
91 }
92
93
94
95 func (e *Encoder) Reset(w io.Writer) {
96 s := &e.state
97 s.wg.Wait()
98 s.wWg.Wait()
99 if cap(s.filling) == 0 {
100 s.filling = make([]byte, 0, e.o.blockSize)
101 }
102 if e.o.concurrent > 1 {
103 if cap(s.current) == 0 {
104 s.current = make([]byte, 0, e.o.blockSize)
105 }
106 if cap(s.previous) == 0 {
107 s.previous = make([]byte, 0, e.o.blockSize)
108 }
109 s.current = s.current[:0]
110 s.previous = s.previous[:0]
111 if s.writing == nil {
112 s.writing = &blockEnc{lowMem: e.o.lowMem}
113 s.writing.init()
114 }
115 s.writing.initNewEncode()
116 }
117 if s.encoder == nil {
118 s.encoder = e.o.encoder()
119 }
120 s.filling = s.filling[:0]
121 s.encoder.Reset(e.o.dict, false)
122 s.headerWritten = false
123 s.eofWritten = false
124 s.fullFrameWritten = false
125 s.w = w
126 s.err = nil
127 s.nWritten = 0
128 s.nInput = 0
129 s.writeErr = nil
130 s.frameContentSize = 0
131 }
132
133
134
135
136
137
138 func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
139 e.Reset(w)
140 if size >= 0 {
141 e.state.frameContentSize = size
142 }
143 }
144
145
146
147
148
149
150 func (e *Encoder) Write(p []byte) (n int, err error) {
151 s := &e.state
152 for len(p) > 0 {
153 if len(p)+len(s.filling) < e.o.blockSize {
154 if e.o.crc {
155 _, _ = s.encoder.CRC().Write(p)
156 }
157 s.filling = append(s.filling, p...)
158 return n + len(p), nil
159 }
160 add := p
161 if len(p)+len(s.filling) > e.o.blockSize {
162 add = add[:e.o.blockSize-len(s.filling)]
163 }
164 if e.o.crc {
165 _, _ = s.encoder.CRC().Write(add)
166 }
167 s.filling = append(s.filling, add...)
168 p = p[len(add):]
169 n += len(add)
170 if len(s.filling) < e.o.blockSize {
171 return n, nil
172 }
173 err := e.nextBlock(false)
174 if err != nil {
175 return n, err
176 }
177 if debugAsserts && len(s.filling) > 0 {
178 panic(len(s.filling))
179 }
180 }
181 return n, nil
182 }
183
184
185
186 func (e *Encoder) nextBlock(final bool) error {
187 s := &e.state
188
189 s.wg.Wait()
190 if s.err != nil {
191 return s.err
192 }
193 if len(s.filling) > e.o.blockSize {
194 return fmt.Errorf("block > maxStoreBlockSize")
195 }
196 if !s.headerWritten {
197
198 if final && len(s.filling) == 0 && !e.o.fullZero {
199 s.headerWritten = true
200 s.fullFrameWritten = true
201 s.eofWritten = true
202 return nil
203 }
204 if final && len(s.filling) > 0 {
205 s.current = e.EncodeAll(s.filling, s.current[:0])
206 var n2 int
207 n2, s.err = s.w.Write(s.current)
208 if s.err != nil {
209 return s.err
210 }
211 s.nWritten += int64(n2)
212 s.nInput += int64(len(s.filling))
213 s.current = s.current[:0]
214 s.filling = s.filling[:0]
215 s.headerWritten = true
216 s.fullFrameWritten = true
217 s.eofWritten = true
218 return nil
219 }
220
221 var tmp [maxHeaderSize]byte
222 fh := frameHeader{
223 ContentSize: uint64(s.frameContentSize),
224 WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)),
225 SingleSegment: false,
226 Checksum: e.o.crc,
227 DictID: e.o.dict.ID(),
228 }
229
230 dst := fh.appendTo(tmp[:0])
231 s.headerWritten = true
232 s.wWg.Wait()
233 var n2 int
234 n2, s.err = s.w.Write(dst)
235 if s.err != nil {
236 return s.err
237 }
238 s.nWritten += int64(n2)
239 }
240 if s.eofWritten {
241
242 final = false
243 }
244
245 if len(s.filling) == 0 {
246
247 if final {
248 enc := s.encoder
249 blk := enc.Block()
250 blk.reset(nil)
251 blk.last = true
252 blk.encodeRaw(nil)
253 s.wWg.Wait()
254 _, s.err = s.w.Write(blk.output)
255 s.nWritten += int64(len(blk.output))
256 s.eofWritten = true
257 }
258 return s.err
259 }
260
261
262 if e.o.concurrent == 1 {
263 src := s.filling
264 s.nInput += int64(len(s.filling))
265 if debugEncoder {
266 println("Adding sync block,", len(src), "bytes, final:", final)
267 }
268 enc := s.encoder
269 blk := enc.Block()
270 blk.reset(nil)
271 enc.Encode(blk, src)
272 blk.last = final
273 if final {
274 s.eofWritten = true
275 }
276
277 s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
278 if s.err != nil {
279 return s.err
280 }
281 _, s.err = s.w.Write(blk.output)
282 s.nWritten += int64(len(blk.output))
283 s.filling = s.filling[:0]
284 return s.err
285 }
286
287
288 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
289 s.nInput += int64(len(s.current))
290 s.wg.Add(1)
291 go func(src []byte) {
292 if debugEncoder {
293 println("Adding block,", len(src), "bytes, final:", final)
294 }
295 defer func() {
296 if r := recover(); r != nil {
297 s.err = fmt.Errorf("panic while encoding: %v", r)
298 rdebug.PrintStack()
299 }
300 s.wg.Done()
301 }()
302 enc := s.encoder
303 blk := enc.Block()
304 enc.Encode(blk, src)
305 blk.last = final
306 if final {
307 s.eofWritten = true
308 }
309
310 s.wWg.Wait()
311 if s.writeErr != nil {
312 s.err = s.writeErr
313 return
314 }
315
316 blk.swapEncoders(s.writing)
317
318 enc.UseBlock(s.writing)
319 s.writing = blk
320 s.wWg.Add(1)
321 go func() {
322 defer func() {
323 if r := recover(); r != nil {
324 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
325 rdebug.PrintStack()
326 }
327 s.wWg.Done()
328 }()
329 s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
330 if s.writeErr != nil {
331 return
332 }
333 _, s.writeErr = s.w.Write(blk.output)
334 s.nWritten += int64(len(blk.output))
335 }()
336 }(s.current)
337 return nil
338 }
339
340
341
342
343
344
345 func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
346 if debugEncoder {
347 println("Using ReadFrom")
348 }
349
350
351 if len(e.state.filling) > 0 {
352 if err := e.nextBlock(false); err != nil {
353 return 0, err
354 }
355 }
356 e.state.filling = e.state.filling[:e.o.blockSize]
357 src := e.state.filling
358 for {
359 n2, err := r.Read(src)
360 if e.o.crc {
361 _, _ = e.state.encoder.CRC().Write(src[:n2])
362 }
363
364 src = src[n2:]
365 n += int64(n2)
366 switch err {
367 case io.EOF:
368 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
369 if debugEncoder {
370 println("ReadFrom: got EOF final block:", len(e.state.filling))
371 }
372 return n, nil
373 case nil:
374 default:
375 if debugEncoder {
376 println("ReadFrom: got error:", err)
377 }
378 e.state.err = err
379 return n, err
380 }
381 if len(src) > 0 {
382 if debugEncoder {
383 println("ReadFrom: got space left in source:", len(src))
384 }
385 continue
386 }
387 err = e.nextBlock(false)
388 if err != nil {
389 return n, err
390 }
391 e.state.filling = e.state.filling[:e.o.blockSize]
392 src = e.state.filling
393 }
394 }
395
396
397
398
399 func (e *Encoder) Flush() error {
400 s := &e.state
401 if len(s.filling) > 0 {
402 err := e.nextBlock(false)
403 if err != nil {
404 return err
405 }
406 }
407 s.wg.Wait()
408 s.wWg.Wait()
409 if s.err != nil {
410 return s.err
411 }
412 return s.writeErr
413 }
414
415
416
417
418 func (e *Encoder) Close() error {
419 s := &e.state
420 if s.encoder == nil {
421 return nil
422 }
423 err := e.nextBlock(true)
424 if err != nil {
425 return err
426 }
427 if s.frameContentSize > 0 {
428 if s.nInput != s.frameContentSize {
429 return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
430 }
431 }
432 if e.state.fullFrameWritten {
433 return s.err
434 }
435 s.wg.Wait()
436 s.wWg.Wait()
437
438 if s.err != nil {
439 return s.err
440 }
441 if s.writeErr != nil {
442 return s.writeErr
443 }
444
445
446 if e.o.crc && s.err == nil {
447
448 var tmp [4]byte
449 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
450 s.nWritten += 4
451 }
452
453
454 if s.err == nil && e.o.pad > 0 {
455 add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
456 frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
457 if err != nil {
458 return err
459 }
460 _, s.err = s.w.Write(frame)
461 }
462 return s.err
463 }
464
465
466
467
468
469
470
471 func (e *Encoder) EncodeAll(src, dst []byte) []byte {
472 if len(src) == 0 {
473 if e.o.fullZero {
474
475 fh := frameHeader{
476 ContentSize: 0,
477 WindowSize: MinWindowSize,
478 SingleSegment: true,
479
480 Checksum: false,
481 DictID: 0,
482 }
483 dst = fh.appendTo(dst)
484
485
486 var blk blockHeader
487 blk.setSize(0)
488 blk.setType(blockTypeRaw)
489 blk.setLast(true)
490 dst = blk.appendTo(dst)
491 }
492 return dst
493 }
494 e.init.Do(e.initialize)
495 enc := <-e.encoders
496 defer func() {
497
498
499 e.encoders <- enc
500 }()
501
502 single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
503 if e.o.single != nil {
504 single = *e.o.single
505 }
506 fh := frameHeader{
507 ContentSize: uint64(len(src)),
508 WindowSize: uint32(enc.WindowSize(int64(len(src)))),
509 SingleSegment: single,
510 Checksum: e.o.crc,
511 DictID: e.o.dict.ID(),
512 }
513
514
515 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
516 dst = make([]byte, 0, len(src))
517 }
518 dst = fh.appendTo(dst)
519
520
521 if len(src) <= e.o.blockSize {
522 enc.Reset(e.o.dict, true)
523
524 if e.o.crc {
525 _, _ = enc.CRC().Write(src)
526 }
527 blk := enc.Block()
528 blk.last = true
529 if e.o.dict == nil {
530 enc.EncodeNoHist(blk, src)
531 } else {
532 enc.Encode(blk, src)
533 }
534
535
536
537 oldout := blk.output
538
539 blk.output = dst
540
541 err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
542 if err != nil {
543 panic(err)
544 }
545 dst = blk.output
546 blk.output = oldout
547 } else {
548 enc.Reset(e.o.dict, false)
549 blk := enc.Block()
550 for len(src) > 0 {
551 todo := src
552 if len(todo) > e.o.blockSize {
553 todo = todo[:e.o.blockSize]
554 }
555 src = src[len(todo):]
556 if e.o.crc {
557 _, _ = enc.CRC().Write(todo)
558 }
559 blk.pushOffsets()
560 enc.Encode(blk, todo)
561 if len(src) == 0 {
562 blk.last = true
563 }
564 err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
565 if err != nil {
566 panic(err)
567 }
568 dst = append(dst, blk.output...)
569 blk.reset(nil)
570 }
571 }
572 if e.o.crc {
573 dst = enc.AppendCRC(dst)
574 }
575
576 if e.o.pad > 0 {
577 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
578 var err error
579 dst, err = skippableFrame(dst, add, rand.Reader)
580 if err != nil {
581 panic(err)
582 }
583 }
584 return dst
585 }
586
587
588
589 func (e *Encoder) MaxEncodedSize(size int) int {
590 frameHeader := 4 + 2
591 if e.o.dict != nil {
592 frameHeader += 4
593 }
594
595 if size < 256 {
596 frameHeader++
597 } else if size < 65536+256 {
598 frameHeader += 2
599 } else if size < math.MaxInt32 {
600 frameHeader += 4
601 } else {
602 frameHeader += 8
603 }
604
605 if e.o.crc {
606 frameHeader += 4
607 }
608
609
610
611 blocks := (size + e.o.blockSize) / e.o.blockSize
612
613
614 maxSz := frameHeader + 3*blocks + size
615 if e.o.pad > 1 {
616 maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
617 }
618 return maxSz
619 }
620
View as plain text