...

Source file src/github.com/klauspost/compress/dict/builder.go

Documentation: github.com/klauspost/compress/dict

     1  // Copyright 2023+ Klaus Post. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dict
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"math/rand"
    14  	"sort"
    15  	"time"
    16  
    17  	"github.com/klauspost/compress/s2"
    18  	"github.com/klauspost/compress/zstd"
    19  )
    20  
    21  type match struct {
    22  	hash   uint32
    23  	n      uint32
    24  	offset int64
    25  }
    26  
    27  type matchValue struct {
    28  	value       []byte
    29  	followBy    map[uint32]uint32
    30  	preceededBy map[uint32]uint32
    31  }
    32  
    33  type Options struct {
    34  	// MaxDictSize is the max size of the backreference dictionary.
    35  	MaxDictSize int
    36  
    37  	// HashBytes is the minimum length to index.
    38  	// Must be >=4 and <=8
    39  	HashBytes int
    40  
    41  	// Debug output
    42  	Output io.Writer
    43  
    44  	// ZstdDictID is the Zstd dictionary ID to use.
    45  	// Leave at zero to generate a random ID.
    46  	ZstdDictID uint32
    47  
    48  	// ZstdDictCompat will make the dictionary compatible with Zstd v1.5.5 and earlier.
    49  	// See https://github.com/facebook/zstd/issues/3724
    50  	ZstdDictCompat bool
    51  
    52  	// Use the specified encoder level for Zstandard dictionaries.
    53  	// The dictionary will be built using the specified encoder level,
    54  	// which will reflect speed and make the dictionary tailored for that level.
    55  	// If not set zstd.SpeedBestCompression will be used.
    56  	ZstdLevel zstd.EncoderLevel
    57  
    58  	outFormat int
    59  }
    60  
    61  const (
    62  	formatRaw = iota
    63  	formatZstd
    64  	formatS2
    65  )
    66  
    67  // BuildZstdDict will build a Zstandard dictionary from the provided input.
    68  func BuildZstdDict(input [][]byte, o Options) ([]byte, error) {
    69  	o.outFormat = formatZstd
    70  	if o.ZstdDictID == 0 {
    71  		rng := rand.New(rand.NewSource(time.Now().UnixNano()))
    72  		o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768))
    73  	}
    74  	return buildDict(input, o)
    75  }
    76  
    77  // BuildS2Dict will build a S2 dictionary from the provided input.
    78  func BuildS2Dict(input [][]byte, o Options) ([]byte, error) {
    79  	o.outFormat = formatS2
    80  	if o.MaxDictSize > s2.MaxDictSize {
    81  		return nil, errors.New("max dict size too large")
    82  	}
    83  	return buildDict(input, o)
    84  }
    85  
    86  // BuildRawDict will build a raw dictionary from the provided input.
    87  // This can be used for deflate, lz4 and others.
    88  func BuildRawDict(input [][]byte, o Options) ([]byte, error) {
    89  	o.outFormat = formatRaw
    90  	return buildDict(input, o)
    91  }
    92  
    93  func buildDict(input [][]byte, o Options) ([]byte, error) {
    94  	matches := make(map[uint32]uint32)
    95  	offsets := make(map[uint32]int64)
    96  	var total uint64
    97  
    98  	wantLen := o.MaxDictSize
    99  	hashBytes := o.HashBytes
   100  	if len(input) == 0 {
   101  		return nil, fmt.Errorf("no input provided")
   102  	}
   103  	if hashBytes < 4 || hashBytes > 8 {
   104  		return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8")
   105  	}
   106  	println := func(args ...interface{}) {
   107  		if o.Output != nil {
   108  			fmt.Fprintln(o.Output, args...)
   109  		}
   110  	}
   111  	printf := func(s string, args ...interface{}) {
   112  		if o.Output != nil {
   113  			fmt.Fprintf(o.Output, s, args...)
   114  		}
   115  	}
   116  	found := make(map[uint32]struct{})
   117  	for i, b := range input {
   118  		for k := range found {
   119  			delete(found, k)
   120  		}
   121  		for i := range b {
   122  			rem := b[i:]
   123  			if len(rem) < 8 {
   124  				break
   125  			}
   126  			h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
   127  			if _, ok := found[h]; ok {
   128  				// Only count first occurrence
   129  				continue
   130  			}
   131  			matches[h]++
   132  			offsets[h] += int64(i)
   133  			total++
   134  			found[h] = struct{}{}
   135  		}
   136  		printf("\r input %d indexed...", i)
   137  	}
   138  	threshold := uint32(total / uint64(len(matches)))
   139  	println("\nTotal", total, "match", len(matches), "avg", threshold)
   140  	sorted := make([]match, 0, len(matches)/2)
   141  	for k, v := range matches {
   142  		if v <= threshold {
   143  			continue
   144  		}
   145  		sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]})
   146  	}
   147  	sort.Slice(sorted, func(i, j int) bool {
   148  		if true {
   149  			// Group very similar counts together and emit low offsets first.
   150  			// This will keep together strings that are very similar.
   151  			deltaN := int(sorted[i].n) - int(sorted[j].n)
   152  			if deltaN < 0 {
   153  				deltaN = -deltaN
   154  			}
   155  			if uint32(deltaN) < sorted[i].n/32 {
   156  				return sorted[i].offset < sorted[j].offset
   157  			}
   158  		} else {
   159  			if sorted[i].n == sorted[j].n {
   160  				return sorted[i].offset < sorted[j].offset
   161  			}
   162  		}
   163  		return sorted[i].n > sorted[j].n
   164  	})
   165  	println("Sorted len:", len(sorted))
   166  	if len(sorted) > wantLen {
   167  		sorted = sorted[:wantLen]
   168  	}
   169  	lowestOcc := sorted[len(sorted)-1].n
   170  	println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc)
   171  
   172  	wantMatches := make(map[uint32]uint32, len(sorted))
   173  	for _, v := range sorted {
   174  		wantMatches[v.hash] = v.n
   175  	}
   176  
   177  	output := make(map[uint32]matchValue, len(sorted))
   178  	var remainCnt [256]int
   179  	var remainTotal int
   180  	var firstOffsets []int
   181  	for i, b := range input {
   182  		for i := range b {
   183  			rem := b[i:]
   184  			if len(rem) < 8 {
   185  				break
   186  			}
   187  			var prev []byte
   188  			if i > hashBytes {
   189  				prev = b[i-hashBytes:]
   190  			}
   191  
   192  			h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
   193  			if _, ok := wantMatches[h]; !ok {
   194  				remainCnt[rem[0]]++
   195  				remainTotal++
   196  				continue
   197  			}
   198  			mv := output[h]
   199  			if len(mv.value) == 0 {
   200  				var tmp = make([]byte, hashBytes)
   201  				copy(tmp[:], rem)
   202  				mv.value = tmp[:]
   203  			}
   204  			if mv.followBy == nil {
   205  				mv.followBy = make(map[uint32]uint32, 4)
   206  				mv.preceededBy = make(map[uint32]uint32, 4)
   207  			}
   208  			if len(rem) > hashBytes+8 {
   209  				// Check if we should add next as well.
   210  				hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes))
   211  				if _, ok := wantMatches[hNext]; ok {
   212  					mv.followBy[hNext]++
   213  				}
   214  			}
   215  			if len(prev) >= 8 {
   216  				// Check if we should prev next as well.
   217  				hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes))
   218  				if _, ok := wantMatches[hPrev]; ok {
   219  					mv.preceededBy[hPrev]++
   220  				}
   221  			}
   222  			output[h] = mv
   223  		}
   224  		printf("\rinput %d re-indexed...", i)
   225  	}
   226  	println("")
   227  	dst := make([][]byte, 0, wantLen/hashBytes)
   228  	added := 0
   229  	const printUntil = 500
   230  	for i, e := range sorted {
   231  		if added > o.MaxDictSize {
   232  			println("Ending. Next Occurrence:", e.n)
   233  			break
   234  		}
   235  		m, ok := output[e.hash]
   236  		if !ok {
   237  			// Already added
   238  			continue
   239  		}
   240  		wantLen := e.n / uint32(hashBytes) / 4
   241  		if wantLen <= lowestOcc {
   242  			wantLen = lowestOcc
   243  		}
   244  
   245  		var tmp = make([]byte, 0, hashBytes*2)
   246  		{
   247  			sortedPrev := make([]match, 0, len(m.followBy))
   248  			for k, v := range m.preceededBy {
   249  				if _, ok := output[k]; v < wantLen || !ok {
   250  					continue
   251  				}
   252  				sortedPrev = append(sortedPrev, match{
   253  					hash: k,
   254  					n:    v,
   255  				})
   256  			}
   257  			if len(sortedPrev) > 0 {
   258  				sort.Slice(sortedPrev, func(i, j int) bool {
   259  					return sortedPrev[i].n > sortedPrev[j].n
   260  				})
   261  				bestPrev := output[sortedPrev[0].hash]
   262  				tmp = append(tmp, bestPrev.value...)
   263  			}
   264  		}
   265  		tmp = append(tmp, m.value...)
   266  		delete(output, e.hash)
   267  
   268  		sortedFollow := make([]match, 0, len(m.followBy))
   269  		for {
   270  			var nh uint32 // Next hash
   271  			stopAfter := false
   272  			{
   273  				sortedFollow = sortedFollow[:0]
   274  				for k, v := range m.followBy {
   275  					if _, ok := output[k]; !ok {
   276  						continue
   277  					}
   278  					sortedFollow = append(sortedFollow, match{
   279  						hash:   k,
   280  						n:      v,
   281  						offset: offsets[k],
   282  					})
   283  				}
   284  				if len(sortedFollow) == 0 {
   285  					// Step back
   286  					// Extremely small impact, but helps longer hashes a bit.
   287  					const stepBack = 2
   288  					if stepBack > 0 && len(tmp) >= hashBytes+stepBack {
   289  						var t8 [8]byte
   290  						copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:])
   291  						m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))]
   292  						if ok && len(m.followBy) > 0 {
   293  							found := []byte(nil)
   294  							for k := range m.followBy {
   295  								v, ok := output[k]
   296  								if !ok {
   297  									continue
   298  								}
   299  								found = v.value
   300  								break
   301  							}
   302  							if found != nil {
   303  								tmp = tmp[:len(tmp)-stepBack]
   304  								printf("Step back: %q +  %q\n", string(tmp), string(found))
   305  								continue
   306  							}
   307  						}
   308  						break
   309  					} else {
   310  						if i < printUntil {
   311  							printf("FOLLOW: none after %q\n", string(m.value))
   312  						}
   313  					}
   314  					break
   315  				}
   316  				sort.Slice(sortedFollow, func(i, j int) bool {
   317  					if sortedFollow[i].n == sortedFollow[j].n {
   318  						return sortedFollow[i].offset > sortedFollow[j].offset
   319  					}
   320  					return sortedFollow[i].n > sortedFollow[j].n
   321  				})
   322  				nh = sortedFollow[0].hash
   323  				stopAfter = sortedFollow[0].n < wantLen
   324  				if stopAfter && i < printUntil {
   325  					printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value))
   326  				}
   327  			}
   328  			m, ok = output[nh]
   329  			if !ok {
   330  				break
   331  			}
   332  			if len(tmp) > 0 {
   333  				// Delete all hashes that are in the current string to avoid stuttering.
   334  				var toDel [16 + 8]byte
   335  				copy(toDel[:], tmp[len(tmp)-hashBytes:])
   336  				copy(toDel[hashBytes:], m.value)
   337  				for i := range toDel[:hashBytes*2] {
   338  					delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes)))
   339  				}
   340  			}
   341  			tmp = append(tmp, m.value...)
   342  			//delete(output, nh)
   343  			if stopAfter {
   344  				// Last entry was no significant.
   345  				break
   346  			}
   347  		}
   348  		if i < printUntil {
   349  			printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen)
   350  		}
   351  		// Delete substrings already added.
   352  		if len(tmp) > hashBytes {
   353  			for j := range tmp[:len(tmp)-hashBytes+1] {
   354  				var t8 [8]byte
   355  				copy(t8[:], tmp[j:])
   356  				if i < printUntil {
   357  					//printf("* POST DELETE %q\n", string(t8[:hashBytes]))
   358  				}
   359  				delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes)))
   360  			}
   361  		}
   362  		dst = append(dst, tmp)
   363  		added += len(tmp)
   364  		// Find offsets
   365  		// TODO: This can be better if done as a global search.
   366  		if len(firstOffsets) < 3 {
   367  			if len(tmp) > 16 {
   368  				tmp = tmp[:16]
   369  			}
   370  			offCnt := make(map[int]int, len(input))
   371  			// Find first offsets
   372  			for _, b := range input {
   373  				off := bytes.Index(b, tmp)
   374  				if off == -1 {
   375  					continue
   376  				}
   377  				offCnt[off]++
   378  			}
   379  			for _, off := range firstOffsets {
   380  				// Very unlikely, but we deleted it just in case
   381  				delete(offCnt, off-added)
   382  			}
   383  			maxCnt := 0
   384  			maxOffset := 0
   385  			for k, v := range offCnt {
   386  				if v == maxCnt && k > maxOffset {
   387  					// Prefer the longer offset on ties , since it is more expensive to encode
   388  					maxCnt = v
   389  					maxOffset = k
   390  					continue
   391  				}
   392  
   393  				if v > maxCnt {
   394  					maxCnt = v
   395  					maxOffset = k
   396  				}
   397  			}
   398  			if maxCnt > 1 {
   399  				firstOffsets = append(firstOffsets, maxOffset+added)
   400  				println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset)
   401  			}
   402  		}
   403  	}
   404  	out := bytes.NewBuffer(nil)
   405  	written := 0
   406  	for i, toWrite := range dst {
   407  		if len(toWrite)+written > wantLen {
   408  			toWrite = toWrite[:wantLen-written]
   409  		}
   410  		dst[i] = toWrite
   411  		written += len(toWrite)
   412  		if written >= wantLen {
   413  			dst = dst[:i+1]
   414  			break
   415  		}
   416  	}
   417  	// Write in reverse order.
   418  	for i := range dst {
   419  		toWrite := dst[len(dst)-i-1]
   420  		out.Write(toWrite)
   421  	}
   422  	if o.outFormat == formatRaw {
   423  		return out.Bytes(), nil
   424  	}
   425  
   426  	if o.outFormat == formatS2 {
   427  		dOff := 0
   428  		dBytes := out.Bytes()
   429  		if len(dBytes) > s2.MaxDictSize {
   430  			dBytes = dBytes[:s2.MaxDictSize]
   431  		}
   432  		for _, off := range firstOffsets {
   433  			myOff := len(dBytes) - off
   434  			if myOff < 0 || myOff > s2.MaxDictSrcOffset {
   435  				continue
   436  			}
   437  			dOff = myOff
   438  		}
   439  
   440  		dict := s2.MakeDictManual(dBytes, uint16(dOff))
   441  		if dict == nil {
   442  			return nil, fmt.Errorf("unable to create s2 dictionary")
   443  		}
   444  		return dict.Bytes(), nil
   445  	}
   446  
   447  	offsetsZstd := [3]int{1, 4, 8}
   448  	for i, off := range firstOffsets {
   449  		if i >= 3 || off == 0 || off >= out.Len() {
   450  			break
   451  		}
   452  		offsetsZstd[i] = off
   453  	}
   454  	println("\nCompressing. Offsets:", offsetsZstd)
   455  	return zstd.BuildDict(zstd.BuildDictOptions{
   456  		ID:         o.ZstdDictID,
   457  		Contents:   input,
   458  		History:    out.Bytes(),
   459  		Offsets:    offsetsZstd,
   460  		CompatV155: o.ZstdDictCompat,
   461  		Level:      o.ZstdLevel,
   462  		DebugOut:   o.Output,
   463  	})
   464  }
   465  
   466  const (
   467  	prime3bytes = 506832829
   468  	prime4bytes = 2654435761
   469  	prime5bytes = 889523592379
   470  	prime6bytes = 227718039650203
   471  	prime7bytes = 58295818150454627
   472  	prime8bytes = 0xcf1bbcdcb7a56463
   473  )
   474  
   475  // hashLen returns a hash of the lowest l bytes of u for a size size of h bytes.
   476  // l must be >=4 and <=8. Any other value will return hash for 4 bytes.
   477  // h should always be <32.
   478  // Preferably h and l should be a constant.
   479  // LENGTH 4 is passed straight through
   480  func hashLen(u uint64, hashLog, mls uint8) uint32 {
   481  	switch mls {
   482  	case 5:
   483  		return hash5(u, hashLog)
   484  	case 6:
   485  		return hash6(u, hashLog)
   486  	case 7:
   487  		return hash7(u, hashLog)
   488  	case 8:
   489  		return hash8(u, hashLog)
   490  	default:
   491  		return uint32(u)
   492  	}
   493  }
   494  
   495  // hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits.
   496  // Preferably h should be a constant and should always be <32.
   497  func hash3(u uint32, h uint8) uint32 {
   498  	return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31)
   499  }
   500  
   501  // hash4 returns the hash of u to fit in a hash table with h bits.
   502  // Preferably h should be a constant and should always be <32.
   503  func hash4(u uint32, h uint8) uint32 {
   504  	return (u * prime4bytes) >> ((32 - h) & 31)
   505  }
   506  
   507  // hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits.
   508  // Preferably h should be a constant and should always be <32.
   509  func hash4x64(u uint64, h uint8) uint32 {
   510  	return (uint32(u) * prime4bytes) >> ((32 - h) & 31)
   511  }
   512  
   513  // hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits.
   514  // Preferably h should be a constant and should always be <64.
   515  func hash5(u uint64, h uint8) uint32 {
   516  	return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63))
   517  }
   518  
   519  // hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
   520  // Preferably h should be a constant and should always be <64.
   521  func hash6(u uint64, h uint8) uint32 {
   522  	return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63))
   523  }
   524  
   525  // hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits.
   526  // Preferably h should be a constant and should always be <64.
   527  func hash7(u uint64, h uint8) uint32 {
   528  	return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63))
   529  }
   530  
   531  // hash8 returns the hash of u to fit in a hash table with h bits.
   532  // Preferably h should be a constant and should always be <64.
   533  func hash8(u uint64, h uint8) uint32 {
   534  	return uint32((u * prime8bytes) >> ((64 - h) & 63))
   535  }
   536  

View as plain text