1
2
3
4
5 package zstd
6
7 import (
8 "errors"
9 "fmt"
10 "math"
11 )
12
13 const (
14
15 maxEncTableLog = 8
16 maxEncTablesize = 1 << maxTableLog
17 maxEncTableMask = (1 << maxTableLog) - 1
18 minEncTablelog = 5
19 maxEncSymbolValue = maxMatchLengthSymbol
20 )
21
22
23 type fseEncoder struct {
24 symbolLen uint16
25 actualTableLog uint8
26 ct cTable
27 maxCount int
28 zeroBits bool
29 clearCount bool
30 useRLE bool
31 preDefined bool
32 reUsed bool
33 rleVal uint8
34 maxBits uint8
35
36
37 count [256]uint32
38 norm [256]int16
39 }
40
41
42 type cTable struct {
43 tableSymbol []byte
44 stateTable []uint16
45 symbolTT []symbolTransform
46 }
47
48
49 type symbolTransform struct {
50 deltaNbBits uint32
51 deltaFindState int16
52 outBits uint8
53 }
54
55
56 func (s symbolTransform) String() string {
57 return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits)
58 }
59
60
61
62
63
64
65 func (s *fseEncoder) Histogram() *[256]uint32 {
66 return &s.count
67 }
68
69
70
71
72
73 func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) {
74 s.maxCount = maxCount
75 s.symbolLen = uint16(maxSymbol) + 1
76 s.clearCount = maxCount != 0
77 }
78
79
80
81 func (s *fseEncoder) allocCtable() {
82 tableSize := 1 << s.actualTableLog
83
84 if cap(s.ct.tableSymbol) < tableSize {
85 s.ct.tableSymbol = make([]byte, tableSize)
86 }
87 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
88
89 ctSize := tableSize
90 if cap(s.ct.stateTable) < ctSize {
91 s.ct.stateTable = make([]uint16, ctSize)
92 }
93 s.ct.stateTable = s.ct.stateTable[:ctSize]
94
95 if cap(s.ct.symbolTT) < 256 {
96 s.ct.symbolTT = make([]symbolTransform, 256)
97 }
98 s.ct.symbolTT = s.ct.symbolTT[:256]
99 }
100
101
102 func (s *fseEncoder) buildCTable() error {
103 tableSize := uint32(1 << s.actualTableLog)
104 highThreshold := tableSize - 1
105 var cumul [256]int16
106
107 s.allocCtable()
108 tableSymbol := s.ct.tableSymbol[:tableSize]
109
110 {
111 cumul[0] = 0
112 for ui, v := range s.norm[:s.symbolLen-1] {
113 u := byte(ui)
114 if v == -1 {
115
116 cumul[u+1] = cumul[u] + 1
117 tableSymbol[highThreshold] = u
118 highThreshold--
119 } else {
120 cumul[u+1] = cumul[u] + v
121 }
122 }
123
124 u := int(s.symbolLen - 1)
125 v := s.norm[s.symbolLen-1]
126 if v == -1 {
127
128 cumul[u+1] = cumul[u] + 1
129 tableSymbol[highThreshold] = byte(u)
130 highThreshold--
131 } else {
132 cumul[u+1] = cumul[u] + v
133 }
134 if uint32(cumul[s.symbolLen]) != tableSize {
135 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
136 }
137 cumul[s.symbolLen] = int16(tableSize) + 1
138 }
139
140 s.zeroBits = false
141 {
142 step := tableStep(tableSize)
143 tableMask := tableSize - 1
144 var position uint32
145
146 largeLimit := int16(1 << (s.actualTableLog - 1))
147 for ui, v := range s.norm[:s.symbolLen] {
148 symbol := byte(ui)
149 if v > largeLimit {
150 s.zeroBits = true
151 }
152 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
153 tableSymbol[position] = symbol
154 position = (position + step) & tableMask
155 for position > highThreshold {
156 position = (position + step) & tableMask
157 }
158 }
159 }
160
161
162 if position != 0 {
163 return errors.New("position!=0")
164 }
165 }
166
167
168 table := s.ct.stateTable
169 {
170 tsi := int(tableSize)
171 for u, v := range tableSymbol {
172
173 table[cumul[v]] = uint16(tsi + u)
174 cumul[v]++
175 }
176 }
177
178
179 {
180 total := int16(0)
181 symbolTT := s.ct.symbolTT[:s.symbolLen]
182 tableLog := s.actualTableLog
183 tl := (uint32(tableLog) << 16) - (1 << tableLog)
184 for i, v := range s.norm[:s.symbolLen] {
185 switch v {
186 case 0:
187 case -1, 1:
188 symbolTT[i].deltaNbBits = tl
189 symbolTT[i].deltaFindState = total - 1
190 total++
191 default:
192 maxBitsOut := uint32(tableLog) - highBit(uint32(v-1))
193 minStatePlus := uint32(v) << maxBitsOut
194 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
195 symbolTT[i].deltaFindState = total - v
196 total += v
197 }
198 }
199 if total != int16(tableSize) {
200 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
201 }
202 }
203 return nil
204 }
205
206 var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
207
208 func (s *fseEncoder) setRLE(val byte) {
209 s.allocCtable()
210 s.actualTableLog = 0
211 s.ct.stateTable = s.ct.stateTable[:1]
212 s.ct.symbolTT[val] = symbolTransform{
213 deltaFindState: 0,
214 deltaNbBits: 0,
215 }
216 if debugEncoder {
217 println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val])
218 }
219 s.rleVal = val
220 s.useRLE = true
221 }
222
223
224
225 func (s *fseEncoder) setBits(transform []byte) {
226 if s.reUsed || s.preDefined {
227 return
228 }
229 if s.useRLE {
230 if transform == nil {
231 s.ct.symbolTT[s.rleVal].outBits = s.rleVal
232 s.maxBits = s.rleVal
233 return
234 }
235 s.maxBits = transform[s.rleVal]
236 s.ct.symbolTT[s.rleVal].outBits = s.maxBits
237 return
238 }
239 if transform == nil {
240 for i := range s.ct.symbolTT[:s.symbolLen] {
241 s.ct.symbolTT[i].outBits = uint8(i)
242 }
243 s.maxBits = uint8(s.symbolLen - 1)
244 return
245 }
246 s.maxBits = 0
247 for i, v := range transform[:s.symbolLen] {
248 s.ct.symbolTT[i].outBits = v
249 if v > s.maxBits {
250
251 s.maxBits = v
252 }
253 }
254 }
255
256
257
258
259 func (s *fseEncoder) normalizeCount(length int) error {
260 if s.reUsed {
261 return nil
262 }
263 s.optimalTableLog(length)
264 var (
265 tableLog = s.actualTableLog
266 scale = 62 - uint64(tableLog)
267 step = (1 << 62) / uint64(length)
268 vStep = uint64(1) << (scale - 20)
269 stillToDistribute = int16(1 << tableLog)
270 largest int
271 largestP int16
272 lowThreshold = (uint32)(length >> tableLog)
273 )
274 if s.maxCount == length {
275 s.useRLE = true
276 return nil
277 }
278 s.useRLE = false
279 for i, cnt := range s.count[:s.symbolLen] {
280
281
282
283 if cnt == 0 {
284 s.norm[i] = 0
285 continue
286 }
287 if cnt <= lowThreshold {
288 s.norm[i] = -1
289 stillToDistribute--
290 } else {
291 proba := (int16)((uint64(cnt) * step) >> scale)
292 if proba < 8 {
293 restToBeat := vStep * uint64(rtbTable[proba])
294 v := uint64(cnt)*step - (uint64(proba) << scale)
295 if v > restToBeat {
296 proba++
297 }
298 }
299 if proba > largestP {
300 largestP = proba
301 largest = i
302 }
303 s.norm[i] = proba
304 stillToDistribute -= proba
305 }
306 }
307
308 if -stillToDistribute >= (s.norm[largest] >> 1) {
309
310 err := s.normalizeCount2(length)
311 if err != nil {
312 return err
313 }
314 if debugAsserts {
315 err = s.validateNorm()
316 if err != nil {
317 return err
318 }
319 }
320 return s.buildCTable()
321 }
322 s.norm[largest] += stillToDistribute
323 if debugAsserts {
324 err := s.validateNorm()
325 if err != nil {
326 return err
327 }
328 }
329 return s.buildCTable()
330 }
331
332
333
334 func (s *fseEncoder) normalizeCount2(length int) error {
335 const notYetAssigned = -2
336 var (
337 distributed uint32
338 total = uint32(length)
339 tableLog = s.actualTableLog
340 lowThreshold = total >> tableLog
341 lowOne = (total * 3) >> (tableLog + 1)
342 )
343 for i, cnt := range s.count[:s.symbolLen] {
344 if cnt == 0 {
345 s.norm[i] = 0
346 continue
347 }
348 if cnt <= lowThreshold {
349 s.norm[i] = -1
350 distributed++
351 total -= cnt
352 continue
353 }
354 if cnt <= lowOne {
355 s.norm[i] = 1
356 distributed++
357 total -= cnt
358 continue
359 }
360 s.norm[i] = notYetAssigned
361 }
362 toDistribute := (1 << tableLog) - distributed
363
364 if (total / toDistribute) > lowOne {
365
366 lowOne = (total * 3) / (toDistribute * 2)
367 for i, cnt := range s.count[:s.symbolLen] {
368 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
369 s.norm[i] = 1
370 distributed++
371 total -= cnt
372 continue
373 }
374 }
375 toDistribute = (1 << tableLog) - distributed
376 }
377 if distributed == uint32(s.symbolLen)+1 {
378
379
380
381 var maxV int
382 var maxC uint32
383 for i, cnt := range s.count[:s.symbolLen] {
384 if cnt > maxC {
385 maxV = i
386 maxC = cnt
387 }
388 }
389 s.norm[maxV] += int16(toDistribute)
390 return nil
391 }
392
393 if total == 0 {
394
395 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
396 if s.norm[i] > 0 {
397 toDistribute--
398 s.norm[i]++
399 }
400 }
401 return nil
402 }
403
404 var (
405 vStepLog = 62 - uint64(tableLog)
406 mid = uint64((1 << (vStepLog - 1)) - 1)
407 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total)
408 tmpTotal = mid
409 )
410 for i, cnt := range s.count[:s.symbolLen] {
411 if s.norm[i] == notYetAssigned {
412 var (
413 end = tmpTotal + uint64(cnt)*rStep
414 sStart = uint32(tmpTotal >> vStepLog)
415 sEnd = uint32(end >> vStepLog)
416 weight = sEnd - sStart
417 )
418 if weight < 1 {
419 return errors.New("weight < 1")
420 }
421 s.norm[i] = int16(weight)
422 tmpTotal = end
423 }
424 }
425 return nil
426 }
427
428
429 func (s *fseEncoder) optimalTableLog(length int) {
430 tableLog := uint8(maxEncTableLog)
431 minBitsSrc := highBit(uint32(length)) + 1
432 minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2
433 minBits := uint8(minBitsSymbols)
434 if minBitsSrc < minBitsSymbols {
435 minBits = uint8(minBitsSrc)
436 }
437
438 maxBitsSrc := uint8(highBit(uint32(length-1))) - 2
439 if maxBitsSrc < tableLog {
440
441 tableLog = maxBitsSrc
442 }
443 if minBits > tableLog {
444 tableLog = minBits
445 }
446
447 if tableLog < minEncTablelog {
448 tableLog = minEncTablelog
449 }
450 if tableLog > maxEncTableLog {
451 tableLog = maxEncTableLog
452 }
453 s.actualTableLog = tableLog
454 }
455
456
457 func (s *fseEncoder) validateNorm() (err error) {
458 var total int
459 for _, v := range s.norm[:s.symbolLen] {
460 if v >= 0 {
461 total += int(v)
462 } else {
463 total -= int(v)
464 }
465 }
466 defer func() {
467 if err == nil {
468 return
469 }
470 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
471 for i, v := range s.norm[:s.symbolLen] {
472 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
473 }
474 }()
475 if total != (1 << s.actualTableLog) {
476 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
477 }
478 for i, v := range s.count[s.symbolLen:] {
479 if v != 0 {
480 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
481 }
482 }
483 return nil
484 }
485
486
487
488 func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
489 if s.useRLE {
490 return append(out, s.rleVal), nil
491 }
492 if s.preDefined || s.reUsed {
493
494 return out, nil
495 }
496
497 var (
498 tableLog = s.actualTableLog
499 tableSize = 1 << tableLog
500 previous0 bool
501 charnum uint16
502
503
504 maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 + 2
505
506
507 bitStream = uint32(tableLog - minEncTablelog)
508 bitCount = uint(4)
509 remaining = int16(tableSize + 1)
510 threshold = int16(tableSize)
511 nbBits = uint(tableLog + 1)
512 outP = len(out)
513 )
514 if cap(out) < outP+maxHeaderSize {
515 out = append(out, make([]byte, maxHeaderSize*3)...)
516 out = out[:len(out)-maxHeaderSize*3]
517 }
518 out = out[:outP+maxHeaderSize]
519
520
521 for remaining > 1 {
522 if previous0 {
523 start := charnum
524 for s.norm[charnum] == 0 {
525 charnum++
526 }
527 for charnum >= start+24 {
528 start += 24
529 bitStream += uint32(0xFFFF) << bitCount
530 out[outP] = byte(bitStream)
531 out[outP+1] = byte(bitStream >> 8)
532 outP += 2
533 bitStream >>= 16
534 }
535 for charnum >= start+3 {
536 start += 3
537 bitStream += 3 << bitCount
538 bitCount += 2
539 }
540 bitStream += uint32(charnum-start) << bitCount
541 bitCount += 2
542 if bitCount > 16 {
543 out[outP] = byte(bitStream)
544 out[outP+1] = byte(bitStream >> 8)
545 outP += 2
546 bitStream >>= 16
547 bitCount -= 16
548 }
549 }
550
551 count := s.norm[charnum]
552 charnum++
553 max := (2*threshold - 1) - remaining
554 if count < 0 {
555 remaining += count
556 } else {
557 remaining -= count
558 }
559 count++
560 if count >= threshold {
561 count += max
562 }
563 bitStream += uint32(count) << bitCount
564 bitCount += nbBits
565 if count < max {
566 bitCount--
567 }
568
569 previous0 = count == 1
570 if remaining < 1 {
571 return nil, errors.New("internal error: remaining < 1")
572 }
573 for remaining < threshold {
574 nbBits--
575 threshold >>= 1
576 }
577
578 if bitCount > 16 {
579 out[outP] = byte(bitStream)
580 out[outP+1] = byte(bitStream >> 8)
581 outP += 2
582 bitStream >>= 16
583 bitCount -= 16
584 }
585 }
586
587 if outP+2 > len(out) {
588 return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", outP+2, len(out), maxHeaderSize, s.symbolLen, int(tableLog), s.norm[:s.symbolLen])
589 }
590 out[outP] = byte(bitStream)
591 out[outP+1] = byte(bitStream >> 8)
592 outP += int((bitCount + 7) / 8)
593
594 if charnum > s.symbolLen {
595 return nil, errors.New("internal error: charnum > s.symbolLen")
596 }
597 return out[:outP], nil
598 }
599
600
601
602
603 func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 {
604 minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16
605 threshold := (minNbBits + 1) << 16
606 if debugAsserts {
607 if !(s.actualTableLog < 16) {
608 panic("!s.actualTableLog < 16")
609 }
610
611 if !(uint8(accuracyLog) < 31-s.actualTableLog) {
612 panic("!uint8(accuracyLog) < 31-s.actualTableLog")
613 }
614 }
615 tableSize := uint32(1) << s.actualTableLog
616 deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize)
617
618 normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog
619 bitMultiplier := uint32(1) << accuracyLog
620 if debugAsserts {
621 if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold {
622 panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold")
623 }
624 if normalizedDeltaFromThreshold > bitMultiplier {
625 panic("normalizedDeltaFromThreshold > bitMultiplier")
626 }
627 }
628 return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold
629 }
630
631
632
633
634 func (s *fseEncoder) approxSize(hist []uint32) uint32 {
635 if int(s.symbolLen) < len(hist) {
636
637 return math.MaxUint32
638 }
639 if s.useRLE {
640
641 return math.MaxUint32
642 }
643 const kAccuracyLog = 8
644 badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog
645 var cost uint32
646 for i, v := range hist {
647 if v == 0 {
648 continue
649 }
650 if s.norm[i] == 0 {
651 return math.MaxUint32
652 }
653 bitCost := s.bitCost(uint8(i), kAccuracyLog)
654 if bitCost > badCost {
655 return math.MaxUint32
656 }
657 cost += v * bitCost
658 }
659 return cost >> kAccuracyLog
660 }
661
662
663
664 func (s *fseEncoder) maxHeaderSize() uint32 {
665 if s.preDefined {
666 return 0
667 }
668 if s.useRLE {
669 return 8
670 }
671 return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8
672 }
673
674
675 type cState struct {
676 bw *bitWriter
677 stateTable []uint16
678 state uint16
679 }
680
681
682 func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) {
683 c.bw = bw
684 c.stateTable = ct.stateTable
685 if len(c.stateTable) == 1 {
686
687 c.stateTable[0] = uint16(0)
688 c.state = 0
689 return
690 }
691 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
692 im := int32((nbBitsOut << 16) - first.deltaNbBits)
693 lu := (im >> nbBitsOut) + int32(first.deltaFindState)
694 c.state = c.stateTable[lu]
695 }
696
697
698 func (c *cState) flush(tableLog uint8) {
699 c.bw.flush32()
700 c.bw.addBits16NC(c.state, tableLog)
701 }
702
View as plain text