...

Source file src/github.com/klauspost/compress/zstd/enc_best.go

Documentation: github.com/klauspost/compress/zstd

     1  // Copyright 2019+ Klaus Post. All rights reserved.
     2  // License information can be found in the LICENSE file.
     3  // Based on work by Yann Collet, released under BSD License.
     4  
     5  package zstd
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  
    11  	"github.com/klauspost/compress"
    12  )
    13  
    14  const (
    15  	bestLongTableBits = 22                     // Bits used in the long match table
    16  	bestLongTableSize = 1 << bestLongTableBits // Size of the table
    17  	bestLongLen       = 8                      // Bytes used for table hash
    18  
    19  	// Note: Increasing the short table bits or making the hash shorter
    20  	// can actually lead to compression degradation since it will 'steal' more from the
    21  	// long match table and match offsets are quite big.
    22  	// This greatly depends on the type of input.
    23  	bestShortTableBits = 18                      // Bits used in the short match table
    24  	bestShortTableSize = 1 << bestShortTableBits // Size of the table
    25  	bestShortLen       = 4                       // Bytes used for table hash
    26  
    27  )
    28  
    29  type match struct {
    30  	offset int32
    31  	s      int32
    32  	length int32
    33  	rep    int32
    34  	est    int32
    35  }
    36  
    37  const highScore = maxMatchLen * 8
    38  
    39  // estBits will estimate output bits from predefined tables.
    40  func (m *match) estBits(bitsPerByte int32) {
    41  	mlc := mlCode(uint32(m.length - zstdMinMatch))
    42  	var ofc uint8
    43  	if m.rep < 0 {
    44  		ofc = ofCode(uint32(m.s-m.offset) + 3)
    45  	} else {
    46  		ofc = ofCode(uint32(m.rep) & 3)
    47  	}
    48  	// Cost, excluding
    49  	ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc]
    50  
    51  	// Add cost of match encoding...
    52  	m.est = int32(ofTT.outBits + mlTT.outBits)
    53  	m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16)
    54  	// Subtract savings compared to literal encoding...
    55  	m.est -= (m.length * bitsPerByte) >> 10
    56  	if m.est > 0 {
    57  		// Unlikely gain..
    58  		m.length = 0
    59  		m.est = highScore
    60  	}
    61  }
    62  
    63  // bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
    64  // The long match table contains the previous entry with the same hash,
    65  // effectively making it a "chain" of length 2.
    66  // When we find a long match we choose between the two values and select the longest.
    67  // When we find a short match, after checking the long, we check if we can find a long at n+1
    68  // and that it is longer (lazy matching).
    69  type bestFastEncoder struct {
    70  	fastBase
    71  	table         [bestShortTableSize]prevEntry
    72  	longTable     [bestLongTableSize]prevEntry
    73  	dictTable     []prevEntry
    74  	dictLongTable []prevEntry
    75  }
    76  
    77  // Encode improves compression...
    78  func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
    79  	const (
    80  		// Input margin is the number of bytes we read (8)
    81  		// and the maximum we will read ahead (2)
    82  		inputMargin            = 8 + 4
    83  		minNonLiteralBlockSize = 16
    84  	)
    85  
    86  	// Protect against e.cur wraparound.
    87  	for e.cur >= e.bufferReset-int32(len(e.hist)) {
    88  		if len(e.hist) == 0 {
    89  			e.table = [bestShortTableSize]prevEntry{}
    90  			e.longTable = [bestLongTableSize]prevEntry{}
    91  			e.cur = e.maxMatchOff
    92  			break
    93  		}
    94  		// Shift down everything in the table that isn't already too far away.
    95  		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
    96  		for i := range e.table[:] {
    97  			v := e.table[i].offset
    98  			v2 := e.table[i].prev
    99  			if v < minOff {
   100  				v = 0
   101  				v2 = 0
   102  			} else {
   103  				v = v - e.cur + e.maxMatchOff
   104  				if v2 < minOff {
   105  					v2 = 0
   106  				} else {
   107  					v2 = v2 - e.cur + e.maxMatchOff
   108  				}
   109  			}
   110  			e.table[i] = prevEntry{
   111  				offset: v,
   112  				prev:   v2,
   113  			}
   114  		}
   115  		for i := range e.longTable[:] {
   116  			v := e.longTable[i].offset
   117  			v2 := e.longTable[i].prev
   118  			if v < minOff {
   119  				v = 0
   120  				v2 = 0
   121  			} else {
   122  				v = v - e.cur + e.maxMatchOff
   123  				if v2 < minOff {
   124  					v2 = 0
   125  				} else {
   126  					v2 = v2 - e.cur + e.maxMatchOff
   127  				}
   128  			}
   129  			e.longTable[i] = prevEntry{
   130  				offset: v,
   131  				prev:   v2,
   132  			}
   133  		}
   134  		e.cur = e.maxMatchOff
   135  		break
   136  	}
   137  
   138  	// Add block to history
   139  	s := e.addBlock(src)
   140  	blk.size = len(src)
   141  
   142  	// Check RLE first
   143  	if len(src) > zstdMinMatch {
   144  		ml := matchLen(src[1:], src)
   145  		if ml == len(src)-1 {
   146  			blk.literals = append(blk.literals, src[0])
   147  			blk.sequences = append(blk.sequences, seq{litLen: 1, matchLen: uint32(len(src)-1) - zstdMinMatch, offset: 1 + 3})
   148  			return
   149  		}
   150  	}
   151  
   152  	if len(src) < minNonLiteralBlockSize {
   153  		blk.extraLits = len(src)
   154  		blk.literals = blk.literals[:len(src)]
   155  		copy(blk.literals, src)
   156  		return
   157  	}
   158  
   159  	// Use this to estimate literal cost.
   160  	// Scaled by 10 bits.
   161  	bitsPerByte := int32((compress.ShannonEntropyBits(src) * 1024) / len(src))
   162  	// Huffman can never go < 1 bit/byte
   163  	if bitsPerByte < 1024 {
   164  		bitsPerByte = 1024
   165  	}
   166  
   167  	// Override src
   168  	src = e.hist
   169  	sLimit := int32(len(src)) - inputMargin
   170  	const kSearchStrength = 10
   171  
   172  	// nextEmit is where in src the next emitLiteral should start from.
   173  	nextEmit := s
   174  
   175  	// Relative offsets
   176  	offset1 := int32(blk.recentOffsets[0])
   177  	offset2 := int32(blk.recentOffsets[1])
   178  	offset3 := int32(blk.recentOffsets[2])
   179  
   180  	addLiterals := func(s *seq, until int32) {
   181  		if until == nextEmit {
   182  			return
   183  		}
   184  		blk.literals = append(blk.literals, src[nextEmit:until]...)
   185  		s.litLen = uint32(until - nextEmit)
   186  	}
   187  
   188  	if debugEncoder {
   189  		println("recent offsets:", blk.recentOffsets)
   190  	}
   191  
   192  encodeLoop:
   193  	for {
   194  		// We allow the encoder to optionally turn off repeat offsets across blocks
   195  		canRepeat := len(blk.sequences) > 2
   196  
   197  		if debugAsserts && canRepeat && offset1 == 0 {
   198  			panic("offset0 was 0")
   199  		}
   200  
   201  		const goodEnough = 250
   202  
   203  		cv := load6432(src, s)
   204  
   205  		nextHashL := hashLen(cv, bestLongTableBits, bestLongLen)
   206  		nextHashS := hashLen(cv, bestShortTableBits, bestShortLen)
   207  		candidateL := e.longTable[nextHashL]
   208  		candidateS := e.table[nextHashS]
   209  
   210  		// Set m to a match at offset if it looks like that will improve compression.
   211  		improve := func(m *match, offset int32, s int32, first uint32, rep int32) {
   212  			delta := s - offset
   213  			if delta >= e.maxMatchOff || delta <= 0 || load3232(src, offset) != first {
   214  				return
   215  			}
   216  			// Try to quick reject if we already have a long match.
   217  			if m.length > 16 {
   218  				left := len(src) - int(m.s+m.length)
   219  				// If we are too close to the end, keep as is.
   220  				if left <= 0 {
   221  					return
   222  				}
   223  				checkLen := m.length - (s - m.s) - 8
   224  				if left > 2 && checkLen > 4 {
   225  					// Check 4 bytes, 4 bytes from the end of the current match.
   226  					a := load3232(src, offset+checkLen)
   227  					b := load3232(src, s+checkLen)
   228  					if a != b {
   229  						return
   230  					}
   231  				}
   232  			}
   233  			l := 4 + e.matchlen(s+4, offset+4, src)
   234  			if m.rep <= 0 {
   235  				// Extend candidate match backwards as far as possible.
   236  				// Do not extend repeats as we can assume they are optimal
   237  				// and offsets change if s == nextEmit.
   238  				tMin := s - e.maxMatchOff
   239  				if tMin < 0 {
   240  					tMin = 0
   241  				}
   242  				for offset > tMin && s > nextEmit && src[offset-1] == src[s-1] && l < maxMatchLength {
   243  					s--
   244  					offset--
   245  					l++
   246  				}
   247  			}
   248  			if debugAsserts {
   249  				if offset >= s {
   250  					panic(fmt.Sprintf("offset: %d - s:%d - rep: %d - cur :%d - max: %d", offset, s, rep, e.cur, e.maxMatchOff))
   251  				}
   252  				if !bytes.Equal(src[s:s+l], src[offset:offset+l]) {
   253  					panic(fmt.Sprintf("second match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first))
   254  				}
   255  			}
   256  			cand := match{offset: offset, s: s, length: l, rep: rep}
   257  			cand.estBits(bitsPerByte)
   258  			if m.est >= highScore || cand.est-m.est+(cand.s-m.s)*bitsPerByte>>10 < 0 {
   259  				*m = cand
   260  			}
   261  		}
   262  
   263  		best := match{s: s, est: highScore}
   264  		improve(&best, candidateL.offset-e.cur, s, uint32(cv), -1)
   265  		improve(&best, candidateL.prev-e.cur, s, uint32(cv), -1)
   266  		improve(&best, candidateS.offset-e.cur, s, uint32(cv), -1)
   267  		improve(&best, candidateS.prev-e.cur, s, uint32(cv), -1)
   268  
   269  		if canRepeat && best.length < goodEnough {
   270  			if s == nextEmit {
   271  				// Check repeats straight after a match.
   272  				improve(&best, s-offset2, s, uint32(cv), 1|4)
   273  				improve(&best, s-offset3, s, uint32(cv), 2|4)
   274  				if offset1 > 1 {
   275  					improve(&best, s-(offset1-1), s, uint32(cv), 3|4)
   276  				}
   277  			}
   278  
   279  			// If either no match or a non-repeat match, check at + 1
   280  			if best.rep <= 0 {
   281  				cv32 := uint32(cv >> 8)
   282  				spp := s + 1
   283  				improve(&best, spp-offset1, spp, cv32, 1)
   284  				improve(&best, spp-offset2, spp, cv32, 2)
   285  				improve(&best, spp-offset3, spp, cv32, 3)
   286  				if best.rep < 0 {
   287  					cv32 = uint32(cv >> 24)
   288  					spp += 2
   289  					improve(&best, spp-offset1, spp, cv32, 1)
   290  					improve(&best, spp-offset2, spp, cv32, 2)
   291  					improve(&best, spp-offset3, spp, cv32, 3)
   292  				}
   293  			}
   294  		}
   295  		// Load next and check...
   296  		e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
   297  		e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
   298  		index0 := s + 1
   299  
   300  		// Look far ahead, unless we have a really long match already...
   301  		if best.length < goodEnough {
   302  			// No match found, move forward on input, no need to check forward...
   303  			if best.length < 4 {
   304  				s += 1 + (s-nextEmit)>>(kSearchStrength-1)
   305  				if s >= sLimit {
   306  					break encodeLoop
   307  				}
   308  				continue
   309  			}
   310  
   311  			candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)]
   312  			cv = load6432(src, s+1)
   313  			cv2 := load6432(src, s+2)
   314  			candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)]
   315  			candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)]
   316  
   317  			// Short at s+1
   318  			improve(&best, candidateS.offset-e.cur, s+1, uint32(cv), -1)
   319  			// Long at s+1, s+2
   320  			improve(&best, candidateL.offset-e.cur, s+1, uint32(cv), -1)
   321  			improve(&best, candidateL.prev-e.cur, s+1, uint32(cv), -1)
   322  			improve(&best, candidateL2.offset-e.cur, s+2, uint32(cv2), -1)
   323  			improve(&best, candidateL2.prev-e.cur, s+2, uint32(cv2), -1)
   324  			if false {
   325  				// Short at s+3.
   326  				// Too often worse...
   327  				improve(&best, e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+3, uint32(cv2>>8), -1)
   328  			}
   329  
   330  			// Start check at a fixed offset to allow for a few mismatches.
   331  			// For this compression level 2 yields the best results.
   332  			// We cannot do this if we have already indexed this position.
   333  			const skipBeginning = 2
   334  			if best.s > s-skipBeginning {
   335  				// See if we can find a better match by checking where the current best ends.
   336  				// Use that offset to see if we can find a better full match.
   337  				if sAt := best.s + best.length; sAt < sLimit {
   338  					nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen)
   339  					candidateEnd := e.longTable[nextHashL]
   340  
   341  					if off := candidateEnd.offset - e.cur - best.length + skipBeginning; off >= 0 {
   342  						improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
   343  						if off := candidateEnd.prev - e.cur - best.length + skipBeginning; off >= 0 {
   344  							improve(&best, off, best.s+skipBeginning, load3232(src, best.s+skipBeginning), -1)
   345  						}
   346  					}
   347  				}
   348  			}
   349  		}
   350  
   351  		if debugAsserts {
   352  			if best.offset >= best.s {
   353  				panic(fmt.Sprintf("best.offset > s: %d >= %d", best.offset, best.s))
   354  			}
   355  			if best.s < nextEmit {
   356  				panic(fmt.Sprintf("s %d < nextEmit %d", best.s, nextEmit))
   357  			}
   358  			if best.offset < s-e.maxMatchOff {
   359  				panic(fmt.Sprintf("best.offset < s-e.maxMatchOff: %d < %d", best.offset, s-e.maxMatchOff))
   360  			}
   361  			if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) {
   362  				panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]))
   363  			}
   364  		}
   365  
   366  		// We have a match, we can store the forward value
   367  		s = best.s
   368  		if best.rep > 0 {
   369  			var seq seq
   370  			seq.matchLen = uint32(best.length - zstdMinMatch)
   371  			addLiterals(&seq, best.s)
   372  
   373  			// Repeat. If bit 4 is set, this is a non-lit repeat.
   374  			seq.offset = uint32(best.rep & 3)
   375  			if debugSequences {
   376  				println("repeat sequence", seq, "next s:", best.s, "off:", best.s-best.offset)
   377  			}
   378  			blk.sequences = append(blk.sequences, seq)
   379  
   380  			// Index old s + 1 -> s - 1
   381  			s = best.s + best.length
   382  			nextEmit = s
   383  
   384  			// Index skipped...
   385  			end := s
   386  			if s > sLimit+4 {
   387  				end = sLimit + 4
   388  			}
   389  			off := index0 + e.cur
   390  			for index0 < end {
   391  				cv0 := load6432(src, index0)
   392  				h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
   393  				h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
   394  				e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
   395  				e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
   396  				off++
   397  				index0++
   398  			}
   399  
   400  			switch best.rep {
   401  			case 2, 4 | 1:
   402  				offset1, offset2 = offset2, offset1
   403  			case 3, 4 | 2:
   404  				offset1, offset2, offset3 = offset3, offset1, offset2
   405  			case 4 | 3:
   406  				offset1, offset2, offset3 = offset1-1, offset1, offset2
   407  			}
   408  			if s >= sLimit {
   409  				if debugEncoder {
   410  					println("repeat ended", s, best.length)
   411  				}
   412  				break encodeLoop
   413  			}
   414  			continue
   415  		}
   416  
   417  		// A 4-byte match has been found. Update recent offsets.
   418  		// We'll later see if more than 4 bytes.
   419  		t := best.offset
   420  		offset1, offset2, offset3 = s-t, offset1, offset2
   421  
   422  		if debugAsserts && s <= t {
   423  			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
   424  		}
   425  
   426  		if debugAsserts && int(offset1) > len(src) {
   427  			panic("invalid offset")
   428  		}
   429  
   430  		// Write our sequence
   431  		var seq seq
   432  		l := best.length
   433  		seq.litLen = uint32(s - nextEmit)
   434  		seq.matchLen = uint32(l - zstdMinMatch)
   435  		if seq.litLen > 0 {
   436  			blk.literals = append(blk.literals, src[nextEmit:s]...)
   437  		}
   438  		seq.offset = uint32(s-t) + 3
   439  		s += l
   440  		if debugSequences {
   441  			println("sequence", seq, "next s:", s)
   442  		}
   443  		blk.sequences = append(blk.sequences, seq)
   444  		nextEmit = s
   445  
   446  		// Index old s + 1 -> s - 1 or sLimit
   447  		end := s
   448  		if s > sLimit-4 {
   449  			end = sLimit - 4
   450  		}
   451  
   452  		off := index0 + e.cur
   453  		for index0 < end {
   454  			cv0 := load6432(src, index0)
   455  			h0 := hashLen(cv0, bestLongTableBits, bestLongLen)
   456  			h1 := hashLen(cv0, bestShortTableBits, bestShortLen)
   457  			e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
   458  			e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
   459  			index0++
   460  			off++
   461  		}
   462  		if s >= sLimit {
   463  			break encodeLoop
   464  		}
   465  	}
   466  
   467  	if int(nextEmit) < len(src) {
   468  		blk.literals = append(blk.literals, src[nextEmit:]...)
   469  		blk.extraLits = len(src) - int(nextEmit)
   470  	}
   471  	blk.recentOffsets[0] = uint32(offset1)
   472  	blk.recentOffsets[1] = uint32(offset2)
   473  	blk.recentOffsets[2] = uint32(offset3)
   474  	if debugEncoder {
   475  		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
   476  	}
   477  }
   478  
   479  // EncodeNoHist will encode a block with no history and no following blocks.
   480  // Most notable difference is that src will not be copied for history and
   481  // we do not need to check for max match length.
   482  func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
   483  	e.ensureHist(len(src))
   484  	e.Encode(blk, src)
   485  }
   486  
   487  // Reset will reset and set a dictionary if not nil
   488  func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
   489  	e.resetBase(d, singleBlock)
   490  	if d == nil {
   491  		return
   492  	}
   493  	// Init or copy dict table
   494  	if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
   495  		if len(e.dictTable) != len(e.table) {
   496  			e.dictTable = make([]prevEntry, len(e.table))
   497  		}
   498  		end := int32(len(d.content)) - 8 + e.maxMatchOff
   499  		for i := e.maxMatchOff; i < end; i += 4 {
   500  			const hashLog = bestShortTableBits
   501  
   502  			cv := load6432(d.content, i-e.maxMatchOff)
   503  			nextHash := hashLen(cv, hashLog, bestShortLen)      // 0 -> 4
   504  			nextHash1 := hashLen(cv>>8, hashLog, bestShortLen)  // 1 -> 5
   505  			nextHash2 := hashLen(cv>>16, hashLog, bestShortLen) // 2 -> 6
   506  			nextHash3 := hashLen(cv>>24, hashLog, bestShortLen) // 3 -> 7
   507  			e.dictTable[nextHash] = prevEntry{
   508  				prev:   e.dictTable[nextHash].offset,
   509  				offset: i,
   510  			}
   511  			e.dictTable[nextHash1] = prevEntry{
   512  				prev:   e.dictTable[nextHash1].offset,
   513  				offset: i + 1,
   514  			}
   515  			e.dictTable[nextHash2] = prevEntry{
   516  				prev:   e.dictTable[nextHash2].offset,
   517  				offset: i + 2,
   518  			}
   519  			e.dictTable[nextHash3] = prevEntry{
   520  				prev:   e.dictTable[nextHash3].offset,
   521  				offset: i + 3,
   522  			}
   523  		}
   524  		e.lastDictID = d.id
   525  	}
   526  
   527  	// Init or copy dict table
   528  	if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
   529  		if len(e.dictLongTable) != len(e.longTable) {
   530  			e.dictLongTable = make([]prevEntry, len(e.longTable))
   531  		}
   532  		if len(d.content) >= 8 {
   533  			cv := load6432(d.content, 0)
   534  			h := hashLen(cv, bestLongTableBits, bestLongLen)
   535  			e.dictLongTable[h] = prevEntry{
   536  				offset: e.maxMatchOff,
   537  				prev:   e.dictLongTable[h].offset,
   538  			}
   539  
   540  			end := int32(len(d.content)) - 8 + e.maxMatchOff
   541  			off := 8 // First to read
   542  			for i := e.maxMatchOff + 1; i < end; i++ {
   543  				cv = cv>>8 | (uint64(d.content[off]) << 56)
   544  				h := hashLen(cv, bestLongTableBits, bestLongLen)
   545  				e.dictLongTable[h] = prevEntry{
   546  					offset: i,
   547  					prev:   e.dictLongTable[h].offset,
   548  				}
   549  				off++
   550  			}
   551  		}
   552  		e.lastDictID = d.id
   553  	}
   554  	// Reset table to initial state
   555  	copy(e.longTable[:], e.dictLongTable)
   556  
   557  	e.cur = e.maxMatchOff
   558  	// Reset table to initial state
   559  	copy(e.table[:], e.dictTable)
   560  }
   561  

View as plain text