1 package zstd
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "io"
9 "math"
10 "sort"
11
12 "github.com/klauspost/compress/huff0"
13 )
14
15 type dict struct {
16 id uint32
17
18 litEnc *huff0.Scratch
19 llDec, ofDec, mlDec sequenceDec
20 offsets [3]int
21 content []byte
22 }
23
24 const dictMagic = "\x37\xa4\x30\xec"
25
26
27 const dictMaxLength = 1 << 31
28
29
30 func (d *dict) ID() uint32 {
31 if d == nil {
32 return 0
33 }
34 return d.id
35 }
36
37
38 func (d *dict) ContentSize() int {
39 if d == nil {
40 return 0
41 }
42 return len(d.content)
43 }
44
45
46 func (d *dict) Content() []byte {
47 if d == nil {
48 return nil
49 }
50 return d.content
51 }
52
53
54 func (d *dict) Offsets() [3]int {
55 if d == nil {
56 return [3]int{}
57 }
58 return d.offsets
59 }
60
61
62 func (d *dict) LitEncoder() *huff0.Scratch {
63 if d == nil {
64 return nil
65 }
66 return d.litEnc
67 }
68
69
70
71 func loadDict(b []byte) (*dict, error) {
72
73 if len(b) <= 8+(3*4) {
74 return nil, io.ErrUnexpectedEOF
75 }
76 d := dict{
77 llDec: sequenceDec{fse: &fseDecoder{}},
78 ofDec: sequenceDec{fse: &fseDecoder{}},
79 mlDec: sequenceDec{fse: &fseDecoder{}},
80 }
81 if string(b[:4]) != dictMagic {
82 return nil, ErrMagicMismatch
83 }
84 d.id = binary.LittleEndian.Uint32(b[4:8])
85 if d.id == 0 {
86 return nil, errors.New("dictionaries cannot have ID 0")
87 }
88
89
90 var err error
91 d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
92 if err != nil {
93 return nil, fmt.Errorf("loading literal table: %w", err)
94 }
95 d.litEnc.Reuse = huff0.ReusePolicyMust
96
97 br := byteReader{
98 b: b,
99 off: 0,
100 }
101 readDec := func(i tableIndex, dec *fseDecoder) error {
102 if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
103 return err
104 }
105 if br.overread() {
106 return io.ErrUnexpectedEOF
107 }
108 err = dec.transform(symbolTableX[i])
109 if err != nil {
110 println("Transform table error:", err)
111 return err
112 }
113 if debugDecoder || debugEncoder {
114 println("Read table ok", "symbolLen:", dec.symbolLen)
115 }
116
117 dec.preDefined = true
118 return nil
119 }
120
121 if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
122 return nil, err
123 }
124 if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
125 return nil, err
126 }
127 if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
128 return nil, err
129 }
130 if br.remain() < 12 {
131 return nil, io.ErrUnexpectedEOF
132 }
133
134 d.offsets[0] = int(br.Uint32())
135 br.advance(4)
136 d.offsets[1] = int(br.Uint32())
137 br.advance(4)
138 d.offsets[2] = int(br.Uint32())
139 br.advance(4)
140 if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
141 return nil, errors.New("invalid offset in dictionary")
142 }
143 d.content = make([]byte, br.remain())
144 copy(d.content, br.unread())
145 if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
146 return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
147 }
148
149 return &d, nil
150 }
151
152
153 func InspectDictionary(b []byte) (interface {
154 ID() uint32
155 ContentSize() int
156 Content() []byte
157 Offsets() [3]int
158 LitEncoder() *huff0.Scratch
159 }, error) {
160 initPredefined()
161 d, err := loadDict(b)
162 return d, err
163 }
164
165 type BuildDictOptions struct {
166
167 ID uint32
168
169
170 Contents [][]byte
171
172
173 History []byte
174
175
176 Offsets [3]int
177
178
179
180 CompatV155 bool
181
182
183
184
185
186 Level EncoderLevel
187
188
189 DebugOut io.Writer
190 }
191
192 func BuildDict(o BuildDictOptions) ([]byte, error) {
193 initPredefined()
194 hist := o.History
195 contents := o.Contents
196 debug := o.DebugOut != nil
197 println := func(args ...interface{}) {
198 if o.DebugOut != nil {
199 fmt.Fprintln(o.DebugOut, args...)
200 }
201 }
202 printf := func(s string, args ...interface{}) {
203 if o.DebugOut != nil {
204 fmt.Fprintf(o.DebugOut, s, args...)
205 }
206 }
207 print := func(args ...interface{}) {
208 if o.DebugOut != nil {
209 fmt.Fprint(o.DebugOut, args...)
210 }
211 }
212
213 if int64(len(hist)) > dictMaxLength {
214 return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
215 }
216 if len(hist) < 8 {
217 return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
218 }
219 if len(contents) == 0 {
220 return nil, errors.New("no content provided")
221 }
222 d := dict{
223 id: o.ID,
224 litEnc: nil,
225 llDec: sequenceDec{},
226 ofDec: sequenceDec{},
227 mlDec: sequenceDec{},
228 offsets: o.Offsets,
229 content: hist,
230 }
231 block := blockEnc{lowMem: false}
232 block.init()
233 enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
234 if o.Level != 0 {
235 eOpts := encoderOptions{
236 level: o.Level,
237 blockSize: maxMatchLen,
238 windowSize: maxMatchLen,
239 dict: &d,
240 lowMem: false,
241 }
242 enc = eOpts.encoder()
243 } else {
244 o.Level = SpeedBestCompression
245 }
246 var (
247 remain [256]int
248 ll [256]int
249 ml [256]int
250 of [256]int
251 )
252 addValues := func(dst *[256]int, src []byte) {
253 for _, v := range src {
254 dst[v]++
255 }
256 }
257 addHist := func(dst *[256]int, src *[256]uint32) {
258 for i, v := range src {
259 dst[i] += int(v)
260 }
261 }
262 seqs := 0
263 nUsed := 0
264 litTotal := 0
265 newOffsets := make(map[uint32]int, 1000)
266 for _, b := range contents {
267 block.reset(nil)
268 if len(b) < 8 {
269 continue
270 }
271 nUsed++
272 enc.Reset(&d, true)
273 enc.Encode(&block, b)
274 addValues(&remain, block.literals)
275 litTotal += len(block.literals)
276 seqs += len(block.sequences)
277 block.genCodes()
278 addHist(&ll, block.coders.llEnc.Histogram())
279 addHist(&ml, block.coders.mlEnc.Histogram())
280 addHist(&of, block.coders.ofEnc.Histogram())
281 for i, seq := range block.sequences {
282 if i > 3 {
283 break
284 }
285 offset := seq.offset
286 if offset == 0 {
287 continue
288 }
289 if offset > 3 {
290 newOffsets[offset-3]++
291 } else {
292 newOffsets[uint32(o.Offsets[offset-1])]++
293 }
294 }
295 }
296
297 var sortedOffsets []uint32
298 for k := range newOffsets {
299 sortedOffsets = append(sortedOffsets, k)
300 }
301 sort.Slice(sortedOffsets, func(i, j int) bool {
302 a, b := sortedOffsets[i], sortedOffsets[j]
303 if a == b {
304
305 return sortedOffsets[i] > sortedOffsets[j]
306 }
307 return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
308 })
309 if len(sortedOffsets) > 3 {
310 if debug {
311 print("Offsets:")
312 for i, v := range sortedOffsets {
313 if i > 20 {
314 break
315 }
316 printf("[%d: %d],", v, newOffsets[v])
317 }
318 println("")
319 }
320
321 sortedOffsets = sortedOffsets[:3]
322 }
323 for i, v := range sortedOffsets {
324 o.Offsets[i] = int(v)
325 }
326 if debug {
327 println("New repeat offsets", o.Offsets)
328 }
329
330 if nUsed == 0 || seqs == 0 {
331 return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
332 }
333 if debug {
334 println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
335 }
336 if seqs/nUsed < 512 {
337
338 nUsed = seqs / 512
339 }
340 copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
341 hist := dst.Histogram()
342 var maxSym uint8
343 var maxCount int
344 var fakeLength int
345 for i, v := range src {
346 if v > 0 {
347 v = v / nUsed
348 if v == 0 {
349 v = 1
350 }
351 }
352 if v > maxCount {
353 maxCount = v
354 }
355 if v != 0 {
356 maxSym = uint8(i)
357 }
358 fakeLength += v
359 hist[i] = uint32(v)
360 }
361 dst.HistogramFinished(maxSym, maxCount)
362 dst.reUsed = false
363 dst.useRLE = false
364 err := dst.normalizeCount(fakeLength)
365 if err != nil {
366 return nil, err
367 }
368 if debug {
369 println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
370 }
371 return dst.writeCount(nil)
372 }
373 if debug {
374 print("Literal lengths: ")
375 }
376 llTable, err := copyHist(block.coders.llEnc, &ll)
377 if err != nil {
378 return nil, err
379 }
380 if debug {
381 print("Match lengths: ")
382 }
383 mlTable, err := copyHist(block.coders.mlEnc, &ml)
384 if err != nil {
385 return nil, err
386 }
387 if debug {
388 print("Offsets: ")
389 }
390 ofTable, err := copyHist(block.coders.ofEnc, &of)
391 if err != nil {
392 return nil, err
393 }
394
395
396 avgSize := litTotal
397 if avgSize > huff0.BlockSizeMax/2 {
398 avgSize = huff0.BlockSizeMax / 2
399 }
400 huffBuff := make([]byte, 0, avgSize)
401
402 div := litTotal / avgSize
403 if div < 1 {
404 div = 1
405 }
406 if debug {
407 println("Huffman weights:")
408 }
409 for i, n := range remain[:] {
410 if n > 0 {
411 n = n / div
412
413 if n == 0 {
414 n = 1
415 }
416 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
417 if debug {
418 printf("[%d: %d], ", i, n)
419 }
420 }
421 }
422 if o.CompatV155 && remain[255]/div == 0 {
423 huffBuff = append(huffBuff, 255)
424 }
425 scratch := &huff0.Scratch{TableLog: 11}
426 for tries := 0; tries < 255; tries++ {
427 scratch = &huff0.Scratch{TableLog: 11}
428 _, _, err = huff0.Compress1X(huffBuff, scratch)
429 if err == nil {
430 break
431 }
432 if debug {
433 printf("Try %d: Huffman error: %v\n", tries+1, err)
434 }
435 huffBuff = huffBuff[:0]
436 if tries == 250 {
437 if debug {
438 println("Huffman: Bailing out with predefined table")
439 }
440
441
442 huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
443 for i := 0; i < 128; i++ {
444 huffBuff = append(huffBuff, byte(i))
445 }
446 continue
447 }
448 if errors.Is(err, huff0.ErrIncompressible) {
449
450 for i, n := range remain[:] {
451 if n > 0 {
452 n = n / (div * (i + 1))
453 if n > 0 {
454 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
455 }
456 }
457 }
458 if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
459 huffBuff = append(huffBuff, 255)
460 }
461 if len(huffBuff) == 0 {
462 huffBuff = append(huffBuff, 0, 255)
463 }
464 }
465 if errors.Is(err, huff0.ErrUseRLE) {
466 for i, n := range remain[:] {
467 n = n / (div * (i + 1))
468
469 if n == 0 {
470 n = 1
471 }
472 huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
473 }
474 }
475 }
476
477 var out bytes.Buffer
478 out.Write([]byte(dictMagic))
479 out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
480 out.Write(scratch.OutTable)
481 if debug {
482 println("huff table:", len(scratch.OutTable), "bytes")
483 println("of table:", len(ofTable), "bytes")
484 println("ml table:", len(mlTable), "bytes")
485 println("ll table:", len(llTable), "bytes")
486 }
487 out.Write(ofTable)
488 out.Write(mlTable)
489 out.Write(llTable)
490 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
491 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
492 out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
493 out.Write(hist)
494 if debug {
495 _, err := loadDict(out.Bytes())
496 if err != nil {
497 panic(err)
498 }
499 i, err := InspectDictionary(out.Bytes())
500 if err != nil {
501 panic(err)
502 }
503 println("ID:", i.ID())
504 println("Content size:", i.ContentSize())
505 println("Encoder:", i.LitEncoder() != nil)
506 println("Offsets:", i.Offsets())
507 var totalSize int
508 for _, b := range contents {
509 totalSize += len(b)
510 }
511
512 encWith := func(opts ...EOption) int {
513 enc, err := NewWriter(nil, opts...)
514 if err != nil {
515 panic(err)
516 }
517 defer enc.Close()
518 var dst []byte
519 var totalSize int
520 for _, b := range contents {
521 dst = enc.EncodeAll(b, dst[:0])
522 totalSize += len(dst)
523 }
524 return totalSize
525 }
526 plain := encWith(WithEncoderLevel(o.Level))
527 withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
528 println("Input size:", totalSize)
529 println("Plain Compressed:", plain)
530 println("Dict Compressed:", withDict)
531 println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
532 }
533 return out.Bytes(), nil
534 }
535
View as plain text