...

Source file src/github.com/klauspost/compress/zstd/dict_test.go

Documentation: github.com/klauspost/compress/zstd

     1  package zstd
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/klauspost/compress/zip"
    12  )
    13  
    14  func TestDecoder_SmallDict(t *testing.T) {
    15  	// All files have CRC
    16  	zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
    17  	dicts := readDicts(t, zr)
    18  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
    19  	if err != nil {
    20  		t.Fatal(err)
    21  		return
    22  	}
    23  	defer dec.Close()
    24  	for _, tt := range zr.File {
    25  		if !strings.HasSuffix(tt.Name, ".zst") {
    26  			continue
    27  		}
    28  		t.Run("decodeall-"+tt.Name, func(t *testing.T) {
    29  			r, err := tt.Open()
    30  			if err != nil {
    31  				t.Fatal(err)
    32  			}
    33  			defer r.Close()
    34  			in, err := io.ReadAll(r)
    35  			if err != nil {
    36  				t.Fatal(err)
    37  			}
    38  			got, err := dec.DecodeAll(in, nil)
    39  			if err != nil {
    40  				t.Fatal(err)
    41  			}
    42  			_, err = dec.DecodeAll(in, got[:0])
    43  			if err != nil {
    44  				t.Fatal(err)
    45  			}
    46  		})
    47  	}
    48  }
    49  
    50  func TestEncoder_SmallDict(t *testing.T) {
    51  	// All files have CRC
    52  	zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
    53  	var dicts [][]byte
    54  	var encs []*Encoder
    55  	var noDictEncs []*Encoder
    56  	var encNames []string
    57  
    58  	for _, tt := range zr.File {
    59  		if !strings.HasSuffix(tt.Name, ".dict") {
    60  			continue
    61  		}
    62  		func() {
    63  			r, err := tt.Open()
    64  			if err != nil {
    65  				t.Fatal(err)
    66  			}
    67  			defer r.Close()
    68  			in, err := io.ReadAll(r)
    69  			if err != nil {
    70  				t.Fatal(err)
    71  			}
    72  			dicts = append(dicts, in)
    73  			for level := SpeedFastest; level < speedLast; level++ {
    74  				if isRaceTest && level >= SpeedBestCompression {
    75  					break
    76  				}
    77  				enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
    78  				if err != nil {
    79  					t.Fatal(err)
    80  				}
    81  				encs = append(encs, enc)
    82  				encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
    83  
    84  				enc, err = NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
    85  				if err != nil {
    86  					t.Fatal(err)
    87  				}
    88  				noDictEncs = append(noDictEncs, enc)
    89  			}
    90  		}()
    91  	}
    92  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
    93  	if err != nil {
    94  		t.Fatal(err)
    95  		return
    96  	}
    97  	defer dec.Close()
    98  	for i, tt := range zr.File {
    99  		if testing.Short() && i > 100 {
   100  			break
   101  		}
   102  		if !strings.HasSuffix(tt.Name, ".zst") {
   103  			continue
   104  		}
   105  		r, err := tt.Open()
   106  		if err != nil {
   107  			t.Fatal(err)
   108  		}
   109  		defer r.Close()
   110  		in, err := io.ReadAll(r)
   111  		if err != nil {
   112  			t.Fatal(err)
   113  		}
   114  		decoded, err := dec.DecodeAll(in, nil)
   115  		if err != nil {
   116  			t.Fatal(err)
   117  		}
   118  		if testing.Short() && len(decoded) > 1000 {
   119  			continue
   120  		}
   121  
   122  		t.Run("encodeall-"+tt.Name, func(t *testing.T) {
   123  			// Attempt to compress with all dicts
   124  			var b []byte
   125  			var tmp []byte
   126  			for i := range encs {
   127  				i := i
   128  				t.Run(encNames[i], func(t *testing.T) {
   129  					b = encs[i].EncodeAll(decoded, b[:0])
   130  					tmp, err = dec.DecodeAll(in, tmp[:0])
   131  					if err != nil {
   132  						t.Fatal(err)
   133  					}
   134  					if !bytes.Equal(tmp, decoded) {
   135  						t.Fatal("output mismatch")
   136  					}
   137  
   138  					tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
   139  
   140  					if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
   141  						t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
   142  						// Check that we reduced this significantly
   143  						if len(b) > 250 {
   144  							t.Error("output was bigger than expected")
   145  						}
   146  					}
   147  				})
   148  			}
   149  		})
   150  		t.Run("stream-"+tt.Name, func(t *testing.T) {
   151  			// Attempt to compress with all dicts
   152  			var tmp []byte
   153  			for i := range encs {
   154  				i := i
   155  				enc := encs[i]
   156  				t.Run(encNames[i], func(t *testing.T) {
   157  					var buf bytes.Buffer
   158  					enc.ResetContentSize(&buf, int64(len(decoded)))
   159  					_, err := enc.Write(decoded)
   160  					if err != nil {
   161  						t.Fatal(err)
   162  					}
   163  					err = enc.Close()
   164  					if err != nil {
   165  						t.Fatal(err)
   166  					}
   167  					tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
   168  					if err != nil {
   169  						t.Fatal(err)
   170  					}
   171  					if !bytes.Equal(tmp, decoded) {
   172  						t.Fatal("output mismatch")
   173  					}
   174  					var buf2 bytes.Buffer
   175  					noDictEncs[i].Reset(&buf2)
   176  					noDictEncs[i].Write(decoded)
   177  					noDictEncs[i].Close()
   178  
   179  					if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
   180  						t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
   181  						// Check that we reduced this significantly
   182  						if buf.Len() > 250 {
   183  							t.Error("output was bigger than expected")
   184  						}
   185  					}
   186  				})
   187  			}
   188  		})
   189  	}
   190  }
   191  
   192  func TestEncoder_SmallDictFresh(t *testing.T) {
   193  	// All files have CRC
   194  	zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
   195  	var dicts [][]byte
   196  	var encs []func() *Encoder
   197  	var noDictEncs []*Encoder
   198  	var encNames []string
   199  
   200  	for _, tt := range zr.File {
   201  		if !strings.HasSuffix(tt.Name, ".dict") {
   202  			continue
   203  		}
   204  		func() {
   205  			r, err := tt.Open()
   206  			if err != nil {
   207  				t.Fatal(err)
   208  			}
   209  			defer r.Close()
   210  			in, err := io.ReadAll(r)
   211  			if err != nil {
   212  				t.Fatal(err)
   213  			}
   214  			dicts = append(dicts, in)
   215  			for level := SpeedFastest; level < speedLast; level++ {
   216  				if isRaceTest && level >= SpeedBestCompression {
   217  					break
   218  				}
   219  				level := level
   220  				encs = append(encs, func() *Encoder {
   221  					enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
   222  					if err != nil {
   223  						t.Fatal(err)
   224  					}
   225  					return enc
   226  				})
   227  				encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
   228  
   229  				enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
   230  				if err != nil {
   231  					t.Fatal(err)
   232  				}
   233  				noDictEncs = append(noDictEncs, enc)
   234  			}
   235  		}()
   236  	}
   237  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
   238  	if err != nil {
   239  		t.Fatal(err)
   240  		return
   241  	}
   242  	defer dec.Close()
   243  	for i, tt := range zr.File {
   244  		if testing.Short() && i > 100 {
   245  			break
   246  		}
   247  		if !strings.HasSuffix(tt.Name, ".zst") {
   248  			continue
   249  		}
   250  		r, err := tt.Open()
   251  		if err != nil {
   252  			t.Fatal(err)
   253  		}
   254  		defer r.Close()
   255  		in, err := io.ReadAll(r)
   256  		if err != nil {
   257  			t.Fatal(err)
   258  		}
   259  		decoded, err := dec.DecodeAll(in, nil)
   260  		if err != nil {
   261  			t.Fatal(err)
   262  		}
   263  		if testing.Short() && len(decoded) > 1000 {
   264  			continue
   265  		}
   266  
   267  		t.Run("encodeall-"+tt.Name, func(t *testing.T) {
   268  			// Attempt to compress with all dicts
   269  			var b []byte
   270  			var tmp []byte
   271  			for i := range encs {
   272  				i := i
   273  				t.Run(encNames[i], func(t *testing.T) {
   274  					enc := encs[i]()
   275  					defer enc.Close()
   276  					b = enc.EncodeAll(decoded, b[:0])
   277  					tmp, err = dec.DecodeAll(in, tmp[:0])
   278  					if err != nil {
   279  						t.Fatal(err)
   280  					}
   281  					if !bytes.Equal(tmp, decoded) {
   282  						t.Fatal("output mismatch")
   283  					}
   284  
   285  					tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])
   286  
   287  					if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
   288  						t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
   289  						// Check that we reduced this significantly
   290  						if len(b) > 250 {
   291  							t.Error("output was bigger than expected")
   292  						}
   293  					}
   294  				})
   295  			}
   296  		})
   297  		t.Run("stream-"+tt.Name, func(t *testing.T) {
   298  			// Attempt to compress with all dicts
   299  			var tmp []byte
   300  			for i := range encs {
   301  				i := i
   302  				t.Run(encNames[i], func(t *testing.T) {
   303  					enc := encs[i]()
   304  					defer enc.Close()
   305  					var buf bytes.Buffer
   306  					enc.ResetContentSize(&buf, int64(len(decoded)))
   307  					_, err := enc.Write(decoded)
   308  					if err != nil {
   309  						t.Fatal(err)
   310  					}
   311  					err = enc.Close()
   312  					if err != nil {
   313  						t.Fatal(err)
   314  					}
   315  					tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
   316  					if err != nil {
   317  						t.Fatal(err)
   318  					}
   319  					if !bytes.Equal(tmp, decoded) {
   320  						t.Fatal("output mismatch")
   321  					}
   322  					var buf2 bytes.Buffer
   323  					noDictEncs[i].Reset(&buf2)
   324  					noDictEncs[i].Write(decoded)
   325  					noDictEncs[i].Close()
   326  
   327  					if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
   328  						t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
   329  						// Check that we reduced this significantly
   330  						if buf.Len() > 250 {
   331  							t.Error("output was bigger than expected")
   332  						}
   333  					}
   334  				})
   335  			}
   336  		})
   337  	}
   338  }
   339  
   340  func benchmarkEncodeAllLimitedBySize(b *testing.B, lowerLimit int, upperLimit int) {
   341  	zr := testCreateZipReader("testdata/dict-tests-small.zip", b)
   342  	t := testing.TB(b)
   343  
   344  	var dicts [][]byte
   345  	var encs []*Encoder
   346  	var encNames []string
   347  
   348  	for _, tt := range zr.File {
   349  		if !strings.HasSuffix(tt.Name, ".dict") {
   350  			continue
   351  		}
   352  		func() {
   353  			r, err := tt.Open()
   354  			if err != nil {
   355  				t.Fatal(err)
   356  			}
   357  			defer r.Close()
   358  			in, err := io.ReadAll(r)
   359  			if err != nil {
   360  				t.Fatal(err)
   361  			}
   362  			dicts = append(dicts, in)
   363  			for level := SpeedFastest; level < speedLast; level++ {
   364  				enc, err := NewWriter(nil, WithEncoderDict(in), WithEncoderLevel(level))
   365  				if err != nil {
   366  					t.Fatal(err)
   367  				}
   368  				encs = append(encs, enc)
   369  				encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))
   370  			}
   371  		}()
   372  	}
   373  	const nPer = int(speedLast - SpeedFastest)
   374  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
   375  	if err != nil {
   376  		t.Fatal(err)
   377  		return
   378  	}
   379  	defer dec.Close()
   380  
   381  	tested := make(map[int]struct{})
   382  	for j, tt := range zr.File {
   383  		if !strings.HasSuffix(tt.Name, ".zst") {
   384  			continue
   385  		}
   386  		r, err := tt.Open()
   387  		if err != nil {
   388  			t.Fatal(err)
   389  		}
   390  		defer r.Close()
   391  		in, err := io.ReadAll(r)
   392  		if err != nil {
   393  			t.Fatal(err)
   394  		}
   395  		decoded, err := dec.DecodeAll(in, nil)
   396  		if err != nil {
   397  			t.Fatal(err)
   398  		}
   399  
   400  		// Only test each size once
   401  		if _, ok := tested[len(decoded)]; ok {
   402  			continue
   403  		}
   404  		tested[len(decoded)] = struct{}{}
   405  
   406  		if len(decoded) < lowerLimit {
   407  			continue
   408  		}
   409  
   410  		if upperLimit > 0 && len(decoded) > upperLimit {
   411  			continue
   412  		}
   413  
   414  		for i := range encs {
   415  			// Only do 1 dict (4 encoders) for now.
   416  			if i == nPer-1 {
   417  				break
   418  			}
   419  			// Attempt to compress with all dicts
   420  			encIdx := (i + j*nPer) % len(encs)
   421  			enc := encs[encIdx]
   422  			b.Run(fmt.Sprintf("length-%d-%s", len(decoded), encNames[encIdx]), func(b *testing.B) {
   423  				b.RunParallel(func(pb *testing.PB) {
   424  					dst := make([]byte, 0, len(decoded)+10)
   425  					b.SetBytes(int64(len(decoded)))
   426  					b.ResetTimer()
   427  					b.ReportAllocs()
   428  					for pb.Next() {
   429  						dst = enc.EncodeAll(decoded, dst[:0])
   430  					}
   431  				})
   432  			})
   433  		}
   434  	}
   435  }
   436  
   437  func BenchmarkEncodeAllDict0_1024(b *testing.B) {
   438  	benchmarkEncodeAllLimitedBySize(b, 0, 1024)
   439  }
   440  
   441  func BenchmarkEncodeAllDict1024_8192(b *testing.B) {
   442  	benchmarkEncodeAllLimitedBySize(b, 1024, 8192)
   443  }
   444  
   445  func BenchmarkEncodeAllDict8192_16384(b *testing.B) {
   446  	benchmarkEncodeAllLimitedBySize(b, 8192, 16384)
   447  }
   448  
   449  func BenchmarkEncodeAllDict16384_65536(b *testing.B) {
   450  	benchmarkEncodeAllLimitedBySize(b, 16384, 65536)
   451  }
   452  
   453  func BenchmarkEncodeAllDict65536_0(b *testing.B) {
   454  	benchmarkEncodeAllLimitedBySize(b, 65536, 0)
   455  }
   456  
   457  func TestDecoder_MoreDicts(t *testing.T) {
   458  	// All files have CRC
   459  	// https://files.klauspost.com/compress/zstd-dict-tests.zip
   460  	fn := "testdata/zstd-dict-tests.zip"
   461  	data, err := os.ReadFile(fn)
   462  	if err != nil {
   463  		t.Skip("extended dict test not found.")
   464  	}
   465  	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
   466  	if err != nil {
   467  		t.Fatal(err)
   468  	}
   469  
   470  	var dicts [][]byte
   471  	for _, tt := range zr.File {
   472  		if !strings.HasSuffix(tt.Name, ".dict") {
   473  			continue
   474  		}
   475  		func() {
   476  			r, err := tt.Open()
   477  			if err != nil {
   478  				t.Fatal(err)
   479  			}
   480  			defer r.Close()
   481  			in, err := io.ReadAll(r)
   482  			if err != nil {
   483  				t.Fatal(err)
   484  			}
   485  			dicts = append(dicts, in)
   486  		}()
   487  	}
   488  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
   489  	if err != nil {
   490  		t.Fatal(err)
   491  		return
   492  	}
   493  	defer dec.Close()
   494  	for i, tt := range zr.File {
   495  		if !strings.HasSuffix(tt.Name, ".zst") {
   496  			continue
   497  		}
   498  		if testing.Short() && i > 50 {
   499  			continue
   500  		}
   501  		t.Run("decodeall-"+tt.Name, func(t *testing.T) {
   502  			r, err := tt.Open()
   503  			if err != nil {
   504  				t.Fatal(err)
   505  			}
   506  			defer r.Close()
   507  			in, err := io.ReadAll(r)
   508  			if err != nil {
   509  				t.Fatal(err)
   510  			}
   511  			got, err := dec.DecodeAll(in, nil)
   512  			if err != nil {
   513  				t.Fatal(err)
   514  			}
   515  			_, err = dec.DecodeAll(in, got[:0])
   516  			if err != nil {
   517  				t.Fatal(err)
   518  			}
   519  		})
   520  	}
   521  }
   522  
   523  func TestDecoder_MoreDicts2(t *testing.T) {
   524  	// All files have CRC
   525  	// https://files.klauspost.com/compress/zstd-dict-tests.zip
   526  	fn := "testdata/zstd-dict-tests.zip"
   527  	data, err := os.ReadFile(fn)
   528  	if err != nil {
   529  		t.Skip("extended dict test not found.")
   530  	}
   531  	zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
   532  	if err != nil {
   533  		t.Fatal(err)
   534  	}
   535  
   536  	var dicts [][]byte
   537  	for _, tt := range zr.File {
   538  		if !strings.HasSuffix(tt.Name, ".dict") {
   539  			continue
   540  		}
   541  		func() {
   542  			r, err := tt.Open()
   543  			if err != nil {
   544  				t.Fatal(err)
   545  			}
   546  			defer r.Close()
   547  			in, err := io.ReadAll(r)
   548  			if err != nil {
   549  				t.Fatal(err)
   550  			}
   551  			dicts = append(dicts, in)
   552  		}()
   553  	}
   554  	dec, err := NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...))
   555  	if err != nil {
   556  		t.Fatal(err)
   557  		return
   558  	}
   559  	defer dec.Close()
   560  	for i, tt := range zr.File {
   561  		if !strings.HasSuffix(tt.Name, ".zst") {
   562  			continue
   563  		}
   564  		if testing.Short() && i > 50 {
   565  			continue
   566  		}
   567  		t.Run("decodeall-"+tt.Name, func(t *testing.T) {
   568  			r, err := tt.Open()
   569  			if err != nil {
   570  				t.Fatal(err)
   571  			}
   572  			defer r.Close()
   573  			in, err := io.ReadAll(r)
   574  			if err != nil {
   575  				t.Fatal(err)
   576  			}
   577  			got, err := dec.DecodeAll(in, nil)
   578  			if err != nil {
   579  				t.Fatal(err)
   580  			}
   581  			_, err = dec.DecodeAll(in, got[:0])
   582  			if err != nil {
   583  				t.Fatal(err)
   584  			}
   585  		})
   586  	}
   587  }
   588  
   589  func readDicts(tb testing.TB, zr *zip.Reader) [][]byte {
   590  	var dicts [][]byte
   591  	for _, tt := range zr.File {
   592  		if !strings.HasSuffix(tt.Name, ".dict") {
   593  			continue
   594  		}
   595  		func() {
   596  			r, err := tt.Open()
   597  			if err != nil {
   598  				tb.Fatal(err)
   599  			}
   600  			defer r.Close()
   601  			in, err := io.ReadAll(r)
   602  			if err != nil {
   603  				tb.Fatal(err)
   604  			}
   605  			dicts = append(dicts, in)
   606  		}()
   607  	}
   608  	return dicts
   609  }
   610  
   611  // Test decoding of zstd --patch-from output.
   612  func TestDecoderRawDict(t *testing.T) {
   613  	t.Parallel()
   614  
   615  	dict, err := os.ReadFile("testdata/delta/source.txt")
   616  	if err != nil {
   617  		t.Fatal(err)
   618  	}
   619  
   620  	delta, err := os.Open("testdata/delta/target.txt.zst")
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	defer delta.Close()
   625  
   626  	dec, err := NewReader(delta, WithDecoderDictRaw(0, dict))
   627  	if err != nil {
   628  		t.Fatal(err)
   629  	}
   630  
   631  	out, err := io.ReadAll(dec)
   632  	if err != nil {
   633  		t.Fatal(err)
   634  	}
   635  
   636  	ref, err := os.ReadFile("testdata/delta/target.txt")
   637  	if err != nil {
   638  		t.Fatal(err)
   639  	}
   640  
   641  	if !bytes.Equal(out, ref) {
   642  		t.Errorf("mismatch: got %q, wanted %q", out, ref)
   643  	}
   644  }
   645  

View as plain text