1
2
3
4
5 package dict
6
7 import (
8 "bytes"
9 "encoding/binary"
10 "errors"
11 "fmt"
12 "io"
13 "math/rand"
14 "sort"
15 "time"
16
17 "github.com/klauspost/compress/s2"
18 "github.com/klauspost/compress/zstd"
19 )
20
21 type match struct {
22 hash uint32
23 n uint32
24 offset int64
25 }
26
27 type matchValue struct {
28 value []byte
29 followBy map[uint32]uint32
30 preceededBy map[uint32]uint32
31 }
32
33 type Options struct {
34
35 MaxDictSize int
36
37
38
39 HashBytes int
40
41
42 Output io.Writer
43
44
45
46 ZstdDictID uint32
47
48
49
50 ZstdDictCompat bool
51
52
53
54
55
56 ZstdLevel zstd.EncoderLevel
57
58 outFormat int
59 }
60
61 const (
62 formatRaw = iota
63 formatZstd
64 formatS2
65 )
66
67
68 func BuildZstdDict(input [][]byte, o Options) ([]byte, error) {
69 o.outFormat = formatZstd
70 if o.ZstdDictID == 0 {
71 rng := rand.New(rand.NewSource(time.Now().UnixNano()))
72 o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768))
73 }
74 return buildDict(input, o)
75 }
76
77
78 func BuildS2Dict(input [][]byte, o Options) ([]byte, error) {
79 o.outFormat = formatS2
80 if o.MaxDictSize > s2.MaxDictSize {
81 return nil, errors.New("max dict size too large")
82 }
83 return buildDict(input, o)
84 }
85
86
87
88 func BuildRawDict(input [][]byte, o Options) ([]byte, error) {
89 o.outFormat = formatRaw
90 return buildDict(input, o)
91 }
92
93 func buildDict(input [][]byte, o Options) ([]byte, error) {
94 matches := make(map[uint32]uint32)
95 offsets := make(map[uint32]int64)
96 var total uint64
97
98 wantLen := o.MaxDictSize
99 hashBytes := o.HashBytes
100 if len(input) == 0 {
101 return nil, fmt.Errorf("no input provided")
102 }
103 if hashBytes < 4 || hashBytes > 8 {
104 return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8")
105 }
106 println := func(args ...interface{}) {
107 if o.Output != nil {
108 fmt.Fprintln(o.Output, args...)
109 }
110 }
111 printf := func(s string, args ...interface{}) {
112 if o.Output != nil {
113 fmt.Fprintf(o.Output, s, args...)
114 }
115 }
116 found := make(map[uint32]struct{})
117 for i, b := range input {
118 for k := range found {
119 delete(found, k)
120 }
121 for i := range b {
122 rem := b[i:]
123 if len(rem) < 8 {
124 break
125 }
126 h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
127 if _, ok := found[h]; ok {
128
129 continue
130 }
131 matches[h]++
132 offsets[h] += int64(i)
133 total++
134 found[h] = struct{}{}
135 }
136 printf("\r input %d indexed...", i)
137 }
138 threshold := uint32(total / uint64(len(matches)))
139 println("\nTotal", total, "match", len(matches), "avg", threshold)
140 sorted := make([]match, 0, len(matches)/2)
141 for k, v := range matches {
142 if v <= threshold {
143 continue
144 }
145 sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]})
146 }
147 sort.Slice(sorted, func(i, j int) bool {
148 if true {
149
150
151 deltaN := int(sorted[i].n) - int(sorted[j].n)
152 if deltaN < 0 {
153 deltaN = -deltaN
154 }
155 if uint32(deltaN) < sorted[i].n/32 {
156 return sorted[i].offset < sorted[j].offset
157 }
158 } else {
159 if sorted[i].n == sorted[j].n {
160 return sorted[i].offset < sorted[j].offset
161 }
162 }
163 return sorted[i].n > sorted[j].n
164 })
165 println("Sorted len:", len(sorted))
166 if len(sorted) > wantLen {
167 sorted = sorted[:wantLen]
168 }
169 lowestOcc := sorted[len(sorted)-1].n
170 println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc)
171
172 wantMatches := make(map[uint32]uint32, len(sorted))
173 for _, v := range sorted {
174 wantMatches[v.hash] = v.n
175 }
176
177 output := make(map[uint32]matchValue, len(sorted))
178 var remainCnt [256]int
179 var remainTotal int
180 var firstOffsets []int
181 for i, b := range input {
182 for i := range b {
183 rem := b[i:]
184 if len(rem) < 8 {
185 break
186 }
187 var prev []byte
188 if i > hashBytes {
189 prev = b[i-hashBytes:]
190 }
191
192 h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
193 if _, ok := wantMatches[h]; !ok {
194 remainCnt[rem[0]]++
195 remainTotal++
196 continue
197 }
198 mv := output[h]
199 if len(mv.value) == 0 {
200 var tmp = make([]byte, hashBytes)
201 copy(tmp[:], rem)
202 mv.value = tmp[:]
203 }
204 if mv.followBy == nil {
205 mv.followBy = make(map[uint32]uint32, 4)
206 mv.preceededBy = make(map[uint32]uint32, 4)
207 }
208 if len(rem) > hashBytes+8 {
209
210 hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes))
211 if _, ok := wantMatches[hNext]; ok {
212 mv.followBy[hNext]++
213 }
214 }
215 if len(prev) >= 8 {
216
217 hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes))
218 if _, ok := wantMatches[hPrev]; ok {
219 mv.preceededBy[hPrev]++
220 }
221 }
222 output[h] = mv
223 }
224 printf("\rinput %d re-indexed...", i)
225 }
226 println("")
227 dst := make([][]byte, 0, wantLen/hashBytes)
228 added := 0
229 const printUntil = 500
230 for i, e := range sorted {
231 if added > o.MaxDictSize {
232 println("Ending. Next Occurrence:", e.n)
233 break
234 }
235 m, ok := output[e.hash]
236 if !ok {
237
238 continue
239 }
240 wantLen := e.n / uint32(hashBytes) / 4
241 if wantLen <= lowestOcc {
242 wantLen = lowestOcc
243 }
244
245 var tmp = make([]byte, 0, hashBytes*2)
246 {
247 sortedPrev := make([]match, 0, len(m.followBy))
248 for k, v := range m.preceededBy {
249 if _, ok := output[k]; v < wantLen || !ok {
250 continue
251 }
252 sortedPrev = append(sortedPrev, match{
253 hash: k,
254 n: v,
255 })
256 }
257 if len(sortedPrev) > 0 {
258 sort.Slice(sortedPrev, func(i, j int) bool {
259 return sortedPrev[i].n > sortedPrev[j].n
260 })
261 bestPrev := output[sortedPrev[0].hash]
262 tmp = append(tmp, bestPrev.value...)
263 }
264 }
265 tmp = append(tmp, m.value...)
266 delete(output, e.hash)
267
268 sortedFollow := make([]match, 0, len(m.followBy))
269 for {
270 var nh uint32
271 stopAfter := false
272 {
273 sortedFollow = sortedFollow[:0]
274 for k, v := range m.followBy {
275 if _, ok := output[k]; !ok {
276 continue
277 }
278 sortedFollow = append(sortedFollow, match{
279 hash: k,
280 n: v,
281 offset: offsets[k],
282 })
283 }
284 if len(sortedFollow) == 0 {
285
286
287 const stepBack = 2
288 if stepBack > 0 && len(tmp) >= hashBytes+stepBack {
289 var t8 [8]byte
290 copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:])
291 m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))]
292 if ok && len(m.followBy) > 0 {
293 found := []byte(nil)
294 for k := range m.followBy {
295 v, ok := output[k]
296 if !ok {
297 continue
298 }
299 found = v.value
300 break
301 }
302 if found != nil {
303 tmp = tmp[:len(tmp)-stepBack]
304 printf("Step back: %q + %q\n", string(tmp), string(found))
305 continue
306 }
307 }
308 break
309 } else {
310 if i < printUntil {
311 printf("FOLLOW: none after %q\n", string(m.value))
312 }
313 }
314 break
315 }
316 sort.Slice(sortedFollow, func(i, j int) bool {
317 if sortedFollow[i].n == sortedFollow[j].n {
318 return sortedFollow[i].offset > sortedFollow[j].offset
319 }
320 return sortedFollow[i].n > sortedFollow[j].n
321 })
322 nh = sortedFollow[0].hash
323 stopAfter = sortedFollow[0].n < wantLen
324 if stopAfter && i < printUntil {
325 printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value))
326 }
327 }
328 m, ok = output[nh]
329 if !ok {
330 break
331 }
332 if len(tmp) > 0 {
333
334 var toDel [16 + 8]byte
335 copy(toDel[:], tmp[len(tmp)-hashBytes:])
336 copy(toDel[hashBytes:], m.value)
337 for i := range toDel[:hashBytes*2] {
338 delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes)))
339 }
340 }
341 tmp = append(tmp, m.value...)
342
343 if stopAfter {
344
345 break
346 }
347 }
348 if i < printUntil {
349 printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen)
350 }
351
352 if len(tmp) > hashBytes {
353 for j := range tmp[:len(tmp)-hashBytes+1] {
354 var t8 [8]byte
355 copy(t8[:], tmp[j:])
356 if i < printUntil {
357
358 }
359 delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes)))
360 }
361 }
362 dst = append(dst, tmp)
363 added += len(tmp)
364
365
366 if len(firstOffsets) < 3 {
367 if len(tmp) > 16 {
368 tmp = tmp[:16]
369 }
370 offCnt := make(map[int]int, len(input))
371
372 for _, b := range input {
373 off := bytes.Index(b, tmp)
374 if off == -1 {
375 continue
376 }
377 offCnt[off]++
378 }
379 for _, off := range firstOffsets {
380
381 delete(offCnt, off-added)
382 }
383 maxCnt := 0
384 maxOffset := 0
385 for k, v := range offCnt {
386 if v == maxCnt && k > maxOffset {
387
388 maxCnt = v
389 maxOffset = k
390 continue
391 }
392
393 if v > maxCnt {
394 maxCnt = v
395 maxOffset = k
396 }
397 }
398 if maxCnt > 1 {
399 firstOffsets = append(firstOffsets, maxOffset+added)
400 println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset)
401 }
402 }
403 }
404 out := bytes.NewBuffer(nil)
405 written := 0
406 for i, toWrite := range dst {
407 if len(toWrite)+written > wantLen {
408 toWrite = toWrite[:wantLen-written]
409 }
410 dst[i] = toWrite
411 written += len(toWrite)
412 if written >= wantLen {
413 dst = dst[:i+1]
414 break
415 }
416 }
417
418 for i := range dst {
419 toWrite := dst[len(dst)-i-1]
420 out.Write(toWrite)
421 }
422 if o.outFormat == formatRaw {
423 return out.Bytes(), nil
424 }
425
426 if o.outFormat == formatS2 {
427 dOff := 0
428 dBytes := out.Bytes()
429 if len(dBytes) > s2.MaxDictSize {
430 dBytes = dBytes[:s2.MaxDictSize]
431 }
432 for _, off := range firstOffsets {
433 myOff := len(dBytes) - off
434 if myOff < 0 || myOff > s2.MaxDictSrcOffset {
435 continue
436 }
437 dOff = myOff
438 }
439
440 dict := s2.MakeDictManual(dBytes, uint16(dOff))
441 if dict == nil {
442 return nil, fmt.Errorf("unable to create s2 dictionary")
443 }
444 return dict.Bytes(), nil
445 }
446
447 offsetsZstd := [3]int{1, 4, 8}
448 for i, off := range firstOffsets {
449 if i >= 3 || off == 0 || off >= out.Len() {
450 break
451 }
452 offsetsZstd[i] = off
453 }
454 println("\nCompressing. Offsets:", offsetsZstd)
455 return zstd.BuildDict(zstd.BuildDictOptions{
456 ID: o.ZstdDictID,
457 Contents: input,
458 History: out.Bytes(),
459 Offsets: offsetsZstd,
460 CompatV155: o.ZstdDictCompat,
461 Level: o.ZstdLevel,
462 DebugOut: o.Output,
463 })
464 }
465
466 const (
467 prime3bytes = 506832829
468 prime4bytes = 2654435761
469 prime5bytes = 889523592379
470 prime6bytes = 227718039650203
471 prime7bytes = 58295818150454627
472 prime8bytes = 0xcf1bbcdcb7a56463
473 )
474
475
476
477
478
479
480 func hashLen(u uint64, hashLog, mls uint8) uint32 {
481 switch mls {
482 case 5:
483 return hash5(u, hashLog)
484 case 6:
485 return hash6(u, hashLog)
486 case 7:
487 return hash7(u, hashLog)
488 case 8:
489 return hash8(u, hashLog)
490 default:
491 return uint32(u)
492 }
493 }
494
495
496
497 func hash3(u uint32, h uint8) uint32 {
498 return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31)
499 }
500
501
502
503 func hash4(u uint32, h uint8) uint32 {
504 return (u * prime4bytes) >> ((32 - h) & 31)
505 }
506
507
508
509 func hash4x64(u uint64, h uint8) uint32 {
510 return (uint32(u) * prime4bytes) >> ((32 - h) & 31)
511 }
512
513
514
515 func hash5(u uint64, h uint8) uint32 {
516 return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63))
517 }
518
519
520
521 func hash6(u uint64, h uint8) uint32 {
522 return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63))
523 }
524
525
526
527 func hash7(u uint64, h uint8) uint32 {
528 return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63))
529 }
530
531
532
533 func hash8(u uint64, h uint8) uint32 {
534 return uint32((u * prime8bytes) >> ((64 - h) & 63))
535 }
536
View as plain text