1 package huff0
2
3 import (
4 "fmt"
5 "math"
6 "runtime"
7 "sync"
8 )
9
10
11
12
13
14 func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
15 s, err = s.prepare(in)
16 if err != nil {
17 return nil, false, err
18 }
19 return compress(in, s, s.compress1X)
20 }
21
22
23
24
25
26
27 func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
28 s, err = s.prepare(in)
29 if err != nil {
30 return nil, false, err
31 }
32 if false {
33
34 const parallelThreshold = 8 << 10
35 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
36 return compress(in, s, s.compress4X)
37 }
38 return compress(in, s, s.compress4Xp)
39 }
40 return compress(in, s, s.compress4X)
41 }
42
43 func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
44
45 if s.Reuse == ReusePolicyNone {
46 s.prevTable = s.prevTable[:0]
47 }
48
49
50 maxCount := s.maxCount
51 var canReuse = false
52 if maxCount == 0 {
53 maxCount, canReuse = s.countSimple(in)
54 } else {
55 canReuse = s.canUseTable(s.prevTable)
56 }
57
58
59 wantSize := len(in)
60 if s.WantLogLess > 0 {
61 wantSize -= wantSize >> s.WantLogLess
62 }
63
64
65 s.clearCount = true
66 s.maxCount = 0
67 if maxCount >= len(in) {
68 if maxCount > len(in) {
69 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
70 }
71 if len(in) == 1 {
72 return nil, false, ErrIncompressible
73 }
74
75 return nil, false, ErrUseRLE
76 }
77 if maxCount == 1 || maxCount < (len(in)>>7) {
78
79 return nil, false, ErrIncompressible
80 }
81 if s.Reuse == ReusePolicyMust && !canReuse {
82
83 return nil, false, ErrIncompressible
84 }
85 if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse {
86 keepTable := s.cTable
87 keepTL := s.actualTableLog
88 s.cTable = s.prevTable
89 s.actualTableLog = s.prevTableLog
90 s.Out, err = compressor(in)
91 s.cTable = keepTable
92 s.actualTableLog = keepTL
93 if err == nil && len(s.Out) < wantSize {
94 s.OutData = s.Out
95 return s.Out, true, nil
96 }
97 if s.Reuse == ReusePolicyMust {
98 return nil, false, ErrIncompressible
99 }
100
101 s.prevTable = s.prevTable[:0]
102 }
103
104
105 err = s.buildCTable()
106 if err != nil {
107 return nil, false, err
108 }
109
110 if false && !s.canUseTable(s.cTable) {
111 panic("invalid table generated")
112 }
113
114 if s.Reuse == ReusePolicyAllow && canReuse {
115 hSize := len(s.Out)
116 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
117 newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
118 if oldSize <= hSize+newSize || hSize+12 >= wantSize {
119
120 keepTable := s.cTable
121 keepTL := s.actualTableLog
122
123 s.cTable = s.prevTable
124 s.actualTableLog = s.prevTableLog
125 s.Out, err = compressor(in)
126
127
128 s.cTable = keepTable
129 s.actualTableLog = keepTL
130 if err != nil {
131 return nil, false, err
132 }
133 if len(s.Out) >= wantSize {
134 return nil, false, ErrIncompressible
135 }
136 s.OutData = s.Out
137 return s.Out, true, nil
138 }
139 }
140
141
142 err = s.cTable.write(s)
143 if err != nil {
144 s.OutTable = nil
145 return nil, false, err
146 }
147 s.OutTable = s.Out
148
149
150 s.Out, err = compressor(in)
151 if err != nil {
152 s.OutTable = nil
153 return nil, false, err
154 }
155 if len(s.Out) >= wantSize {
156 s.OutTable = nil
157 return nil, false, ErrIncompressible
158 }
159
160 s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
161 s.OutData = s.Out[len(s.OutTable):]
162 return s.Out, false, nil
163 }
164
165
166 func EstimateSizes(in []byte, s *Scratch) (tableSz, dataSz, reuseSz int, err error) {
167 s, err = s.prepare(in)
168 if err != nil {
169 return 0, 0, 0, err
170 }
171
172
173 tableSz, dataSz, reuseSz = -1, -1, -1
174 maxCount := s.maxCount
175 var canReuse = false
176 if maxCount == 0 {
177 maxCount, canReuse = s.countSimple(in)
178 } else {
179 canReuse = s.canUseTable(s.prevTable)
180 }
181
182
183 wantSize := len(in)
184 if s.WantLogLess > 0 {
185 wantSize -= wantSize >> s.WantLogLess
186 }
187
188
189 s.clearCount = true
190 s.maxCount = 0
191 if maxCount >= len(in) {
192 if maxCount > len(in) {
193 return 0, 0, 0, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
194 }
195 if len(in) == 1 {
196 return 0, 0, 0, ErrIncompressible
197 }
198
199 return 0, 0, 0, ErrUseRLE
200 }
201 if maxCount == 1 || maxCount < (len(in)>>7) {
202
203 return 0, 0, 0, ErrIncompressible
204 }
205
206
207 err = s.buildCTable()
208 if err != nil {
209 return 0, 0, 0, err
210 }
211
212 if false && !s.canUseTable(s.cTable) {
213 panic("invalid table generated")
214 }
215
216 tableSz, err = s.cTable.estTableSize(s)
217 if err != nil {
218 return 0, 0, 0, err
219 }
220 if canReuse {
221 reuseSz = s.prevTable.estimateSize(s.count[:s.symbolLen])
222 }
223 dataSz = s.cTable.estimateSize(s.count[:s.symbolLen])
224
225
226 return tableSz, dataSz, reuseSz, nil
227 }
228
229 func (s *Scratch) compress1X(src []byte) ([]byte, error) {
230 return s.compress1xDo(s.Out, src), nil
231 }
232
233 func (s *Scratch) compress1xDo(dst, src []byte) []byte {
234 var bw = bitWriter{out: dst}
235
236
237 n := len(src)
238 n -= n & 3
239 cTable := s.cTable[:256]
240
241
242 for i := len(src) & 3; i > 0; i-- {
243 bw.encSymbol(cTable, src[n+i-1])
244 }
245 n -= 4
246 if s.actualTableLog <= 8 {
247 for ; n >= 0; n -= 4 {
248 tmp := src[n : n+4]
249
250 bw.flush32()
251 bw.encFourSymbols(cTable[tmp[3]], cTable[tmp[2]], cTable[tmp[1]], cTable[tmp[0]])
252 }
253 } else {
254 for ; n >= 0; n -= 4 {
255 tmp := src[n : n+4]
256
257 bw.flush32()
258 bw.encTwoSymbols(cTable, tmp[3], tmp[2])
259 bw.flush32()
260 bw.encTwoSymbols(cTable, tmp[1], tmp[0])
261 }
262 }
263 bw.close()
264 return bw.out
265 }
266
267 var sixZeros [6]byte
268
269 func (s *Scratch) compress4X(src []byte) ([]byte, error) {
270 if len(src) < 12 {
271 return nil, ErrIncompressible
272 }
273 segmentSize := (len(src) + 3) / 4
274
275
276 offsetIdx := len(s.Out)
277 s.Out = append(s.Out, sixZeros[:]...)
278
279 for i := 0; i < 4; i++ {
280 toDo := src
281 if len(toDo) > segmentSize {
282 toDo = toDo[:segmentSize]
283 }
284 src = src[len(toDo):]
285
286 idx := len(s.Out)
287 s.Out = s.compress1xDo(s.Out, toDo)
288 if len(s.Out)-idx > math.MaxUint16 {
289
290 return nil, ErrIncompressible
291 }
292
293 if i < 3 {
294
295 length := len(s.Out) - idx
296 s.Out[i*2+offsetIdx] = byte(length)
297 s.Out[i*2+offsetIdx+1] = byte(length >> 8)
298 }
299 }
300
301 return s.Out, nil
302 }
303
304
305 func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
306 if len(src) < 12 {
307 return nil, ErrIncompressible
308 }
309
310 s.Out = s.Out[:6]
311
312 segmentSize := (len(src) + 3) / 4
313 var wg sync.WaitGroup
314 wg.Add(4)
315 for i := 0; i < 4; i++ {
316 toDo := src
317 if len(toDo) > segmentSize {
318 toDo = toDo[:segmentSize]
319 }
320 src = src[len(toDo):]
321
322
323 go func(i int) {
324 s.tmpOut[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
325 wg.Done()
326 }(i)
327 }
328 wg.Wait()
329 for i := 0; i < 4; i++ {
330 o := s.tmpOut[i]
331 if len(o) > math.MaxUint16 {
332
333 return nil, ErrIncompressible
334 }
335
336 if i < 3 {
337
338 s.Out[i*2] = byte(len(o))
339 s.Out[i*2+1] = byte(len(o) >> 8)
340 }
341
342
343 s.Out = append(s.Out, o...)
344 }
345 return s.Out, nil
346 }
347
348
349
350
351 func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
352 reuse = true
353 _ = s.count
354 for _, v := range in {
355 s.count[v]++
356 }
357 m := uint32(0)
358 if len(s.prevTable) > 0 {
359 for i, v := range s.count[:] {
360 if v == 0 {
361 continue
362 }
363 if v > m {
364 m = v
365 }
366 s.symbolLen = uint16(i) + 1
367 if i >= len(s.prevTable) {
368 reuse = false
369 } else if s.prevTable[i].nBits == 0 {
370 reuse = false
371 }
372 }
373 return int(m), reuse
374 }
375 for i, v := range s.count[:] {
376 if v == 0 {
377 continue
378 }
379 if v > m {
380 m = v
381 }
382 s.symbolLen = uint16(i) + 1
383 }
384 return int(m), false
385 }
386
387 func (s *Scratch) canUseTable(c cTable) bool {
388 if len(c) < int(s.symbolLen) {
389 return false
390 }
391 for i, v := range s.count[:s.symbolLen] {
392 if v != 0 && c[i].nBits == 0 {
393 return false
394 }
395 }
396 return true
397 }
398
399
400 func (s *Scratch) validateTable(c cTable) bool {
401 if len(c) < int(s.symbolLen) {
402 return false
403 }
404 for i, v := range s.count[:s.symbolLen] {
405 if v != 0 {
406 if c[i].nBits == 0 {
407 return false
408 }
409 if c[i].nBits > s.actualTableLog {
410 return false
411 }
412 }
413 }
414 return true
415 }
416
417
418 func (s *Scratch) minTableLog() uint8 {
419 minBitsSrc := highBit32(uint32(s.srcLen)) + 1
420 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
421 if minBitsSrc < minBitsSymbols {
422 return uint8(minBitsSrc)
423 }
424 return uint8(minBitsSymbols)
425 }
426
427
428 func (s *Scratch) optimalTableLog() {
429 tableLog := s.TableLog
430 minBits := s.minTableLog()
431 maxBitsSrc := uint8(highBit32(uint32(s.srcLen-1))) - 1
432 if maxBitsSrc < tableLog {
433
434 tableLog = maxBitsSrc
435 }
436 if minBits > tableLog {
437 tableLog = minBits
438 }
439
440 if tableLog < minTablelog {
441 tableLog = minTablelog
442 }
443 if tableLog > tableLogMax {
444 tableLog = tableLogMax
445 }
446 s.actualTableLog = tableLog
447 }
448
449 type cTableEntry struct {
450 val uint16
451 nBits uint8
452
453 }
454
455 const huffNodesMask = huffNodesLen - 1
456
457 func (s *Scratch) buildCTable() error {
458 s.optimalTableLog()
459 s.huffSort()
460 if cap(s.cTable) < maxSymbolValue+1 {
461 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
462 } else {
463 s.cTable = s.cTable[:s.symbolLen]
464 for i := range s.cTable {
465 s.cTable[i] = cTableEntry{}
466 }
467 }
468
469 var startNode = int16(s.symbolLen)
470 nonNullRank := s.symbolLen - 1
471
472 nodeNb := startNode
473 huffNode := s.nodes[1 : huffNodesLen+1]
474
475
476
477 huffNode0 := s.nodes[0 : huffNodesLen+1]
478
479 for huffNode[nonNullRank].count() == 0 {
480 nonNullRank--
481 }
482
483 lowS := int16(nonNullRank)
484 nodeRoot := nodeNb + lowS - 1
485 lowN := nodeNb
486 huffNode[nodeNb].setCount(huffNode[lowS].count() + huffNode[lowS-1].count())
487 huffNode[lowS].setParent(nodeNb)
488 huffNode[lowS-1].setParent(nodeNb)
489 nodeNb++
490 lowS -= 2
491 for n := nodeNb; n <= nodeRoot; n++ {
492 huffNode[n].setCount(1 << 30)
493 }
494
495 huffNode0[0].setCount(1 << 31)
496
497
498 for nodeNb <= nodeRoot {
499 var n1, n2 int16
500 if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() {
501 n1 = lowS
502 lowS--
503 } else {
504 n1 = lowN
505 lowN++
506 }
507 if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() {
508 n2 = lowS
509 lowS--
510 } else {
511 n2 = lowN
512 lowN++
513 }
514
515 huffNode[nodeNb].setCount(huffNode0[n1+1].count() + huffNode0[n2+1].count())
516 huffNode0[n1+1].setParent(nodeNb)
517 huffNode0[n2+1].setParent(nodeNb)
518 nodeNb++
519 }
520
521
522 huffNode[nodeRoot].setNbBits(0)
523 for n := nodeRoot - 1; n >= startNode; n-- {
524 huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1)
525 }
526 for n := uint16(0); n <= nonNullRank; n++ {
527 huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1)
528 }
529 s.actualTableLog = s.setMaxHeight(int(nonNullRank))
530 maxNbBits := s.actualTableLog
531
532
533 if maxNbBits > tableLogMax {
534 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
535 }
536 var nbPerRank [tableLogMax + 1]uint16
537 var valPerRank [16]uint16
538 for _, v := range huffNode[:nonNullRank+1] {
539 nbPerRank[v.nbBits()]++
540 }
541
542 {
543 min := uint16(0)
544 for n := maxNbBits; n > 0; n-- {
545
546 valPerRank[n] = min
547 min += nbPerRank[n]
548 min >>= 1
549 }
550 }
551
552
553 for _, v := range huffNode[:nonNullRank+1] {
554 s.cTable[v.symbol()].nBits = v.nbBits()
555 }
556
557
558 t := s.cTable[:s.symbolLen]
559 for n, val := range t {
560 nbits := val.nBits & 15
561 v := valPerRank[nbits]
562 t[n].val = v
563 valPerRank[nbits] = v + 1
564 }
565
566 return nil
567 }
568
569
570 func (s *Scratch) huffSort() {
571 type rankPos struct {
572 base uint32
573 current uint32
574 }
575
576
577 nodes := s.nodes[:huffNodesLen+1]
578 s.nodes = nodes
579 nodes = nodes[1 : huffNodesLen+1]
580
581
582 var rank [32]rankPos
583 for _, v := range s.count[:s.symbolLen] {
584 r := highBit32(v+1) & 31
585 rank[r].base++
586 }
587
588 const maxBitLength = 18 + 1
589 for n := maxBitLength; n > 0; n-- {
590 rank[n-1].base += rank[n].base
591 }
592 for n := range rank[:maxBitLength] {
593 rank[n].current = rank[n].base
594 }
595 for n, c := range s.count[:s.symbolLen] {
596 r := (highBit32(c+1) + 1) & 31
597 pos := rank[r].current
598 rank[r].current++
599 prev := nodes[(pos-1)&huffNodesMask]
600 for pos > rank[r].base && c > prev.count() {
601 nodes[pos&huffNodesMask] = prev
602 pos--
603 prev = nodes[(pos-1)&huffNodesMask]
604 }
605 nodes[pos&huffNodesMask] = makeNodeElt(c, byte(n))
606 }
607 }
608
609 func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
610 maxNbBits := s.actualTableLog
611 huffNode := s.nodes[1 : huffNodesLen+1]
612
613
614 largestBits := huffNode[lastNonNull].nbBits()
615
616
617 if largestBits <= maxNbBits {
618 return largestBits
619 }
620 totalCost := int(0)
621 baseCost := int(1) << (largestBits - maxNbBits)
622 n := uint32(lastNonNull)
623
624 for huffNode[n].nbBits() > maxNbBits {
625 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits()))
626 huffNode[n].setNbBits(maxNbBits)
627 n--
628 }
629
630
631 for huffNode[n].nbBits() == maxNbBits {
632 n--
633 }
634
635
636
637 totalCost >>= largestBits - maxNbBits
638
639
640 {
641 const noSymbol = 0xF0F0F0F0
642 var rankLast [tableLogMax + 2]uint32
643
644 for i := range rankLast[:] {
645 rankLast[i] = noSymbol
646 }
647
648
649 {
650 currentNbBits := maxNbBits
651 for pos := int(n); pos >= 0; pos-- {
652 if huffNode[pos].nbBits() >= currentNbBits {
653 continue
654 }
655 currentNbBits = huffNode[pos].nbBits()
656 rankLast[maxNbBits-currentNbBits] = uint32(pos)
657 }
658 }
659
660 for totalCost > 0 {
661 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
662
663 for ; nBitsToDecrease > 1; nBitsToDecrease-- {
664 highPos := rankLast[nBitsToDecrease]
665 lowPos := rankLast[nBitsToDecrease-1]
666 if highPos == noSymbol {
667 continue
668 }
669 if lowPos == noSymbol {
670 break
671 }
672 highTotal := huffNode[highPos].count()
673 lowTotal := 2 * huffNode[lowPos].count()
674 if highTotal <= lowTotal {
675 break
676 }
677 }
678
679
680
681 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
682 nBitsToDecrease++
683 }
684 totalCost -= 1 << (nBitsToDecrease - 1)
685 if rankLast[nBitsToDecrease-1] == noSymbol {
686
687 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
688 }
689 huffNode[rankLast[nBitsToDecrease]].setNbBits(1 +
690 huffNode[rankLast[nBitsToDecrease]].nbBits())
691 if rankLast[nBitsToDecrease] == 0 {
692
693 rankLast[nBitsToDecrease] = noSymbol
694 } else {
695 rankLast[nBitsToDecrease]--
696 if huffNode[rankLast[nBitsToDecrease]].nbBits() != maxNbBits-nBitsToDecrease {
697 rankLast[nBitsToDecrease] = noSymbol
698 }
699 }
700 }
701
702 for totalCost < 0 {
703 if rankLast[1] == noSymbol {
704 for huffNode[n].nbBits() == maxNbBits {
705 n--
706 }
707 huffNode[n+1].setNbBits(huffNode[n+1].nbBits() - 1)
708 rankLast[1] = n + 1
709 totalCost++
710 continue
711 }
712 huffNode[rankLast[1]+1].setNbBits(huffNode[rankLast[1]+1].nbBits() - 1)
713 rankLast[1]++
714 totalCost++
715 }
716 }
717 return maxNbBits
718 }
719
720
721
722
723
724
725
726
727
728
729 type nodeElt uint64
730
731 func makeNodeElt(count uint32, symbol byte) nodeElt {
732 return nodeElt(count) | nodeElt(symbol)<<48
733 }
734
735 func (e *nodeElt) count() uint32 { return uint32(*e) }
736 func (e *nodeElt) parent() uint16 { return uint16(*e >> 32) }
737 func (e *nodeElt) symbol() byte { return byte(*e >> 48) }
738 func (e *nodeElt) nbBits() uint8 { return uint8(*e >> 56) }
739
740 func (e *nodeElt) setCount(c uint32) { *e = (*e)&0xffffffff00000000 | nodeElt(c) }
741 func (e *nodeElt) setParent(p int16) { *e = (*e)&0xffff0000ffffffff | nodeElt(uint16(p))<<32 }
742 func (e *nodeElt) setNbBits(n uint8) { *e = (*e)&0x00ffffffffffffff | nodeElt(n)<<56 }
743
View as plain text