...

Source file src/github.com/klauspost/compress/zstd/dict.go

Documentation: github.com/klauspost/compress/zstd

     1  package zstd
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  	"sort"
    11  
    12  	"github.com/klauspost/compress/huff0"
    13  )
    14  
    15  type dict struct {
    16  	id uint32
    17  
    18  	litEnc              *huff0.Scratch
    19  	llDec, ofDec, mlDec sequenceDec
    20  	offsets             [3]int
    21  	content             []byte
    22  }
    23  
    24  const dictMagic = "\x37\xa4\x30\xec"
    25  
    26  // Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
    27  const dictMaxLength = 1 << 31
    28  
    29  // ID returns the dictionary id or 0 if d is nil.
    30  func (d *dict) ID() uint32 {
    31  	if d == nil {
    32  		return 0
    33  	}
    34  	return d.id
    35  }
    36  
    37  // ContentSize returns the dictionary content size or 0 if d is nil.
    38  func (d *dict) ContentSize() int {
    39  	if d == nil {
    40  		return 0
    41  	}
    42  	return len(d.content)
    43  }
    44  
    45  // Content returns the dictionary content.
    46  func (d *dict) Content() []byte {
    47  	if d == nil {
    48  		return nil
    49  	}
    50  	return d.content
    51  }
    52  
    53  // Offsets returns the initial offsets.
    54  func (d *dict) Offsets() [3]int {
    55  	if d == nil {
    56  		return [3]int{}
    57  	}
    58  	return d.offsets
    59  }
    60  
    61  // LitEncoder returns the literal encoder.
    62  func (d *dict) LitEncoder() *huff0.Scratch {
    63  	if d == nil {
    64  		return nil
    65  	}
    66  	return d.litEnc
    67  }
    68  
    69  // Load a dictionary as described in
    70  // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
    71  func loadDict(b []byte) (*dict, error) {
    72  	// Check static field size.
    73  	if len(b) <= 8+(3*4) {
    74  		return nil, io.ErrUnexpectedEOF
    75  	}
    76  	d := dict{
    77  		llDec: sequenceDec{fse: &fseDecoder{}},
    78  		ofDec: sequenceDec{fse: &fseDecoder{}},
    79  		mlDec: sequenceDec{fse: &fseDecoder{}},
    80  	}
    81  	if string(b[:4]) != dictMagic {
    82  		return nil, ErrMagicMismatch
    83  	}
    84  	d.id = binary.LittleEndian.Uint32(b[4:8])
    85  	if d.id == 0 {
    86  		return nil, errors.New("dictionaries cannot have ID 0")
    87  	}
    88  
    89  	// Read literal table
    90  	var err error
    91  	d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("loading literal table: %w", err)
    94  	}
    95  	d.litEnc.Reuse = huff0.ReusePolicyMust
    96  
    97  	br := byteReader{
    98  		b:   b,
    99  		off: 0,
   100  	}
   101  	readDec := func(i tableIndex, dec *fseDecoder) error {
   102  		if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
   103  			return err
   104  		}
   105  		if br.overread() {
   106  			return io.ErrUnexpectedEOF
   107  		}
   108  		err = dec.transform(symbolTableX[i])
   109  		if err != nil {
   110  			println("Transform table error:", err)
   111  			return err
   112  		}
   113  		if debugDecoder || debugEncoder {
   114  			println("Read table ok", "symbolLen:", dec.symbolLen)
   115  		}
   116  		// Set decoders as predefined so they aren't reused.
   117  		dec.preDefined = true
   118  		return nil
   119  	}
   120  
   121  	if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
   122  		return nil, err
   123  	}
   124  	if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
   125  		return nil, err
   126  	}
   127  	if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
   128  		return nil, err
   129  	}
   130  	if br.remain() < 12 {
   131  		return nil, io.ErrUnexpectedEOF
   132  	}
   133  
   134  	d.offsets[0] = int(br.Uint32())
   135  	br.advance(4)
   136  	d.offsets[1] = int(br.Uint32())
   137  	br.advance(4)
   138  	d.offsets[2] = int(br.Uint32())
   139  	br.advance(4)
   140  	if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
   141  		return nil, errors.New("invalid offset in dictionary")
   142  	}
   143  	d.content = make([]byte, br.remain())
   144  	copy(d.content, br.unread())
   145  	if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
   146  		return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
   147  	}
   148  
   149  	return &d, nil
   150  }
   151  
   152  // InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
   153  func InspectDictionary(b []byte) (interface {
   154  	ID() uint32
   155  	ContentSize() int
   156  	Content() []byte
   157  	Offsets() [3]int
   158  	LitEncoder() *huff0.Scratch
   159  }, error) {
   160  	initPredefined()
   161  	d, err := loadDict(b)
   162  	return d, err
   163  }
   164  
   165  type BuildDictOptions struct {
   166  	// Dictionary ID.
   167  	ID uint32
   168  
   169  	// Content to use to create dictionary tables.
   170  	Contents [][]byte
   171  
   172  	// History to use for all blocks.
   173  	History []byte
   174  
   175  	// Offsets to use.
   176  	Offsets [3]int
   177  
   178  	// CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
   179  	// See https://github.com/facebook/zstd/issues/3724
   180  	CompatV155 bool
   181  
   182  	// Use the specified encoder level.
   183  	// The dictionary will be built using the specified encoder level,
   184  	// which will reflect speed and make the dictionary tailored for that level.
   185  	// If not set SpeedBestCompression will be used.
   186  	Level EncoderLevel
   187  
   188  	// DebugOut will write stats and other details here if set.
   189  	DebugOut io.Writer
   190  }
   191  
   192  func BuildDict(o BuildDictOptions) ([]byte, error) {
   193  	initPredefined()
   194  	hist := o.History
   195  	contents := o.Contents
   196  	debug := o.DebugOut != nil
   197  	println := func(args ...interface{}) {
   198  		if o.DebugOut != nil {
   199  			fmt.Fprintln(o.DebugOut, args...)
   200  		}
   201  	}
   202  	printf := func(s string, args ...interface{}) {
   203  		if o.DebugOut != nil {
   204  			fmt.Fprintf(o.DebugOut, s, args...)
   205  		}
   206  	}
   207  	print := func(args ...interface{}) {
   208  		if o.DebugOut != nil {
   209  			fmt.Fprint(o.DebugOut, args...)
   210  		}
   211  	}
   212  
   213  	if int64(len(hist)) > dictMaxLength {
   214  		return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
   215  	}
   216  	if len(hist) < 8 {
   217  		return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
   218  	}
   219  	if len(contents) == 0 {
   220  		return nil, errors.New("no content provided")
   221  	}
   222  	d := dict{
   223  		id:      o.ID,
   224  		litEnc:  nil,
   225  		llDec:   sequenceDec{},
   226  		ofDec:   sequenceDec{},
   227  		mlDec:   sequenceDec{},
   228  		offsets: o.Offsets,
   229  		content: hist,
   230  	}
   231  	block := blockEnc{lowMem: false}
   232  	block.init()
   233  	enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
   234  	if o.Level != 0 {
   235  		eOpts := encoderOptions{
   236  			level:      o.Level,
   237  			blockSize:  maxMatchLen,
   238  			windowSize: maxMatchLen,
   239  			dict:       &d,
   240  			lowMem:     false,
   241  		}
   242  		enc = eOpts.encoder()
   243  	} else {
   244  		o.Level = SpeedBestCompression
   245  	}
   246  	var (
   247  		remain [256]int
   248  		ll     [256]int
   249  		ml     [256]int
   250  		of     [256]int
   251  	)
   252  	addValues := func(dst *[256]int, src []byte) {
   253  		for _, v := range src {
   254  			dst[v]++
   255  		}
   256  	}
   257  	addHist := func(dst *[256]int, src *[256]uint32) {
   258  		for i, v := range src {
   259  			dst[i] += int(v)
   260  		}
   261  	}
   262  	seqs := 0
   263  	nUsed := 0
   264  	litTotal := 0
   265  	newOffsets := make(map[uint32]int, 1000)
   266  	for _, b := range contents {
   267  		block.reset(nil)
   268  		if len(b) < 8 {
   269  			continue
   270  		}
   271  		nUsed++
   272  		enc.Reset(&d, true)
   273  		enc.Encode(&block, b)
   274  		addValues(&remain, block.literals)
   275  		litTotal += len(block.literals)
   276  		seqs += len(block.sequences)
   277  		block.genCodes()
   278  		addHist(&ll, block.coders.llEnc.Histogram())
   279  		addHist(&ml, block.coders.mlEnc.Histogram())
   280  		addHist(&of, block.coders.ofEnc.Histogram())
   281  		for i, seq := range block.sequences {
   282  			if i > 3 {
   283  				break
   284  			}
   285  			offset := seq.offset
   286  			if offset == 0 {
   287  				continue
   288  			}
   289  			if offset > 3 {
   290  				newOffsets[offset-3]++
   291  			} else {
   292  				newOffsets[uint32(o.Offsets[offset-1])]++
   293  			}
   294  		}
   295  	}
   296  	// Find most used offsets.
   297  	var sortedOffsets []uint32
   298  	for k := range newOffsets {
   299  		sortedOffsets = append(sortedOffsets, k)
   300  	}
   301  	sort.Slice(sortedOffsets, func(i, j int) bool {
   302  		a, b := sortedOffsets[i], sortedOffsets[j]
   303  		if a == b {
   304  			// Prefer the longer offset
   305  			return sortedOffsets[i] > sortedOffsets[j]
   306  		}
   307  		return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
   308  	})
   309  	if len(sortedOffsets) > 3 {
   310  		if debug {
   311  			print("Offsets:")
   312  			for i, v := range sortedOffsets {
   313  				if i > 20 {
   314  					break
   315  				}
   316  				printf("[%d: %d],", v, newOffsets[v])
   317  			}
   318  			println("")
   319  		}
   320  
   321  		sortedOffsets = sortedOffsets[:3]
   322  	}
   323  	for i, v := range sortedOffsets {
   324  		o.Offsets[i] = int(v)
   325  	}
   326  	if debug {
   327  		println("New repeat offsets", o.Offsets)
   328  	}
   329  
   330  	if nUsed == 0 || seqs == 0 {
   331  		return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
   332  	}
   333  	if debug {
   334  		println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
   335  	}
   336  	if seqs/nUsed < 512 {
   337  		// Use 512 as minimum.
   338  		nUsed = seqs / 512
   339  	}
   340  	copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
   341  		hist := dst.Histogram()
   342  		var maxSym uint8
   343  		var maxCount int
   344  		var fakeLength int
   345  		for i, v := range src {
   346  			if v > 0 {
   347  				v = v / nUsed
   348  				if v == 0 {
   349  					v = 1
   350  				}
   351  			}
   352  			if v > maxCount {
   353  				maxCount = v
   354  			}
   355  			if v != 0 {
   356  				maxSym = uint8(i)
   357  			}
   358  			fakeLength += v
   359  			hist[i] = uint32(v)
   360  		}
   361  		dst.HistogramFinished(maxSym, maxCount)
   362  		dst.reUsed = false
   363  		dst.useRLE = false
   364  		err := dst.normalizeCount(fakeLength)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  		if debug {
   369  			println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
   370  		}
   371  		return dst.writeCount(nil)
   372  	}
   373  	if debug {
   374  		print("Literal lengths: ")
   375  	}
   376  	llTable, err := copyHist(block.coders.llEnc, &ll)
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  	if debug {
   381  		print("Match lengths: ")
   382  	}
   383  	mlTable, err := copyHist(block.coders.mlEnc, &ml)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  	if debug {
   388  		print("Offsets: ")
   389  	}
   390  	ofTable, err := copyHist(block.coders.ofEnc, &of)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  
   395  	// Literal table
   396  	avgSize := litTotal
   397  	if avgSize > huff0.BlockSizeMax/2 {
   398  		avgSize = huff0.BlockSizeMax / 2
   399  	}
   400  	huffBuff := make([]byte, 0, avgSize)
   401  	// Target size
   402  	div := litTotal / avgSize
   403  	if div < 1 {
   404  		div = 1
   405  	}
   406  	if debug {
   407  		println("Huffman weights:")
   408  	}
   409  	for i, n := range remain[:] {
   410  		if n > 0 {
   411  			n = n / div
   412  			// Allow all entries to be represented.
   413  			if n == 0 {
   414  				n = 1
   415  			}
   416  			huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
   417  			if debug {
   418  				printf("[%d: %d], ", i, n)
   419  			}
   420  		}
   421  	}
   422  	if o.CompatV155 && remain[255]/div == 0 {
   423  		huffBuff = append(huffBuff, 255)
   424  	}
   425  	scratch := &huff0.Scratch{TableLog: 11}
   426  	for tries := 0; tries < 255; tries++ {
   427  		scratch = &huff0.Scratch{TableLog: 11}
   428  		_, _, err = huff0.Compress1X(huffBuff, scratch)
   429  		if err == nil {
   430  			break
   431  		}
   432  		if debug {
   433  			printf("Try %d: Huffman error: %v\n", tries+1, err)
   434  		}
   435  		huffBuff = huffBuff[:0]
   436  		if tries == 250 {
   437  			if debug {
   438  				println("Huffman: Bailing out with predefined table")
   439  			}
   440  
   441  			// Bail out.... Just generate something
   442  			huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
   443  			for i := 0; i < 128; i++ {
   444  				huffBuff = append(huffBuff, byte(i))
   445  			}
   446  			continue
   447  		}
   448  		if errors.Is(err, huff0.ErrIncompressible) {
   449  			// Try truncating least common.
   450  			for i, n := range remain[:] {
   451  				if n > 0 {
   452  					n = n / (div * (i + 1))
   453  					if n > 0 {
   454  						huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
   455  					}
   456  				}
   457  			}
   458  			if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
   459  				huffBuff = append(huffBuff, 255)
   460  			}
   461  			if len(huffBuff) == 0 {
   462  				huffBuff = append(huffBuff, 0, 255)
   463  			}
   464  		}
   465  		if errors.Is(err, huff0.ErrUseRLE) {
   466  			for i, n := range remain[:] {
   467  				n = n / (div * (i + 1))
   468  				// Allow all entries to be represented.
   469  				if n == 0 {
   470  					n = 1
   471  				}
   472  				huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
   473  			}
   474  		}
   475  	}
   476  
   477  	var out bytes.Buffer
   478  	out.Write([]byte(dictMagic))
   479  	out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
   480  	out.Write(scratch.OutTable)
   481  	if debug {
   482  		println("huff table:", len(scratch.OutTable), "bytes")
   483  		println("of table:", len(ofTable), "bytes")
   484  		println("ml table:", len(mlTable), "bytes")
   485  		println("ll table:", len(llTable), "bytes")
   486  	}
   487  	out.Write(ofTable)
   488  	out.Write(mlTable)
   489  	out.Write(llTable)
   490  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
   491  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
   492  	out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
   493  	out.Write(hist)
   494  	if debug {
   495  		_, err := loadDict(out.Bytes())
   496  		if err != nil {
   497  			panic(err)
   498  		}
   499  		i, err := InspectDictionary(out.Bytes())
   500  		if err != nil {
   501  			panic(err)
   502  		}
   503  		println("ID:", i.ID())
   504  		println("Content size:", i.ContentSize())
   505  		println("Encoder:", i.LitEncoder() != nil)
   506  		println("Offsets:", i.Offsets())
   507  		var totalSize int
   508  		for _, b := range contents {
   509  			totalSize += len(b)
   510  		}
   511  
   512  		encWith := func(opts ...EOption) int {
   513  			enc, err := NewWriter(nil, opts...)
   514  			if err != nil {
   515  				panic(err)
   516  			}
   517  			defer enc.Close()
   518  			var dst []byte
   519  			var totalSize int
   520  			for _, b := range contents {
   521  				dst = enc.EncodeAll(b, dst[:0])
   522  				totalSize += len(dst)
   523  			}
   524  			return totalSize
   525  		}
   526  		plain := encWith(WithEncoderLevel(o.Level))
   527  		withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
   528  		println("Input size:", totalSize)
   529  		println("Plain Compressed:", plain)
   530  		println("Dict Compressed:", withDict)
   531  		println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
   532  	}
   533  	return out.Bytes(), nil
   534  }
   535  

View as plain text