1
2
3
4
5
6 package fse
7
8 import (
9 "errors"
10 "fmt"
11 )
12
13
14
15
16
17
18 func Compress(in []byte, s *Scratch) ([]byte, error) {
19 if len(in) <= 1 {
20 return nil, ErrIncompressible
21 }
22 if len(in) > (2<<30)-1 {
23 return nil, errors.New("input too big, must be < 2GB")
24 }
25 s, err := s.prepare(in)
26 if err != nil {
27 return nil, err
28 }
29
30
31 maxCount := s.maxCount
32 if maxCount == 0 {
33 maxCount = s.countSimple(in)
34 }
35
36 s.clearCount = true
37 s.maxCount = 0
38 if maxCount == len(in) {
39
40 return nil, ErrUseRLE
41 }
42 if maxCount == 1 || maxCount < (len(in)>>7) {
43
44 return nil, ErrIncompressible
45 }
46 s.optimalTableLog()
47 err = s.normalizeCount()
48 if err != nil {
49 return nil, err
50 }
51 err = s.writeCount()
52 if err != nil {
53 return nil, err
54 }
55
56 if false {
57 err = s.validateNorm()
58 if err != nil {
59 return nil, err
60 }
61 }
62
63 err = s.buildCTable()
64 if err != nil {
65 return nil, err
66 }
67 err = s.compress(in)
68 if err != nil {
69 return nil, err
70 }
71 s.Out = s.bw.out
72
73 if len(s.Out) >= len(in) {
74 return nil, ErrIncompressible
75 }
76 return s.Out, nil
77 }
78
79
80 type cState struct {
81 bw *bitWriter
82 stateTable []uint16
83 state uint16
84 }
85
86
87 func (c *cState) init(bw *bitWriter, ct *cTable, tableLog uint8, first symbolTransform) {
88 c.bw = bw
89 c.stateTable = ct.stateTable
90
91 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
92 im := int32((nbBitsOut << 16) - first.deltaNbBits)
93 lu := (im >> nbBitsOut) + first.deltaFindState
94 c.state = c.stateTable[lu]
95 }
96
97
98 func (c *cState) encode(symbolTT symbolTransform) {
99 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
100 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
101 c.bw.addBits16NC(c.state, uint8(nbBitsOut))
102 c.state = c.stateTable[dstState]
103 }
104
105
106 func (c *cState) encodeZero(symbolTT symbolTransform) {
107 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
108 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState
109 c.bw.addBits16ZeroNC(c.state, uint8(nbBitsOut))
110 c.state = c.stateTable[dstState]
111 }
112
113
114 func (c *cState) flush(tableLog uint8) {
115 c.bw.flush32()
116 c.bw.addBits16NC(c.state, tableLog)
117 c.bw.flush()
118 }
119
120
121 func (s *Scratch) compress(src []byte) error {
122 if len(src) <= 2 {
123 return errors.New("compress: src too small")
124 }
125 tt := s.ct.symbolTT[:256]
126 s.bw.reset(s.Out)
127
128
129
130 var c1, c2 cState
131
132
133 ip := len(src)
134 if ip&1 == 1 {
135 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
136 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
137 c1.encodeZero(tt[src[ip-3]])
138 ip -= 3
139 } else {
140 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]])
141 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]])
142 ip -= 2
143 }
144 if ip&2 != 0 {
145 c2.encodeZero(tt[src[ip-1]])
146 c1.encodeZero(tt[src[ip-2]])
147 ip -= 2
148 }
149 src = src[:ip]
150
151
152 switch {
153 case !s.zeroBits && s.actualTableLog <= 8:
154
155
156 for ; len(src) >= 4; src = src[:len(src)-4] {
157 s.bw.flush32()
158 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1]
159 c2.encode(tt[v0])
160 c1.encode(tt[v1])
161 c2.encode(tt[v2])
162 c1.encode(tt[v3])
163 }
164 case !s.zeroBits:
165
166 for ; len(src) >= 4; src = src[:len(src)-4] {
167 s.bw.flush32()
168 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1]
169 c2.encode(tt[v0])
170 c1.encode(tt[v1])
171 s.bw.flush32()
172 c2.encode(tt[v2])
173 c1.encode(tt[v3])
174 }
175 case s.actualTableLog <= 8:
176
177 for ; len(src) >= 4; src = src[:len(src)-4] {
178 s.bw.flush32()
179 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1]
180 c2.encodeZero(tt[v0])
181 c1.encodeZero(tt[v1])
182 c2.encodeZero(tt[v2])
183 c1.encodeZero(tt[v3])
184 }
185 default:
186 for ; len(src) >= 4; src = src[:len(src)-4] {
187 s.bw.flush32()
188 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1]
189 c2.encodeZero(tt[v0])
190 c1.encodeZero(tt[v1])
191 s.bw.flush32()
192 c2.encodeZero(tt[v2])
193 c1.encodeZero(tt[v3])
194 }
195 }
196
197
198
199 c2.flush(s.actualTableLog)
200 c1.flush(s.actualTableLog)
201
202 s.bw.close()
203 return nil
204 }
205
206
207
208 func (s *Scratch) writeCount() error {
209 var (
210 tableLog = s.actualTableLog
211 tableSize = 1 << tableLog
212 previous0 bool
213 charnum uint16
214
215 maxHeaderSize = ((int(s.symbolLen)*int(tableLog) + 4 + 2) >> 3) + 3
216
217
218 bitStream = uint32(tableLog - minTablelog)
219 bitCount = uint(4)
220 remaining = int16(tableSize + 1)
221 threshold = int16(tableSize)
222 nbBits = uint(tableLog + 1)
223 )
224 if cap(s.Out) < maxHeaderSize {
225 s.Out = make([]byte, 0, s.br.remain()+maxHeaderSize)
226 }
227 outP := uint(0)
228 out := s.Out[:maxHeaderSize]
229
230
231 for remaining > 1 {
232 if previous0 {
233 start := charnum
234 for s.norm[charnum] == 0 {
235 charnum++
236 }
237 for charnum >= start+24 {
238 start += 24
239 bitStream += uint32(0xFFFF) << bitCount
240 out[outP] = byte(bitStream)
241 out[outP+1] = byte(bitStream >> 8)
242 outP += 2
243 bitStream >>= 16
244 }
245 for charnum >= start+3 {
246 start += 3
247 bitStream += 3 << bitCount
248 bitCount += 2
249 }
250 bitStream += uint32(charnum-start) << bitCount
251 bitCount += 2
252 if bitCount > 16 {
253 out[outP] = byte(bitStream)
254 out[outP+1] = byte(bitStream >> 8)
255 outP += 2
256 bitStream >>= 16
257 bitCount -= 16
258 }
259 }
260
261 count := s.norm[charnum]
262 charnum++
263 max := (2*threshold - 1) - remaining
264 if count < 0 {
265 remaining += count
266 } else {
267 remaining -= count
268 }
269 count++
270 if count >= threshold {
271 count += max
272 }
273 bitStream += uint32(count) << bitCount
274 bitCount += nbBits
275 if count < max {
276 bitCount--
277 }
278
279 previous0 = count == 1
280 if remaining < 1 {
281 return errors.New("internal error: remaining<1")
282 }
283 for remaining < threshold {
284 nbBits--
285 threshold >>= 1
286 }
287
288 if bitCount > 16 {
289 out[outP] = byte(bitStream)
290 out[outP+1] = byte(bitStream >> 8)
291 outP += 2
292 bitStream >>= 16
293 bitCount -= 16
294 }
295 }
296
297 out[outP] = byte(bitStream)
298 out[outP+1] = byte(bitStream >> 8)
299 outP += (bitCount + 7) / 8
300
301 if charnum > s.symbolLen {
302 return errors.New("internal error: charnum > s.symbolLen")
303 }
304 s.Out = out[:outP]
305 return nil
306 }
307
308
309 type symbolTransform struct {
310 deltaFindState int32
311 deltaNbBits uint32
312 }
313
314
315 func (s symbolTransform) String() string {
316 return fmt.Sprintf("dnbits: %08x, fs:%d", s.deltaNbBits, s.deltaFindState)
317 }
318
319
320 type cTable struct {
321 tableSymbol []byte
322 stateTable []uint16
323 symbolTT []symbolTransform
324 }
325
326
327
328 func (s *Scratch) allocCtable() {
329 tableSize := 1 << s.actualTableLog
330
331 if cap(s.ct.tableSymbol) < tableSize {
332 s.ct.tableSymbol = make([]byte, tableSize)
333 }
334 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
335
336 ctSize := tableSize
337 if cap(s.ct.stateTable) < ctSize {
338 s.ct.stateTable = make([]uint16, ctSize)
339 }
340 s.ct.stateTable = s.ct.stateTable[:ctSize]
341
342 if cap(s.ct.symbolTT) < 256 {
343 s.ct.symbolTT = make([]symbolTransform, 256)
344 }
345 s.ct.symbolTT = s.ct.symbolTT[:256]
346 }
347
348
349 func (s *Scratch) buildCTable() error {
350 tableSize := uint32(1 << s.actualTableLog)
351 highThreshold := tableSize - 1
352 var cumul [maxSymbolValue + 2]int16
353
354 s.allocCtable()
355 tableSymbol := s.ct.tableSymbol[:tableSize]
356
357 {
358 cumul[0] = 0
359 for ui, v := range s.norm[:s.symbolLen-1] {
360 u := byte(ui)
361 if v == -1 {
362
363 cumul[u+1] = cumul[u] + 1
364 tableSymbol[highThreshold] = u
365 highThreshold--
366 } else {
367 cumul[u+1] = cumul[u] + v
368 }
369 }
370
371 u := int(s.symbolLen - 1)
372 v := s.norm[s.symbolLen-1]
373 if v == -1 {
374
375 cumul[u+1] = cumul[u] + 1
376 tableSymbol[highThreshold] = byte(u)
377 highThreshold--
378 } else {
379 cumul[u+1] = cumul[u] + v
380 }
381 if uint32(cumul[s.symbolLen]) != tableSize {
382 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
383 }
384 cumul[s.symbolLen] = int16(tableSize) + 1
385 }
386
387 s.zeroBits = false
388 {
389 step := tableStep(tableSize)
390 tableMask := tableSize - 1
391 var position uint32
392
393 largeLimit := int16(1 << (s.actualTableLog - 1))
394 for ui, v := range s.norm[:s.symbolLen] {
395 symbol := byte(ui)
396 if v > largeLimit {
397 s.zeroBits = true
398 }
399 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
400 tableSymbol[position] = symbol
401 position = (position + step) & tableMask
402 for position > highThreshold {
403 position = (position + step) & tableMask
404 }
405 }
406 }
407
408
409 if position != 0 {
410 return errors.New("position!=0")
411 }
412 }
413
414
415 table := s.ct.stateTable
416 {
417 tsi := int(tableSize)
418 for u, v := range tableSymbol {
419
420 table[cumul[v]] = uint16(tsi + u)
421 cumul[v]++
422 }
423 }
424
425
426 {
427 total := int16(0)
428 symbolTT := s.ct.symbolTT[:s.symbolLen]
429 tableLog := s.actualTableLog
430 tl := (uint32(tableLog) << 16) - (1 << tableLog)
431 for i, v := range s.norm[:s.symbolLen] {
432 switch v {
433 case 0:
434 case -1, 1:
435 symbolTT[i].deltaNbBits = tl
436 symbolTT[i].deltaFindState = int32(total - 1)
437 total++
438 default:
439 maxBitsOut := uint32(tableLog) - highBits(uint32(v-1))
440 minStatePlus := uint32(v) << maxBitsOut
441 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
442 symbolTT[i].deltaFindState = int32(total - v)
443 total += v
444 }
445 }
446 if total != int16(tableSize) {
447 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
448 }
449 }
450 return nil
451 }
452
453
454
455
456 func (s *Scratch) countSimple(in []byte) (max int) {
457 for _, v := range in {
458 s.count[v]++
459 }
460 m, symlen := uint32(0), s.symbolLen
461 for i, v := range s.count[:] {
462 if v == 0 {
463 continue
464 }
465 if v > m {
466 m = v
467 }
468 symlen = uint16(i) + 1
469 }
470 s.symbolLen = symlen
471 return int(m)
472 }
473
474
475 func (s *Scratch) minTableLog() uint8 {
476 minBitsSrc := highBits(uint32(s.br.remain()-1)) + 1
477 minBitsSymbols := highBits(uint32(s.symbolLen-1)) + 2
478 if minBitsSrc < minBitsSymbols {
479 return uint8(minBitsSrc)
480 }
481 return uint8(minBitsSymbols)
482 }
483
484
485 func (s *Scratch) optimalTableLog() {
486 tableLog := s.TableLog
487 minBits := s.minTableLog()
488 maxBitsSrc := uint8(highBits(uint32(s.br.remain()-1))) - 2
489 if maxBitsSrc < tableLog {
490
491 tableLog = maxBitsSrc
492 }
493 if minBits > tableLog {
494 tableLog = minBits
495 }
496
497 if tableLog < minTablelog {
498 tableLog = minTablelog
499 }
500 if tableLog > maxTableLog {
501 tableLog = maxTableLog
502 }
503 s.actualTableLog = tableLog
504 }
505
506 var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
507
508
509
510 func (s *Scratch) normalizeCount() error {
511 var (
512 tableLog = s.actualTableLog
513 scale = 62 - uint64(tableLog)
514 step = (1 << 62) / uint64(s.br.remain())
515 vStep = uint64(1) << (scale - 20)
516 stillToDistribute = int16(1 << tableLog)
517 largest int
518 largestP int16
519 lowThreshold = (uint32)(s.br.remain() >> tableLog)
520 )
521
522 for i, cnt := range s.count[:s.symbolLen] {
523
524
525
526 if cnt == 0 {
527 s.norm[i] = 0
528 continue
529 }
530 if cnt <= lowThreshold {
531 s.norm[i] = -1
532 stillToDistribute--
533 } else {
534 proba := (int16)((uint64(cnt) * step) >> scale)
535 if proba < 8 {
536 restToBeat := vStep * uint64(rtbTable[proba])
537 v := uint64(cnt)*step - (uint64(proba) << scale)
538 if v > restToBeat {
539 proba++
540 }
541 }
542 if proba > largestP {
543 largestP = proba
544 largest = i
545 }
546 s.norm[i] = proba
547 stillToDistribute -= proba
548 }
549 }
550
551 if -stillToDistribute >= (s.norm[largest] >> 1) {
552
553 return s.normalizeCount2()
554 }
555 s.norm[largest] += stillToDistribute
556 return nil
557 }
558
559
560
561 func (s *Scratch) normalizeCount2() error {
562 const notYetAssigned = -2
563 var (
564 distributed uint32
565 total = uint32(s.br.remain())
566 tableLog = s.actualTableLog
567 lowThreshold = total >> tableLog
568 lowOne = (total * 3) >> (tableLog + 1)
569 )
570 for i, cnt := range s.count[:s.symbolLen] {
571 if cnt == 0 {
572 s.norm[i] = 0
573 continue
574 }
575 if cnt <= lowThreshold {
576 s.norm[i] = -1
577 distributed++
578 total -= cnt
579 continue
580 }
581 if cnt <= lowOne {
582 s.norm[i] = 1
583 distributed++
584 total -= cnt
585 continue
586 }
587 s.norm[i] = notYetAssigned
588 }
589 toDistribute := (1 << tableLog) - distributed
590
591 if (total / toDistribute) > lowOne {
592
593 lowOne = (total * 3) / (toDistribute * 2)
594 for i, cnt := range s.count[:s.symbolLen] {
595 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
596 s.norm[i] = 1
597 distributed++
598 total -= cnt
599 continue
600 }
601 }
602 toDistribute = (1 << tableLog) - distributed
603 }
604 if distributed == uint32(s.symbolLen)+1 {
605
606
607
608 var maxV int
609 var maxC uint32
610 for i, cnt := range s.count[:s.symbolLen] {
611 if cnt > maxC {
612 maxV = i
613 maxC = cnt
614 }
615 }
616 s.norm[maxV] += int16(toDistribute)
617 return nil
618 }
619
620 if total == 0 {
621
622 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
623 if s.norm[i] > 0 {
624 toDistribute--
625 s.norm[i]++
626 }
627 }
628 return nil
629 }
630
631 var (
632 vStepLog = 62 - uint64(tableLog)
633 mid = uint64((1 << (vStepLog - 1)) - 1)
634 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total)
635 tmpTotal = mid
636 )
637 for i, cnt := range s.count[:s.symbolLen] {
638 if s.norm[i] == notYetAssigned {
639 var (
640 end = tmpTotal + uint64(cnt)*rStep
641 sStart = uint32(tmpTotal >> vStepLog)
642 sEnd = uint32(end >> vStepLog)
643 weight = sEnd - sStart
644 )
645 if weight < 1 {
646 return errors.New("weight < 1")
647 }
648 s.norm[i] = int16(weight)
649 tmpTotal = end
650 }
651 }
652 return nil
653 }
654
655
656 func (s *Scratch) validateNorm() (err error) {
657 var total int
658 for _, v := range s.norm[:s.symbolLen] {
659 if v >= 0 {
660 total += int(v)
661 } else {
662 total -= int(v)
663 }
664 }
665 defer func() {
666 if err == nil {
667 return
668 }
669 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
670 for i, v := range s.norm[:s.symbolLen] {
671 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
672 }
673 }()
674 if total != (1 << s.actualTableLog) {
675 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
676 }
677 for i, v := range s.count[s.symbolLen:] {
678 if v != 0 {
679 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
680 }
681 }
682 return nil
683 }
684
View as plain text