1
2
3
4
5 package zstd
6
7 import (
8 "bytes"
9 "fmt"
10
11 "github.com/klauspost/compress"
12 )
13
14 const (
15 bestLongTableBits = 22
16 bestLongTableSize = 1 << bestLongTableBits
17 bestLongLen = 8
18
19
20
21
22
23 bestShortTableBits = 18
24 bestShortTableSize = 1 << bestShortTableBits
25 bestShortLen = 4
26
27 )
28
29 type match struct {
30 offset int32
31 s int32
32 length int32
33 rep int32
34 est int32
35 }
36
37 const highScore = maxMatchLen * 8
38
39
40 func (m *match) estBits(bitsPerByte int32) {
41 mlc := mlCode(uint32(m.length - zstdMinMatch))
42 var ofc uint8
43 if m.rep < 0 {
44 ofc = ofCode(uint32(m.s-m.offset) + 3)
45 } else {
46 ofc = ofCode(uint32(m.rep) & 3)
47 }
48
49 ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc]
50
51
52 m.est = int32(ofTT.outBits + mlTT.outBits)
53 m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16)
54
55 m.est -= (m.length * bitsPerByte) >> 10
56 if m.est > 0 {
57
58 m.length = 0
59 m.est = highScore
60 }
61 }
62
63
64
65
66
67
68
69 type bestFastEncoder struct {
70 fastBase
71 table [bestShortTableSize]prevEntry
72 longTable [bestLongTableSize]prevEntry
73 dictTable []prevEntry
74 dictLongTable []prevEntry
75 }
76
77
78 func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
79 const (
80
81
82 inputMargin = 8 + 4
83 minNonLiteralBlockSize = 16
84 )
85
86
87 for e.cur >= e.bufferReset-int32(len(e.hist)) {
88 if len(e.hist) == 0 {
89 e.table = [bestShortTableSize]prevEntry{}
90 e.longTable = [bestLongTableSize]prevEntry{}
91 e.cur = e.maxMatchOff
92 break
93 }
94
95 minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
96 for i := range e.table[:] {
97 v := e.table[i].offset
98 v2 := e.table[i].prev
99 if v < minOff {
100 v = 0
101 v2 = 0
102 } else {
103 v = v - e.cur + e.maxMatchOff
104 if v2 < minOff {
105 v2 = 0
106 } else {
107 v2 = v2 - e.cur + e.maxMatchOff
108 }
109 }
110 e.table[i] = prevEntry{
111 offset: v,
112 prev: v2,
113 }
114 }
115 for i := range e.longTable[:] {
116 v := e.longTable[i].offset
117 v2 := e.longTable[i].prev
118 if v < minOff {
119 v = 0
120 v2 = 0
121 } else {
122 v = v - e.cur + e.maxMatchOff
123 if v2 < minOff {
124 v2 = 0
125 } else {
126 v2 = v2 - e.cur + e.maxMatchOff
127 }
128 }
129 e.longTable[i] = prevEntry{
130 offset: v,
131 prev: v2,
132 }
133 }
134 e.cur = e.maxMatchOff
135 break
136 }
137
138
139 s := e.addBlock(src)
140 blk.size = len(src)
141
142
143 if len(src) > zstdMinMatch {
144 ml := matchLen(src[1:], src)
145 if ml == len(src)-1 {
146 blk.literals = append(blk.literals, src[0])
147 blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3})
148 return
149 }
150 }
151
152 if len(src) < minNonLiteralBlockSize {
153 blk.extraLits = len(src)
154 blk.literals = blk.literals[:len(src)]
155 copy(blk.literals, src)
156 return
157 }
158
159
160
161 bitsPerByte := int32((compress.ShannonEntropyBits(src) * 1024) / len(src))
162
163 if bitsPerByte < 1024 {
164 bitsPerByte = 1024
165 }
166
167
168 src = e.hist
169 sLimit := int32(len(src)) - inputMargin
170 const kSearchStrength = 10
171
172
173 nextEmit := s
174
175
176 offset1 := int32(blk.recentOffsets[0])
177 offset2 := int32(blk.recentOffsets[1])
178 offset3 := int32(blk.recentOffsets[2])
179
180 addLiterals := func(s *seq, until int32) {
181 if until == nextEmit {
182 return
183 }
184 blk.literals = append(blk.literals, src[nextEmit:until]...)
185 s.litLen = uint32(until - nextEmit)
186 }
187
188 if debugEncoder {
189 println("recent offsets:", blk.recentOffsets)
190 }
191
192 encodeLoop:
193 for {
194
195 canRepeat := len(blk.sequences) > 2
196
197 if debugAsserts && canRepeat && offset1 == 0 {
198 panic("offset0 was 0")
199 }
200
201 const goodEnough = 250
202
203 cv := load6432(src, s)
204
205 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
206 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
207 candidateL := e.longTable[nextHashL]
208 candidateS := e.table[nextHashS]
209
210
211 improve := func(m *match, offset int32, s int32, first uint32, rep int32) {
212 delta := s - offset
213 if delta >= e.maxMatchOff || delta <= 0 || load3232(src, offset) != first {
214 return
215 }
216
217 if m.length > 16 {
218 left := len(src) - int(m.s+m.length)
219
220 if left <= 0 {
221 return
222 }
223 checkLen := m.length - (s - m.s) - 8
224 if left > 2 && checkLen > 4 {
225
226 a := load3232(src, offset+checkLen)
227 b := load3232(src, s+checkLen)
228 if a != b {
229 return
230 }
231 }
232 }
233 l := 4 + e.matchlen(s+4, offset+4, src)
234 if m.rep <= 0 {
235
236
237
238 tMin := s - e.maxMatchOff
239 if tMin < 0 {
240 tMin = 0
241 }
242 for offset > tMin && s > nextEmit && src[offset-1] == src[s-1] && l < maxMatchLength {
243 s--
244 offset--
245 l++
246 }
247 }
248 if debugAsserts {
249 if offset >= s {
250 panic(fmt.Sprintf("offset: %d - s:%d - rep: %d - cur :%d - max: %d", offset, s, rep, e.cur, e.maxMatchOff))
251 }
252 if !bytes.Equal(src[s:s+l], src[offset:offset+l]) {
253 panic(fmt.Sprintf("second match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first))
254 }
255 }
256 cand := match{offset: offset, s: s, length: l, rep: rep}
257 cand.estBits(bitsPerByte)
258 if m.est >= highScore || cand.est-m.est+(cand.s-m.s)*bitsPerByte>>10 < 0 {
259 *m = cand
260 }
261 }
262
263 best := match{s: s, est: highScore}
264 improve(&best, candidateL.offset-e.cur, s, uint32(cv), -1)
265 improve(&best, candidateL.prev-e.cur, s, uint32(cv), -1)
266 improve(&best, candidateS.offset-e.cur, s, uint32(cv), -1)
267 improve(&best, candidateS.prev-e.cur, s, uint32(cv), -1)
268
269 if canRepeat && best.length < goodEnough {
270 if s == nextEmit {
271
272 improve(&best, s-offset2, s, uint32(cv), 1|4)
273 improve(&best, s-offset3, s, uint32(cv), 2|4)
274 if offset1 > 1 {
275 improve(&best, s-(offset1-1), s, uint32(cv), 3|4)
276 }
277 }
278
279
280 if best.rep <= 0 {
281 cv32 := uint32(cv >> 8)
282 spp := s + 1
283 improve(&best, spp-offset1, spp, cv32, 1)
284 improve(&best, spp-offset2, spp, cv32, 2)
285 improve(&best, spp-offset3, spp, cv32, 3)
286 if best.rep < 0 {
287 cv32 = uint32(cv >> 24)
288 spp += 2
289 improve(&best, spp-offset1, spp, cv32, 1)
290 improve(&best, spp-offset2, spp, cv32, 2)
291 improve(&best, spp-offset3, spp, cv32, 3)
292 }
293 }
294 }
295
296 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
297 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
298 index0 := s + 1
299
300
301 if best.length < goodEnough {
302
303 if best.length < 4 {
304 s += 1 + (s-nextEmit)>>(kSearchStrength-1)
305 if s >= sLimit {
306 break encodeLoop
307 }
308 continue
309 }
310
311 candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)]
312 cv = load6432(src, s+1)
313 cv2 := load6432(src, s+2)
314 candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)]
315 candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)]
316
317
318 improve(&best, candidateS.offset-e.cur, s+1, uint32(cv), -1)
319
320 improve(&best, candidateL.offset-e.cur, s+1, uint32(cv), -1)
321 improve(&best, candidateL.prev-e.cur, s+1, uint32(cv), -1)
322 improve(&best, candidateL2.offset-e.cur, s+2, uint32(cv2), -1)
323 improve(&best, candidateL2.prev-e.cur, s+2, uint32(cv2), -1)
324 if false {
325
326
327 improve(&best, e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+3, uint32(cv2>>8), -1)
328 }
329
330
331
332
333 const skipBeginning = 2
334 if best.s > s-skipBeginning {
335
336
337 if sAt := best.s + best.length; sAt < sLimit {
338 nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen)
339 candidateEnd := e.longTable[nextHashL]
340
341 if off := candidateEnd.offset - e.cur - best.length + skipBeginning; off >= 0 {
342 improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
343 if off := candidateEnd.prev - e.cur - best.length + skipBeginning; off >= 0 {
344 improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
345 }
346 }
347 }
348 }
349 }
350
351 if debugAsserts {
352 if best.offset >= best.s {
353 panic(fmt.Sprintf("best.offset > s: %d >= %d", best.offset, best.s))
354 }
355 if best.s < nextEmit {
356 panic(fmt.Sprintf("s %d < nextEmit %d", best.s, nextEmit))
357 }
358 if best.offset < s-e.maxMatchOff {
359 panic(fmt.Sprintf("best.offset < s-e.maxMatchOff: %d < %d", best.offset, s-e.maxMatchOff))
360 }
361 if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) {
362 panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]))
363 }
364 }
365
366
367 s = best.s
368 if best.rep > 0 {
369 var seq seq
370 seq.matchLen = uint32(best.length - zstdMinMatch)
371 addLiterals(&seq, best.s)
372
373
374 seq.offset = uint32(best.rep & 3)
375 if debugSequences {
376 println("repeat sequence", seq, "next s:", best.s, "off:", best.s-best.offset)
377 }
378 blk.sequences = append(blk.sequences, seq)
379
380
381 s = best.s + best.length
382 nextEmit = s
383
384
385 end := s
386 if s > sLimit+4 {
387 end = sLimit + 4
388 }
389 off := index0 + e.cur
390 for index0 < end {
391 cv0 := load6432(src, index0)
392 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
393 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
394 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
395 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
396 off++
397 index0++
398 }
399
400 switch best.rep {
401 case 2, 4 | 1:
402 offset1, offset2 = offset2, offset1
403 case 3, 4 | 2:
404 offset1, offset2, offset3 = offset3, offset1, offset2
405 case 4 | 3:
406 offset1, offset2, offset3 = offset1-1, offset1, offset2
407 }
408 if s >= sLimit {
409 if debugEncoder {
410 println("repeat ended", s, best.length)
411 }
412 break encodeLoop
413 }
414 continue
415 }
416
417
418
419 t := best.offset
420 offset1, offset2, offset3 = s-t, offset1, offset2
421
422 if debugAsserts && s <= t {
423 panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
424 }
425
426 if debugAsserts && int(offset1) > len(src) {
427 panic("invalid offset")
428 }
429
430
431 var seq seq
432 l := best.length
433 seq.litLen = uint32(s - nextEmit)
434 seq.matchLen = uint32(l - zstdMinMatch)
435 if seq.litLen > 0 {
436 blk.literals = append(blk.literals, src[nextEmit:s]...)
437 }
438 seq.offset = uint32(s-t) + 3
439 s += l
440 if debugSequences {
441 println("sequence", seq, "next s:", s)
442 }
443 blk.sequences = append(blk.sequences, seq)
444 nextEmit = s
445
446
447 end := s
448 if s > sLimit-4 {
449 end = sLimit - 4
450 }
451
452 off := index0 + e.cur
453 for index0 < end {
454 cv0 := load6432(src, index0)
455 h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
456 h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
457 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
458 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
459 index0++
460 off++
461 }
462 if s >= sLimit {
463 break encodeLoop
464 }
465 }
466
467 if int(nextEmit) < len(src) {
468 blk.literals = append(blk.literals, src[nextEmit:]...)
469 blk.extraLits = len(src) - int(nextEmit)
470 }
471 blk.recentOffsets[0] = uint32(offset1)
472 blk.recentOffsets[1] = uint32(offset2)
473 blk.recentOffsets[2] = uint32(offset3)
474 if debugEncoder {
475 println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
476 }
477 }
478
479
480
481
482 func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
483 e.ensureHist(len(src))
484 e.Encode(blk, src)
485 }
486
487
488 func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
489 e.resetBase(d, singleBlock)
490 if d == nil {
491 return
492 }
493
494 if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
495 if len(e.dictTable) != len(e.table) {
496 e.dictTable = make([]prevEntry, len(e.table))
497 }
498 end := int32(len(d.content)) - 8 + e.maxMatchOff
499 for i := e.maxMatchOff; i < end; i += 4 {
500 const hashLog = bestShortTableBits
501
502 cv := load6432(d.content, i-e.maxMatchOff)
503 nextHash := hashLen(cv, hashLog, bestShortLen)
504 nextHash1 := hashLen(cv>>8, hashLog, bestShortLen)
505 nextHash2 := hashLen(cv>>16, hashLog, bestShortLen)
506 nextHash3 := hashLen(cv>>24, hashLog, bestShortLen)
507 e.dictTable[nextHash] = prevEntry{
508 prev: e.dictTable[nextHash].offset,
509 offset: i,
510 }
511 e.dictTable[nextHash1] = prevEntry{
512 prev: e.dictTable[nextHash1].offset,
513 offset: i + 1,
514 }
515 e.dictTable[nextHash2] = prevEntry{
516 prev: e.dictTable[nextHash2].offset,
517 offset: i + 2,
518 }
519 e.dictTable[nextHash3] = prevEntry{
520 prev: e.dictTable[nextHash3].offset,
521 offset: i + 3,
522 }
523 }
524 e.lastDictID = d.id
525 }
526
527
528 if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
529 if len(e.dictLongTable) != len(e.longTable) {
530 e.dictLongTable = make([]prevEntry, len(e.longTable))
531 }
532 if len(d.content) >= 8 {
533 cv := load6432(d.content, 0)
534 h := hashLen(cv, bestLongTableBits, bestLongLen)
535 e.dictLongTable[h] = prevEntry{
536 offset: e.maxMatchOff,
537 prev: e.dictLongTable[h].offset,
538 }
539
540 end := int32(len(d.content)) - 8 + e.maxMatchOff
541 off := 8
542 for i := e.maxMatchOff + 1; i < end; i++ {
543 cv = cv>>8 | (uint64(d.content[off]) << 56)
544 h := hashLen(cv, bestLongTableBits, bestLongLen)
545 e.dictLongTable[h] = prevEntry{
546 offset: i,
547 prev: e.dictLongTable[h].offset,
548 }
549 off++
550 }
551 }
552 e.lastDictID = d.id
553 }
554
555 copy(e.longTable[:], e.dictLongTable)
556
557 e.cur = e.maxMatchOff
558
559 copy(e.table[:], e.dictTable)
560 }
561
View as plain text