...

Source file src/github.com/klauspost/compress/zstd/decoder_test.go

Documentation: github.com/klauspost/compress/zstd

     1  // Copyright 2019+ Klaus Post. All rights reserved.
     2  // License information can be found in the LICENSE file.
     3  // Based on work by Yann Collet, released under BSD License.
     4  
     5  package zstd
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"encoding/binary"
    11  	"encoding/hex"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"log"
    16  	"math/rand"
    17  	"os"
    18  	"path/filepath"
    19  	"reflect"
    20  	"runtime"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	// "github.com/DataDog/zstd"
    28  	// zstd "github.com/valyala/gozstd"
    29  
    30  	"github.com/klauspost/compress/zstd/internal/xxhash"
    31  )
    32  
    33  func TestNewReaderMismatch(t *testing.T) {
    34  	// To identify a potential decoding error, do the following steps:
    35  	// 1) Place the compressed file in testdata, eg 'testdata/backup.bin.zst'
    36  	// 2) Decompress the file to using zstd, so it will be named 'testdata/backup.bin'
    37  	// 3) Run the test. A hash file will be generated 'testdata/backup.bin.hash'
    38  	// 4) The decoder will also run and decode the file. It will stop as soon as a mismatch is found.
    39  	// The hash file will be reused between runs if present.
    40  	const baseFile = "testdata/backup.bin"
    41  	const blockSize = 1024
    42  	hashes, err := os.ReadFile(baseFile + ".hash")
    43  	if os.IsNotExist(err) {
    44  		// Create the hash file.
    45  		f, err := os.Open(baseFile)
    46  		if os.IsNotExist(err) {
    47  			t.Skip("no decompressed file found")
    48  			return
    49  		}
    50  		defer f.Close()
    51  		br := bufio.NewReader(f)
    52  		var tmp [8]byte
    53  		xx := xxhash.New()
    54  		for {
    55  			xx.Reset()
    56  			buf := make([]byte, blockSize)
    57  			n, err := io.ReadFull(br, buf)
    58  			if err != nil {
    59  				if err != io.EOF && err != io.ErrUnexpectedEOF {
    60  					t.Fatal(err)
    61  				}
    62  			}
    63  			if n > 0 {
    64  				_, _ = xx.Write(buf[:n])
    65  				binary.LittleEndian.PutUint64(tmp[:], xx.Sum64())
    66  				hashes = append(hashes, tmp[4:]...)
    67  			}
    68  			if n != blockSize {
    69  				break
    70  			}
    71  		}
    72  		err = os.WriteFile(baseFile+".hash", hashes, os.ModePerm)
    73  		if err != nil {
    74  			// We can continue for now
    75  			t.Error(err)
    76  		}
    77  		t.Log("Saved", len(hashes)/4, "hashes as", baseFile+".hash")
    78  	}
    79  
    80  	f, err := os.Open(baseFile + ".zst")
    81  	if os.IsNotExist(err) {
    82  		t.Skip("no compressed file found")
    83  		return
    84  	}
    85  	defer f.Close()
    86  	dec, err := NewReader(f, WithDecoderConcurrency(1))
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	defer dec.Close()
    91  	var tmp [8]byte
    92  	xx := xxhash.New()
    93  	var cHash int
    94  	for {
    95  		xx.Reset()
    96  		buf := make([]byte, blockSize)
    97  		n, err := io.ReadFull(dec, buf)
    98  		if err != nil {
    99  			if err != io.EOF && err != io.ErrUnexpectedEOF {
   100  				t.Fatal("block", cHash, "err:", err)
   101  			}
   102  		}
   103  		if n > 0 {
   104  			if cHash+4 > len(hashes) {
   105  				extra, _ := io.Copy(io.Discard, dec)
   106  				t.Fatal("not enough hashes (length mismatch). Only have", len(hashes)/4, "hashes. Got block of", n, "bytes and", extra, "bytes still on stream.")
   107  			}
   108  			_, _ = xx.Write(buf[:n])
   109  			binary.LittleEndian.PutUint64(tmp[:], xx.Sum64())
   110  			want, got := hashes[cHash:cHash+4], tmp[4:]
   111  			if !bytes.Equal(want, got) {
   112  				org, err := os.Open(baseFile)
   113  				if err == nil {
   114  					const sizeBack = 8 << 20
   115  					defer org.Close()
   116  					start := int64(cHash)/4*blockSize - sizeBack
   117  					if start < 0 {
   118  						start = 0
   119  					}
   120  					_, err = org.Seek(start, io.SeekStart)
   121  					if err != nil {
   122  						t.Fatal(err)
   123  					}
   124  					buf2 := make([]byte, sizeBack+1<<20)
   125  					n, _ := io.ReadFull(org, buf2)
   126  					if n > 0 {
   127  						err = os.WriteFile(baseFile+".section", buf2[:n], os.ModePerm)
   128  						if err == nil {
   129  							t.Log("Wrote problematic section to", baseFile+".section")
   130  						}
   131  					}
   132  				}
   133  
   134  				t.Fatal("block", cHash/4, "offset", cHash/4*blockSize, "hash mismatch, want:", hex.EncodeToString(want), "got:", hex.EncodeToString(got))
   135  			}
   136  			cHash += 4
   137  		}
   138  		if n != blockSize {
   139  			break
   140  		}
   141  	}
   142  	t.Log("Output matched")
   143  }
   144  
   145  type errorReader struct {
   146  	err error
   147  }
   148  
   149  func (r *errorReader) Read(p []byte) (int, error) {
   150  	return 0, r.err
   151  }
   152  
   153  func TestErrorReader(t *testing.T) {
   154  	wantErr := fmt.Errorf("i'm a failure")
   155  	zr, err := NewReader(&errorReader{err: wantErr})
   156  	if err != nil {
   157  		t.Fatal(err)
   158  	}
   159  	defer zr.Close()
   160  
   161  	_, err = io.ReadAll(zr)
   162  	if !errors.Is(err, wantErr) {
   163  		t.Errorf("want error %v, got %v", wantErr, err)
   164  	}
   165  }
   166  
   167  type failingWriter struct {
   168  	err error
   169  }
   170  
   171  func (f failingWriter) Write(_ []byte) (n int, err error) {
   172  	return 0, f.err
   173  }
   174  
   175  func TestErrorWriter(t *testing.T) {
   176  	input := make([]byte, 100)
   177  	cmp := bytes.Buffer{}
   178  	w, err := NewWriter(&cmp)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	_, _ = rand.Read(input)
   183  	_, err = w.Write(input)
   184  	if err != nil {
   185  		t.Fatal(err)
   186  	}
   187  	err = w.Close()
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  	wantErr := fmt.Errorf("i'm a failure")
   192  	zr, err := NewReader(&cmp)
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	defer zr.Close()
   197  	out := failingWriter{err: wantErr}
   198  	_, err = zr.WriteTo(out)
   199  	if !errors.Is(err, wantErr) {
   200  		t.Errorf("error: wanted: %v, got: %v", wantErr, err)
   201  	}
   202  }
   203  
   204  func TestNewDecoder(t *testing.T) {
   205  	for _, n := range []int{1, 4} {
   206  		for _, ignoreCRC := range []bool{false, true} {
   207  			t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) {
   208  				newFn := func() (*Decoder, error) {
   209  					return NewReader(nil, WithDecoderConcurrency(n), IgnoreChecksum(ignoreCRC))
   210  				}
   211  				testDecoderFile(t, "testdata/decoder.zip", newFn)
   212  				dec, err := newFn()
   213  				if err != nil {
   214  					t.Fatal(err)
   215  				}
   216  				testDecoderDecodeAll(t, "testdata/decoder.zip", dec)
   217  			})
   218  		}
   219  	}
   220  }
   221  
   222  func TestNewDecoderMemory(t *testing.T) {
   223  	defer timeout(60 * time.Second)()
   224  	var testdata bytes.Buffer
   225  	enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
   226  	if err != nil {
   227  		t.Fatal(err)
   228  	}
   229  	// Write 256KB
   230  	for i := 0; i < 256; i++ {
   231  		tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
   232  		_, err := enc.Write([]byte(tmp))
   233  		if err != nil {
   234  			t.Fatal(err)
   235  		}
   236  	}
   237  	err = enc.Close()
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	var n = 5000
   243  	if testing.Short() {
   244  		n = 200
   245  	}
   246  
   247  	// 16K buffer
   248  	var tmp [16 << 10]byte
   249  
   250  	var before, after runtime.MemStats
   251  	runtime.GC()
   252  	runtime.ReadMemStats(&before)
   253  
   254  	var decs = make([]*Decoder, n)
   255  	for i := range decs {
   256  		// Wrap in NopCloser to avoid shortcut.
   257  		input := io.NopCloser(bytes.NewBuffer(testdata.Bytes()))
   258  		decs[i], err = NewReader(input, WithDecoderConcurrency(1), WithDecoderLowmem(true))
   259  		if err != nil {
   260  			t.Fatal(err)
   261  		}
   262  	}
   263  
   264  	for i := range decs {
   265  		_, err := io.ReadFull(decs[i], tmp[:])
   266  		if err != nil {
   267  			t.Fatal(err)
   268  		}
   269  	}
   270  
   271  	runtime.GC()
   272  	runtime.ReadMemStats(&after)
   273  	size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024
   274  
   275  	const expect = 124
   276  	t.Log(size, "KiB per decoder")
   277  	// This is not exact science, but fail if we suddenly get more than 2x what we expect.
   278  	if size > expect*2 && !testing.Short() {
   279  		t.Errorf("expected < %dKB per decoder, got %d", expect, size)
   280  	}
   281  
   282  	for _, dec := range decs {
   283  		dec.Close()
   284  	}
   285  }
   286  
   287  func TestNewDecoderMemoryHighMem(t *testing.T) {
   288  	defer timeout(60 * time.Second)()
   289  	var testdata bytes.Buffer
   290  	enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  	// Write 256KB
   295  	for i := 0; i < 256; i++ {
   296  		tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
   297  		_, err := enc.Write([]byte(tmp))
   298  		if err != nil {
   299  			t.Fatal(err)
   300  		}
   301  	}
   302  	err = enc.Close()
   303  	if err != nil {
   304  		t.Fatal(err)
   305  	}
   306  
   307  	var n = 50
   308  	if testing.Short() {
   309  		n = 10
   310  	}
   311  
   312  	// 16K buffer
   313  	var tmp [16 << 10]byte
   314  
   315  	var before, after runtime.MemStats
   316  	runtime.GC()
   317  	runtime.ReadMemStats(&before)
   318  
   319  	var decs = make([]*Decoder, n)
   320  	for i := range decs {
   321  		// Wrap in NopCloser to avoid shortcut.
   322  		input := io.NopCloser(bytes.NewBuffer(testdata.Bytes()))
   323  		decs[i], err = NewReader(input, WithDecoderConcurrency(1), WithDecoderLowmem(false))
   324  		if err != nil {
   325  			t.Fatal(err)
   326  		}
   327  	}
   328  
   329  	for i := range decs {
   330  		_, err := io.ReadFull(decs[i], tmp[:])
   331  		if err != nil {
   332  			t.Fatal(err)
   333  		}
   334  	}
   335  
   336  	runtime.GC()
   337  	runtime.ReadMemStats(&after)
   338  	size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024
   339  
   340  	const expect = 3915
   341  	t.Log(size, "KiB per decoder")
   342  	// This is not exact science, but fail if we suddenly get more than 2x what we expect.
   343  	if size > expect*2 && !testing.Short() {
   344  		t.Errorf("expected < %dKB per decoder, got %d", expect, size)
   345  	}
   346  
   347  	for _, dec := range decs {
   348  		dec.Close()
   349  	}
   350  }
   351  
   352  func TestNewDecoderFrameSize(t *testing.T) {
   353  	defer timeout(60 * time.Second)()
   354  	var testdata bytes.Buffer
   355  	enc, err := NewWriter(&testdata, WithWindowSize(64<<10))
   356  	if err != nil {
   357  		t.Fatal(err)
   358  	}
   359  	// Write 256KB
   360  	for i := 0; i < 256; i++ {
   361  		tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
   362  		_, err := enc.Write([]byte(tmp))
   363  		if err != nil {
   364  			t.Fatal(err)
   365  		}
   366  	}
   367  	err = enc.Close()
   368  	if err != nil {
   369  		t.Fatal(err)
   370  	}
   371  	// Must fail
   372  	dec, err := NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(32<<10))
   373  	if err != nil {
   374  		t.Fatal(err)
   375  	}
   376  	_, err = io.Copy(io.Discard, dec)
   377  	if err == nil {
   378  		dec.Close()
   379  		t.Fatal("Wanted error, got none")
   380  	}
   381  	dec.Close()
   382  
   383  	// Must succeed.
   384  	dec, err = NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(64<<10))
   385  	if err != nil {
   386  		t.Fatal(err)
   387  	}
   388  	_, err = io.Copy(io.Discard, dec)
   389  	if err != nil {
   390  		dec.Close()
   391  		t.Fatalf("Wanted no error, got %+v", err)
   392  	}
   393  	dec.Close()
   394  }
   395  
   396  func TestNewDecoderGood(t *testing.T) {
   397  	for _, n := range []int{1, 4} {
   398  		t.Run(fmt.Sprintf("cpu-%d", n), func(t *testing.T) {
   399  			newFn := func() (*Decoder, error) {
   400  				return NewReader(nil, WithDecoderConcurrency(n))
   401  			}
   402  			testDecoderFile(t, "testdata/good.zip", newFn)
   403  			dec, err := newFn()
   404  			if err != nil {
   405  				t.Fatal(err)
   406  			}
   407  			testDecoderDecodeAll(t, "testdata/good.zip", dec)
   408  		})
   409  	}
   410  }
   411  
   412  func TestNewDecoderBad(t *testing.T) {
   413  	var errMap = make(map[string]string)
   414  	if true {
   415  		t.Run("Reader-4", func(t *testing.T) {
   416  			newFn := func() (*Decoder, error) {
   417  				return NewReader(nil, WithDecoderConcurrency(4), WithDecoderMaxMemory(1<<30))
   418  			}
   419  			testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
   420  
   421  		})
   422  		t.Run("Reader-1", func(t *testing.T) {
   423  			newFn := func() (*Decoder, error) {
   424  				return NewReader(nil, WithDecoderConcurrency(1), WithDecoderMaxMemory(1<<30))
   425  			}
   426  			testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
   427  		})
   428  		t.Run("Reader-4-bigmem", func(t *testing.T) {
   429  			newFn := func() (*Decoder, error) {
   430  				return NewReader(nil, WithDecoderConcurrency(4), WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
   431  			}
   432  			testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
   433  
   434  		})
   435  		t.Run("Reader-1-bigmem", func(t *testing.T) {
   436  			newFn := func() (*Decoder, error) {
   437  				return NewReader(nil, WithDecoderConcurrency(1), WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
   438  			}
   439  			testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
   440  		})
   441  	}
   442  	t.Run("DecodeAll", func(t *testing.T) {
   443  		defer timeout(10 * time.Second)()
   444  		dec, err := NewReader(nil, WithDecoderMaxMemory(1<<30))
   445  		if err != nil {
   446  			t.Fatal(err)
   447  		}
   448  		testDecoderDecodeAllError(t, "testdata/bad.zip", dec, errMap)
   449  	})
   450  	t.Run("DecodeAll-bigmem", func(t *testing.T) {
   451  		defer timeout(10 * time.Second)()
   452  		dec, err := NewReader(nil, WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
   453  		if err != nil {
   454  			t.Fatal(err)
   455  		}
   456  		testDecoderDecodeAllError(t, "testdata/bad.zip", dec, errMap)
   457  	})
   458  }
   459  
   460  func TestNewDecoderLarge(t *testing.T) {
   461  	newFn := func() (*Decoder, error) {
   462  		return NewReader(nil)
   463  	}
   464  	testDecoderFile(t, "testdata/large.zip", newFn)
   465  	dec, err := NewReader(nil)
   466  	if err != nil {
   467  		t.Fatal(err)
   468  	}
   469  	testDecoderDecodeAll(t, "testdata/large.zip", dec)
   470  }
   471  
   472  func TestNewReaderRead(t *testing.T) {
   473  	dec, err := NewReader(nil)
   474  	if err != nil {
   475  		t.Fatal(err)
   476  	}
   477  	defer dec.Close()
   478  	_, err = dec.Read([]byte{0})
   479  	if err == nil {
   480  		t.Fatal("Wanted error on uninitialized read, got nil")
   481  	}
   482  	t.Log("correctly got error", err)
   483  }
   484  
   485  func TestNewDecoderBig(t *testing.T) {
   486  	if testing.Short() || isRaceTest {
   487  		t.SkipNow()
   488  	}
   489  	file := "testdata/zstd-10kfiles.zip"
   490  	if _, err := os.Stat(file); os.IsNotExist(err) {
   491  		t.Skip("To run extended tests, download https://files.klauspost.com/compress/zstd-10kfiles.zip \n" +
   492  			"and place it in " + file + "\n" + "Running it requires about 5GB of RAM")
   493  	}
   494  	newFn := func() (*Decoder, error) {
   495  		return NewReader(nil)
   496  	}
   497  	testDecoderFile(t, file, newFn)
   498  	dec, err := NewReader(nil)
   499  	if err != nil {
   500  		t.Fatal(err)
   501  	}
   502  	testDecoderDecodeAll(t, file, dec)
   503  }
   504  
   505  func TestNewDecoderBigFile(t *testing.T) {
   506  	if testing.Short() || isRaceTest {
   507  		t.SkipNow()
   508  	}
   509  	file := "testdata/enwik9.zst"
   510  	const wantSize = 1000000000
   511  	if _, err := os.Stat(file); os.IsNotExist(err) {
   512  		t.Skip("To run extended tests, download http://mattmahoney.net/dc/enwik9.zip unzip it \n" +
   513  			"compress it with 'zstd -15 -T0 enwik9' and place it in " + file)
   514  	}
   515  	f, err := os.Open(file)
   516  	if err != nil {
   517  		t.Fatal(err)
   518  	}
   519  	defer f.Close()
   520  	start := time.Now()
   521  	dec, err := NewReader(f)
   522  	if err != nil {
   523  		t.Fatal(err)
   524  	}
   525  	defer dec.Close()
   526  	n, err := io.Copy(io.Discard, dec)
   527  	if err != nil {
   528  		t.Fatal(err)
   529  	}
   530  	if n != wantSize {
   531  		t.Errorf("want size %d, got size %d", wantSize, n)
   532  	}
   533  	elapsed := time.Since(start)
   534  	mbpersec := (float64(n) / (1024 * 1024)) / (float64(elapsed) / (float64(time.Second)))
   535  	t.Logf("Decoded %d bytes with %f.2 MB/s", n, mbpersec)
   536  }
   537  
   538  func TestNewDecoderSmallFile(t *testing.T) {
   539  	if testing.Short() {
   540  		t.SkipNow()
   541  	}
   542  	file := "testdata/z000028.zst"
   543  	const wantSize = 39807
   544  	f, err := os.Open(file)
   545  	if err != nil {
   546  		t.Fatal(err)
   547  	}
   548  	defer f.Close()
   549  	start := time.Now()
   550  	dec, err := NewReader(f)
   551  	if err != nil {
   552  		t.Fatal(err)
   553  	}
   554  	defer dec.Close()
   555  	n, err := io.Copy(io.Discard, dec)
   556  	if err != nil {
   557  		t.Fatal(err)
   558  	}
   559  	if n != wantSize {
   560  		t.Errorf("want size %d, got size %d", wantSize, n)
   561  	}
   562  	mbpersec := (float64(n) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
   563  	t.Logf("Decoded %d bytes with %f.2 MB/s", n, mbpersec)
   564  }
   565  
   566  // cursedReader wraps a reader and returns zero bytes every other read.
   567  // This is used to test the ability of the consumer to handle empty reads without EOF,
   568  // which can happen when reading from a network connection.
   569  type cursedReader struct {
   570  	io.Reader
   571  	numReads int
   572  }
   573  
   574  func (r *cursedReader) Read(p []byte) (n int, err error) {
   575  	r.numReads++
   576  	if r.numReads%2 == 0 {
   577  		return 0, nil
   578  	}
   579  
   580  	return r.Reader.Read(p)
   581  }
   582  
   583  func TestNewDecoderZeroLengthReads(t *testing.T) {
   584  	if testing.Short() {
   585  		t.SkipNow()
   586  	}
   587  	file := "testdata/z000028.zst"
   588  	const wantSize = 39807
   589  	f, err := os.Open(file)
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  	defer f.Close()
   594  	dec, err := NewReader(&cursedReader{Reader: f})
   595  	if err != nil {
   596  		t.Fatal(err)
   597  	}
   598  	defer dec.Close()
   599  	n, err := io.Copy(io.Discard, dec)
   600  	if err != nil {
   601  		t.Fatal(err)
   602  	}
   603  	if n != wantSize {
   604  		t.Errorf("want size %d, got size %d", wantSize, n)
   605  	}
   606  }
   607  
   608  type readAndBlock struct {
   609  	buf     []byte
   610  	unblock chan struct{}
   611  }
   612  
   613  func (r *readAndBlock) Read(p []byte) (int, error) {
   614  	n := copy(p, r.buf)
   615  	if n == 0 {
   616  		<-r.unblock
   617  		return 0, io.EOF
   618  	}
   619  	r.buf = r.buf[n:]
   620  	return n, nil
   621  }
   622  
   623  func TestNewDecoderFlushed(t *testing.T) {
   624  	if testing.Short() {
   625  		t.SkipNow()
   626  	}
   627  	file := "testdata/z000028.zst"
   628  	payload, err := os.ReadFile(file)
   629  	if err != nil {
   630  		t.Fatal(err)
   631  	}
   632  	payload = append(payload, payload...) //2x
   633  	payload = append(payload, payload...) //4x
   634  	payload = append(payload, payload...) //8x
   635  	rng := rand.New(rand.NewSource(0x1337))
   636  	runs := 100
   637  	if testing.Short() {
   638  		runs = 5
   639  	}
   640  	enc, err := NewWriter(nil, WithWindowSize(128<<10))
   641  	if err != nil {
   642  		t.Fatal(err)
   643  	}
   644  	defer enc.Close()
   645  	for i := 0; i < runs; i++ {
   646  		wantSize := rng.Intn(len(payload)-1) + 1
   647  		t.Run(fmt.Sprint("size-", wantSize), func(t *testing.T) {
   648  			var encoded bytes.Buffer
   649  			enc.Reset(&encoded)
   650  			_, err := enc.Write(payload[:wantSize])
   651  			if err != nil {
   652  				t.Fatal(err)
   653  			}
   654  			err = enc.Flush()
   655  			if err != nil {
   656  				t.Fatal(err)
   657  			}
   658  
   659  			// We must be able to read back up until the flush...
   660  			r := readAndBlock{
   661  				buf:     encoded.Bytes(),
   662  				unblock: make(chan struct{}),
   663  			}
   664  			defer timeout(5 * time.Second)()
   665  			dec, err := NewReader(&r)
   666  			if err != nil {
   667  				t.Fatal(err)
   668  			}
   669  			defer dec.Close()
   670  			defer close(r.unblock)
   671  			readBack := 0
   672  			dst := make([]byte, 1024)
   673  			for readBack < wantSize {
   674  				// Read until we have enough.
   675  				n, err := dec.Read(dst)
   676  				if err != nil {
   677  					t.Fatal(err)
   678  				}
   679  				readBack += n
   680  			}
   681  		})
   682  	}
   683  }
   684  
   685  func TestDecoderRegression(t *testing.T) {
   686  	defer timeout(160 * time.Second)()
   687  
   688  	zr := testCreateZipReader("testdata/regression.zip", t)
   689  	dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
   690  	if err != nil {
   691  		t.Error(err)
   692  		return
   693  	}
   694  	defer dec.Close()
   695  	for i, tt := range zr.File {
   696  		if testing.Short() && i > 10 {
   697  			continue
   698  		}
   699  		t.Run("Reader-"+tt.Name, func(t *testing.T) {
   700  			r, err := tt.Open()
   701  			if err != nil {
   702  				t.Error(err)
   703  				return
   704  			}
   705  			err = dec.Reset(r)
   706  			if err != nil {
   707  				t.Error(err)
   708  				return
   709  			}
   710  			got, gotErr := io.ReadAll(dec)
   711  			t.Log("Received:", len(got), gotErr)
   712  
   713  			// Check a fresh instance
   714  			r, err = tt.Open()
   715  			if err != nil {
   716  				t.Error(err)
   717  				return
   718  			}
   719  			decL, err := NewReader(r, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
   720  			if err != nil {
   721  				t.Error(err)
   722  				return
   723  			}
   724  			defer decL.Close()
   725  			got2, gotErr2 := io.ReadAll(decL)
   726  			t.Log("Fresh Reader received:", len(got2), gotErr2)
   727  			if gotErr != gotErr2 {
   728  				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
   729  					t.Error(gotErr, "!=", gotErr2)
   730  				}
   731  				if (gotErr == nil) != (gotErr2 == nil) {
   732  					t.Error(gotErr, "!=", gotErr2)
   733  				}
   734  			}
   735  			if !bytes.Equal(got2, got) {
   736  				if gotErr != nil {
   737  					t.Log("Buffer mismatch without Reset")
   738  				} else {
   739  					t.Error("Buffer mismatch without Reset")
   740  				}
   741  			}
   742  		})
   743  		t.Run("DecodeAll-"+tt.Name, func(t *testing.T) {
   744  			r, err := tt.Open()
   745  			if err != nil {
   746  				t.Error(err)
   747  				return
   748  			}
   749  			in, err := io.ReadAll(r)
   750  			if err != nil {
   751  				t.Error(err)
   752  			}
   753  			got, gotErr := dec.DecodeAll(in, make([]byte, 0, len(in)))
   754  			t.Log("Received:", len(got), gotErr)
   755  
   756  			// Check if we got the same:
   757  			decL, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
   758  			if err != nil {
   759  				t.Error(err)
   760  				return
   761  			}
   762  			defer decL.Close()
   763  			got2, gotErr2 := decL.DecodeAll(in, make([]byte, 0, len(in)/2))
   764  			t.Log("Fresh Reader received:", len(got2), gotErr2)
   765  			if gotErr != gotErr2 {
   766  				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
   767  					t.Error(gotErr, "!=", gotErr2)
   768  				}
   769  				if (gotErr == nil) != (gotErr2 == nil) {
   770  					t.Error(gotErr, "!=", gotErr2)
   771  				}
   772  			}
   773  			if !bytes.Equal(got2, got) {
   774  				if gotErr != nil {
   775  					t.Log("Buffer mismatch without Reset")
   776  				} else {
   777  					t.Error("Buffer mismatch without Reset")
   778  				}
   779  			}
   780  		})
   781  		t.Run("Match-"+tt.Name, func(t *testing.T) {
   782  			r, err := tt.Open()
   783  			if err != nil {
   784  				t.Error(err)
   785  				return
   786  			}
   787  			in, err := io.ReadAll(r)
   788  			if err != nil {
   789  				t.Error(err)
   790  			}
   791  			got, gotErr := dec.DecodeAll(in, make([]byte, 0, len(in)))
   792  			t.Log("Received:", len(got), gotErr)
   793  
   794  			// Check a fresh instance
   795  			decL, err := NewReader(bytes.NewBuffer(in), WithDecoderConcurrency(1), WithDecoderLowmem(true), WithDecoderMaxMemory(1<<20))
   796  			if err != nil {
   797  				t.Error(err)
   798  				return
   799  			}
   800  			defer decL.Close()
   801  			got2, gotErr2 := io.ReadAll(decL)
   802  			t.Log("Reader Reader received:", len(got2), gotErr2)
   803  			if gotErr != gotErr2 {
   804  				if gotErr != nil && gotErr2 != nil && gotErr.Error() != gotErr2.Error() {
   805  					t.Error(gotErr, "!=", gotErr2)
   806  				}
   807  				if (gotErr == nil) != (gotErr2 == nil) {
   808  					t.Error(gotErr, "!=", gotErr2)
   809  				}
   810  			}
   811  			if !bytes.Equal(got2, got) {
   812  				if gotErr != nil {
   813  					t.Log("Buffer mismatch")
   814  				} else {
   815  					t.Error("Buffer mismatch")
   816  				}
   817  			}
   818  		})
   819  	}
   820  }
   821  
   822  func TestShort(t *testing.T) {
   823  	for _, in := range []string{"f", "fo", "foo"} {
   824  		inb := []byte(in)
   825  		dec, err := NewReader(nil)
   826  		if err != nil {
   827  			t.Fatal(err)
   828  		}
   829  		defer dec.Close()
   830  
   831  		t.Run(fmt.Sprintf("DecodeAll-%d", len(in)), func(t *testing.T) {
   832  			_, err := dec.DecodeAll(inb, nil)
   833  			if err == nil {
   834  				t.Error("want error, got nil")
   835  			}
   836  		})
   837  		t.Run(fmt.Sprintf("Reader-%d", len(in)), func(t *testing.T) {
   838  			dec.Reset(bytes.NewReader(inb))
   839  			_, err := io.Copy(io.Discard, dec)
   840  			if err == nil {
   841  				t.Error("want error, got nil")
   842  			}
   843  		})
   844  	}
   845  }
   846  
   847  func TestDecoder_Reset(t *testing.T) {
   848  	in, err := os.ReadFile("testdata/z000028")
   849  	if err != nil {
   850  		t.Fatal(err)
   851  	}
   852  	in = append(in, in...)
   853  	var e Encoder
   854  	start := time.Now()
   855  	dst := e.EncodeAll(in, nil)
   856  	t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
   857  	mbpersec := (float64(len(in)) / (1024 * 1024)) / (float64(time.Since(start)) / (float64(time.Second)))
   858  	t.Logf("Encoded %d bytes with %.2f MB/s", len(in), mbpersec)
   859  
   860  	dec, err := NewReader(nil)
   861  	if err != nil {
   862  		t.Fatal(err)
   863  	}
   864  	defer dec.Close()
   865  	decoded, err := dec.DecodeAll(dst, nil)
   866  	if err != nil {
   867  		t.Error(err, len(decoded))
   868  	}
   869  	if !bytes.Equal(decoded, in) {
   870  		t.Logf("size = %d, got = %d", len(decoded), len(in))
   871  		t.Fatal("Decoded does not match")
   872  	}
   873  	t.Log("Encoded content matched")
   874  
   875  	// Decode using reset+copy
   876  	for i := 0; i < 3; i++ {
   877  		err = dec.Reset(bytes.NewBuffer(dst))
   878  		if err != nil {
   879  			t.Fatal(err)
   880  		}
   881  		var dBuf bytes.Buffer
   882  		n, err := io.Copy(&dBuf, dec)
   883  		if err != nil {
   884  			t.Fatal(err)
   885  		}
   886  		decoded = dBuf.Bytes()
   887  		if int(n) != len(decoded) {
   888  			t.Fatalf("decoded reported length mismatch %d != %d", n, len(decoded))
   889  		}
   890  		if !bytes.Equal(decoded, in) {
   891  			os.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
   892  			os.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
   893  			t.Fatal("Decoded does not match")
   894  		}
   895  	}
   896  	// Test without WriterTo interface support.
   897  	for i := 0; i < 3; i++ {
   898  		err = dec.Reset(bytes.NewBuffer(dst))
   899  		if err != nil {
   900  			t.Fatal(err)
   901  		}
   902  		decoded, err := io.ReadAll(io.NopCloser(dec))
   903  		if err != nil {
   904  			t.Fatal(err)
   905  		}
   906  		if !bytes.Equal(decoded, in) {
   907  			os.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
   908  			os.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
   909  			t.Fatal("Decoded does not match")
   910  		}
   911  	}
   912  }
   913  
   914  func TestDecoderMultiFrame(t *testing.T) {
   915  	zr := testCreateZipReader("testdata/benchdecoder.zip", t)
   916  	dec, err := NewReader(nil)
   917  	if err != nil {
   918  		t.Fatal(err)
   919  		return
   920  	}
   921  	defer dec.Close()
   922  	for _, tt := range zr.File {
   923  		if !strings.HasSuffix(tt.Name, ".zst") {
   924  			continue
   925  		}
   926  		t.Run(tt.Name, func(t *testing.T) {
   927  			r, err := tt.Open()
   928  			if err != nil {
   929  				t.Fatal(err)
   930  			}
   931  			defer r.Close()
   932  			in, err := io.ReadAll(r)
   933  			if err != nil {
   934  				t.Fatal(err)
   935  			}
   936  			// 2x
   937  			in = append(in, in...)
   938  			if !testing.Short() {
   939  				// 4x
   940  				in = append(in, in...)
   941  				// 8x
   942  				in = append(in, in...)
   943  			}
   944  			err = dec.Reset(bytes.NewBuffer(in))
   945  			if err != nil {
   946  				t.Fatal(err)
   947  			}
   948  			got, err := io.ReadAll(dec)
   949  			if err != nil {
   950  				t.Fatal(err)
   951  			}
   952  			err = dec.Reset(bytes.NewBuffer(in))
   953  			if err != nil {
   954  				t.Fatal(err)
   955  			}
   956  			got2, err := io.ReadAll(dec)
   957  			if err != nil {
   958  				t.Fatal(err)
   959  			}
   960  			if !bytes.Equal(got, got2) {
   961  				t.Error("results mismatch")
   962  			}
   963  		})
   964  	}
   965  }
   966  
   967  func TestDecoderMultiFrameReset(t *testing.T) {
   968  	zr := testCreateZipReader("testdata/benchdecoder.zip", t)
   969  	dec, err := NewReader(nil)
   970  	if err != nil {
   971  		t.Fatal(err)
   972  		return
   973  	}
   974  	rng := rand.New(rand.NewSource(1337))
   975  	defer dec.Close()
   976  	for _, tt := range zr.File {
   977  		if !strings.HasSuffix(tt.Name, ".zst") {
   978  			continue
   979  		}
   980  		t.Run(tt.Name, func(t *testing.T) {
   981  			r, err := tt.Open()
   982  			if err != nil {
   983  				t.Fatal(err)
   984  			}
   985  			defer r.Close()
   986  			in, err := io.ReadAll(r)
   987  			if err != nil {
   988  				t.Fatal(err)
   989  			}
   990  			// 2x
   991  			in = append(in, in...)
   992  			if !testing.Short() {
   993  				// 4x
   994  				in = append(in, in...)
   995  				// 8x
   996  				in = append(in, in...)
   997  			}
   998  			err = dec.Reset(bytes.NewBuffer(in))
   999  			if err != nil {
  1000  				t.Fatal(err)
  1001  			}
  1002  			got, err := io.ReadAll(dec)
  1003  			if err != nil {
  1004  				t.Fatal(err)
  1005  			}
  1006  			err = dec.Reset(bytes.NewBuffer(in))
  1007  			if err != nil {
  1008  				t.Fatal(err)
  1009  			}
  1010  			// Read a random number of bytes
  1011  			tmp := make([]byte, rng.Intn(len(got)))
  1012  			_, err = io.ReadAtLeast(dec, tmp, len(tmp))
  1013  			if err != nil {
  1014  				t.Fatal(err)
  1015  			}
  1016  			err = dec.Reset(bytes.NewBuffer(in))
  1017  			if err != nil {
  1018  				t.Fatal(err)
  1019  			}
  1020  			got2, err := io.ReadAll(dec)
  1021  			if err != nil {
  1022  				t.Fatal(err)
  1023  			}
  1024  			if !bytes.Equal(got, got2) {
  1025  				t.Error("results mismatch")
  1026  			}
  1027  		})
  1028  	}
  1029  }
  1030  
  1031  func testDecoderFile(t *testing.T, fn string, newDec func() (*Decoder, error)) {
  1032  	zr := testCreateZipReader(fn, t)
  1033  	var want = make(map[string][]byte)
  1034  	for _, tt := range zr.File {
  1035  		if strings.HasSuffix(tt.Name, ".zst") {
  1036  			continue
  1037  		}
  1038  		r, err := tt.Open()
  1039  		if err != nil {
  1040  			t.Fatal(err)
  1041  			return
  1042  		}
  1043  		want[tt.Name+".zst"], _ = io.ReadAll(r)
  1044  	}
  1045  
  1046  	dec, err := newDec()
  1047  	if err != nil {
  1048  		t.Error(err)
  1049  		return
  1050  	}
  1051  	defer dec.Close()
  1052  	for i, tt := range zr.File {
  1053  		if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) {
  1054  			continue
  1055  		}
  1056  		t.Run("Reader-"+tt.Name, func(t *testing.T) {
  1057  			defer timeout(10 * time.Second)()
  1058  			r, err := tt.Open()
  1059  			if err != nil {
  1060  				t.Error(err)
  1061  				return
  1062  			}
  1063  			data, err := io.ReadAll(r)
  1064  			r.Close()
  1065  			if err != nil {
  1066  				t.Error(err)
  1067  				return
  1068  			}
  1069  			err = dec.Reset(io.NopCloser(bytes.NewBuffer(data)))
  1070  			if err != nil {
  1071  				t.Error(err)
  1072  				return
  1073  			}
  1074  			var got []byte
  1075  			var gotError error
  1076  			var wg sync.WaitGroup
  1077  			wg.Add(1)
  1078  			go func() {
  1079  				got, gotError = io.ReadAll(dec)
  1080  				wg.Done()
  1081  			}()
  1082  
  1083  			// This decode should not interfere with the stream...
  1084  			gotDecAll, err := dec.DecodeAll(data, nil)
  1085  			if err != nil {
  1086  				t.Error(err)
  1087  				if err != ErrCRCMismatch {
  1088  					wg.Wait()
  1089  					return
  1090  				}
  1091  			}
  1092  			wg.Wait()
  1093  			if gotError != nil {
  1094  				t.Error(gotError, err)
  1095  				if err != ErrCRCMismatch {
  1096  					return
  1097  				}
  1098  			}
  1099  
  1100  			wantB := want[tt.Name]
  1101  
  1102  			compareWith := func(got []byte, displayName, name string) bool {
  1103  				if bytes.Equal(wantB, got) {
  1104  					return false
  1105  				}
  1106  
  1107  				if len(wantB)+len(got) < 1000 {
  1108  					t.Logf(" got: %v\nwant: %v", got, wantB)
  1109  				} else {
  1110  					fileName, _ := filepath.Abs(filepath.Join("testdata", t.Name()+"-want.bin"))
  1111  					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
  1112  					err := os.WriteFile(fileName, wantB, os.ModePerm)
  1113  					t.Log("Wrote file", fileName, err)
  1114  
  1115  					fileName, _ = filepath.Abs(filepath.Join("testdata", t.Name()+"-"+name+".bin"))
  1116  					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
  1117  					err = os.WriteFile(fileName, got, os.ModePerm)
  1118  					t.Log("Wrote file", fileName, err)
  1119  				}
  1120  				t.Logf("Length, want: %d, got: %d", len(wantB), len(got))
  1121  				t.Errorf("%s mismatch", displayName)
  1122  				return true
  1123  			}
  1124  
  1125  			if compareWith(got, "Output", "got") {
  1126  				return
  1127  			}
  1128  
  1129  			if compareWith(gotDecAll, "DecodeAll Output", "decoded") {
  1130  				return
  1131  			}
  1132  
  1133  			t.Log(len(got), "bytes returned, matches input, ok!")
  1134  		})
  1135  	}
  1136  }
  1137  
  1138  func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error), errMap map[string]string) {
  1139  	zr := testCreateZipReader(fn, t)
  1140  	var want = make(map[string][]byte)
  1141  	for _, tt := range zr.File {
  1142  		if strings.HasSuffix(tt.Name, ".zst") {
  1143  			continue
  1144  		}
  1145  		r, err := tt.Open()
  1146  		if err != nil {
  1147  			t.Fatal(err)
  1148  			return
  1149  		}
  1150  		want[tt.Name+".zst"], _ = io.ReadAll(r)
  1151  	}
  1152  
  1153  	dec, err := newDec()
  1154  	if err != nil {
  1155  		t.Error(err)
  1156  		return
  1157  	}
  1158  	defer dec.Close()
  1159  	for _, tt := range zr.File {
  1160  		t.Run(tt.Name, func(t *testing.T) {
  1161  			defer timeout(10 * time.Second)()
  1162  			r, err := tt.Open()
  1163  			if err != nil {
  1164  				t.Error(err)
  1165  				return
  1166  			}
  1167  			defer r.Close()
  1168  			err = dec.Reset(r)
  1169  			if err != nil {
  1170  				t.Error(err)
  1171  				return
  1172  			}
  1173  			got, err := io.ReadAll(dec)
  1174  			if err == ErrCRCMismatch && !strings.Contains(tt.Name, "badsum") {
  1175  				t.Error(err)
  1176  				return
  1177  			}
  1178  			if err == nil {
  1179  				want := errMap[tt.Name]
  1180  				if want == "" {
  1181  					want = "<error>"
  1182  				}
  1183  				t.Error("Did not get expected error", want, "- got", len(got), "bytes")
  1184  				return
  1185  			}
  1186  			if errMap[tt.Name] == "" {
  1187  				errMap[tt.Name] = err.Error()
  1188  			} else {
  1189  				want := errMap[tt.Name]
  1190  				if want != err.Error() {
  1191  					t.Errorf("error mismatch, prev run got %s, now got %s", want, err.Error())
  1192  				}
  1193  				return
  1194  			}
  1195  			t.Log("got error", err)
  1196  		})
  1197  	}
  1198  }
  1199  
  1200  func BenchmarkDecoder_DecoderSmall(b *testing.B) {
  1201  	zr := testCreateZipReader("testdata/benchdecoder.zip", b)
  1202  	dec, err := NewReader(nil, WithDecodeBuffersBelow(1<<30))
  1203  	if err != nil {
  1204  		b.Fatal(err)
  1205  		return
  1206  	}
  1207  	defer dec.Close()
  1208  	dec2, err := NewReader(nil, WithDecodeBuffersBelow(0))
  1209  	if err != nil {
  1210  		b.Fatal(err)
  1211  		return
  1212  	}
  1213  	defer dec2.Close()
  1214  	for _, tt := range zr.File {
  1215  		if !strings.HasSuffix(tt.Name, ".zst") {
  1216  			continue
  1217  		}
  1218  		b.Run(tt.Name, func(b *testing.B) {
  1219  			r, err := tt.Open()
  1220  			if err != nil {
  1221  				b.Fatal(err)
  1222  			}
  1223  			defer r.Close()
  1224  			in, err := io.ReadAll(r)
  1225  			if err != nil {
  1226  				b.Fatal(err)
  1227  			}
  1228  			// 2x
  1229  			in = append(in, in...)
  1230  			// 4x
  1231  			in = append(in, in...)
  1232  			// 8x
  1233  			in = append(in, in...)
  1234  
  1235  			err = dec.Reset(bytes.NewBuffer(in))
  1236  			if err != nil {
  1237  				b.Fatal(err)
  1238  			}
  1239  			got, err := io.ReadAll(dec)
  1240  			if err != nil {
  1241  				b.Fatal(err)
  1242  			}
  1243  			b.Run("buffered", func(b *testing.B) {
  1244  				b.SetBytes(int64(len(got)))
  1245  				b.ReportAllocs()
  1246  				b.ResetTimer()
  1247  				for i := 0; i < b.N; i++ {
  1248  					err = dec.Reset(bytes.NewBuffer(in))
  1249  					if err != nil {
  1250  						b.Fatal(err)
  1251  					}
  1252  					n, err := io.Copy(io.Discard, dec)
  1253  					if err != nil {
  1254  						b.Fatal(err)
  1255  					}
  1256  					if int(n) != len(got) {
  1257  						b.Fatalf("want %d, got %d", len(got), n)
  1258  					}
  1259  
  1260  				}
  1261  			})
  1262  			b.Run("unbuffered", func(b *testing.B) {
  1263  				b.SetBytes(int64(len(got)))
  1264  				b.ReportAllocs()
  1265  				b.ResetTimer()
  1266  				for i := 0; i < b.N; i++ {
  1267  					err = dec2.Reset(bytes.NewBuffer(in))
  1268  					if err != nil {
  1269  						b.Fatal(err)
  1270  					}
  1271  					n, err := io.Copy(io.Discard, dec2)
  1272  					if err != nil {
  1273  						b.Fatal(err)
  1274  					}
  1275  					if int(n) != len(got) {
  1276  						b.Fatalf("want %d, got %d", len(got), n)
  1277  					}
  1278  				}
  1279  			})
  1280  		})
  1281  	}
  1282  }
  1283  
  1284  func BenchmarkDecoder_DecoderReset(b *testing.B) {
  1285  	zr := testCreateZipReader("testdata/benchdecoder.zip", b)
  1286  	dec, err := NewReader(nil, WithDecodeBuffersBelow(0))
  1287  	if err != nil {
  1288  		b.Fatal(err)
  1289  		return
  1290  	}
  1291  	defer dec.Close()
  1292  	bench := func(name string, b *testing.B, opts []DOption, in, want []byte) {
  1293  		b.Helper()
  1294  		buf := newBytesReader(in)
  1295  		dec, err := NewReader(nil, opts...)
  1296  		if err != nil {
  1297  			b.Fatal(err)
  1298  			return
  1299  		}
  1300  		defer dec.Close()
  1301  		b.Run(name, func(b *testing.B) {
  1302  			b.SetBytes(1)
  1303  			b.ReportAllocs()
  1304  			b.ResetTimer()
  1305  			for i := 0; i < b.N; i++ {
  1306  				buf.Reset(in)
  1307  				err = dec.Reset(buf)
  1308  				if err != nil {
  1309  					b.Fatal(err)
  1310  				}
  1311  			}
  1312  		})
  1313  	}
  1314  	for _, tt := range zr.File {
  1315  		if !strings.HasSuffix(tt.Name, ".zst") {
  1316  			continue
  1317  		}
  1318  		b.Run(tt.Name, func(b *testing.B) {
  1319  			r, err := tt.Open()
  1320  			if err != nil {
  1321  				b.Fatal(err)
  1322  			}
  1323  			defer r.Close()
  1324  			in, err := io.ReadAll(r)
  1325  			if err != nil {
  1326  				b.Fatal(err)
  1327  			}
  1328  
  1329  			got, err := dec.DecodeAll(in, nil)
  1330  			if err != nil {
  1331  				b.Fatal(err)
  1332  			}
  1333  			// Disable buffers:
  1334  			bench("stream", b, []DOption{WithDecodeBuffersBelow(0)}, in, got)
  1335  			bench("stream-single", b, []DOption{WithDecodeBuffersBelow(0), WithDecoderConcurrency(1)}, in, got)
  1336  			// Force buffers:
  1337  			bench("buffer", b, []DOption{WithDecodeBuffersBelow(1 << 30)}, in, got)
  1338  			bench("buffer-single", b, []DOption{WithDecodeBuffersBelow(1 << 30), WithDecoderConcurrency(1)}, in, got)
  1339  		})
  1340  	}
  1341  }
  1342  
  1343  // newBytesReader returns a *bytes.Reader that also supports Bytes() []byte
  1344  func newBytesReader(b []byte) *bytesReader {
  1345  	return &bytesReader{Reader: bytes.NewReader(b), buf: b}
  1346  }
  1347  
  1348  type bytesReader struct {
  1349  	*bytes.Reader
  1350  	buf []byte
  1351  }
  1352  
  1353  func (b *bytesReader) Bytes() []byte {
  1354  	n := b.Reader.Len()
  1355  	if n > len(b.buf) {
  1356  		panic("buffer mismatch")
  1357  	}
  1358  	return b.buf[len(b.buf)-n:]
  1359  }
  1360  
  1361  func (b *bytesReader) Reset(data []byte) {
  1362  	b.buf = data
  1363  	b.Reader.Reset(data)
  1364  }
  1365  
  1366  func BenchmarkDecoder_DecoderNewNoRead(b *testing.B) {
  1367  	zr := testCreateZipReader("testdata/benchdecoder.zip", b)
  1368  	dec, err := NewReader(nil)
  1369  	if err != nil {
  1370  		b.Fatal(err)
  1371  		return
  1372  	}
  1373  	defer dec.Close()
  1374  
  1375  	bench := func(name string, b *testing.B, opts []DOption, in, want []byte) {
  1376  		b.Helper()
  1377  		b.Run(name, func(b *testing.B) {
  1378  			buf := newBytesReader(in)
  1379  			b.SetBytes(1)
  1380  			b.ReportAllocs()
  1381  			b.ResetTimer()
  1382  			for i := 0; i < b.N; i++ {
  1383  				buf.Reset(in)
  1384  				dec, err := NewReader(buf, opts...)
  1385  				if err != nil {
  1386  					b.Fatal(err)
  1387  					return
  1388  				}
  1389  				dec.Close()
  1390  			}
  1391  		})
  1392  	}
  1393  	for _, tt := range zr.File {
  1394  		if !strings.HasSuffix(tt.Name, ".zst") {
  1395  			continue
  1396  		}
  1397  		b.Run(tt.Name, func(b *testing.B) {
  1398  			r, err := tt.Open()
  1399  			if err != nil {
  1400  				b.Fatal(err)
  1401  			}
  1402  			defer r.Close()
  1403  			in, err := io.ReadAll(r)
  1404  			if err != nil {
  1405  				b.Fatal(err)
  1406  			}
  1407  
  1408  			got, err := dec.DecodeAll(in, nil)
  1409  			if err != nil {
  1410  				b.Fatal(err)
  1411  			}
  1412  			// Disable buffers:
  1413  			bench("stream", b, []DOption{WithDecodeBuffersBelow(0)}, in, got)
  1414  			bench("stream-single", b, []DOption{WithDecodeBuffersBelow(0), WithDecoderConcurrency(1)}, in, got)
  1415  			// Force buffers:
  1416  			bench("buffer", b, []DOption{WithDecodeBuffersBelow(1 << 30)}, in, got)
  1417  			bench("buffer-single", b, []DOption{WithDecodeBuffersBelow(1 << 30), WithDecoderConcurrency(1)}, in, got)
  1418  		})
  1419  	}
  1420  }
  1421  
  1422  func BenchmarkDecoder_DecoderNewSomeRead(b *testing.B) {
  1423  	var buf [1 << 20]byte
  1424  	bench := func(name string, b *testing.B, opts []DOption, in *os.File) {
  1425  		b.Helper()
  1426  		b.Run(name, func(b *testing.B) {
  1427  			//b.ReportAllocs()
  1428  			b.ResetTimer()
  1429  			var heapTotal int64
  1430  			var m runtime.MemStats
  1431  			for i := 0; i < b.N; i++ {
  1432  				runtime.GC()
  1433  				runtime.ReadMemStats(&m)
  1434  				heapTotal -= int64(m.HeapInuse)
  1435  				_, err := in.Seek(io.SeekStart, 0)
  1436  				if err != nil {
  1437  					b.Fatal(err)
  1438  				}
  1439  
  1440  				dec, err := NewReader(in, opts...)
  1441  				if err != nil {
  1442  					b.Fatal(err)
  1443  				}
  1444  				// Read 16 MB
  1445  				_, err = io.CopyBuffer(io.Discard, io.LimitReader(dec, 16<<20), buf[:])
  1446  				if err != nil {
  1447  					b.Fatal(err)
  1448  				}
  1449  				runtime.GC()
  1450  				runtime.ReadMemStats(&m)
  1451  				heapTotal += int64(m.HeapInuse)
  1452  
  1453  				dec.Close()
  1454  			}
  1455  			b.ReportMetric(float64(heapTotal)/float64(b.N), "b/op")
  1456  		})
  1457  	}
  1458  	files := []string{"testdata/000002.map.win32K.zst", "testdata/000002.map.win1MB.zst", "testdata/000002.map.win8MB.zst"}
  1459  	for _, file := range files {
  1460  		if !strings.HasSuffix(file, ".zst") {
  1461  			continue
  1462  		}
  1463  		r, err := os.Open(file)
  1464  		if err != nil {
  1465  			b.Fatal(err)
  1466  		}
  1467  		defer r.Close()
  1468  
  1469  		b.Run(file, func(b *testing.B) {
  1470  			bench("stream-single", b, []DOption{WithDecodeBuffersBelow(0), WithDecoderConcurrency(1)}, r)
  1471  			bench("stream-single-himem", b, []DOption{WithDecodeBuffersBelow(0), WithDecoderConcurrency(1), WithDecoderLowmem(false)}, r)
  1472  		})
  1473  	}
  1474  }
  1475  
  1476  func BenchmarkDecoder_DecodeAll(b *testing.B) {
  1477  	zr := testCreateZipReader("testdata/benchdecoder.zip", b)
  1478  	dec, err := NewReader(nil, WithDecoderConcurrency(1))
  1479  	if err != nil {
  1480  		b.Fatal(err)
  1481  		return
  1482  	}
  1483  	defer dec.Close()
  1484  	for _, tt := range zr.File {
  1485  		if !strings.HasSuffix(tt.Name, ".zst") {
  1486  			continue
  1487  		}
  1488  		b.Run(tt.Name, func(b *testing.B) {
  1489  			r, err := tt.Open()
  1490  			if err != nil {
  1491  				b.Fatal(err)
  1492  			}
  1493  			defer r.Close()
  1494  			in, err := io.ReadAll(r)
  1495  			if err != nil {
  1496  				b.Fatal(err)
  1497  			}
  1498  			got, err := dec.DecodeAll(in, nil)
  1499  			if err != nil {
  1500  				b.Fatal(err)
  1501  			}
  1502  			b.SetBytes(int64(len(got)))
  1503  			b.ReportAllocs()
  1504  			b.ResetTimer()
  1505  			for i := 0; i < b.N; i++ {
  1506  				_, err = dec.DecodeAll(in, got[:0])
  1507  				if err != nil {
  1508  					b.Fatal(err)
  1509  				}
  1510  			}
  1511  		})
  1512  	}
  1513  }
  1514  
  1515  func BenchmarkDecoder_DecodeAllFiles(b *testing.B) {
  1516  	filepath.Walk("../testdata/", func(path string, info os.FileInfo, err error) error {
  1517  		if info.IsDir() || info.Size() < 100 {
  1518  			return nil
  1519  		}
  1520  		b.Run(filepath.Base(path), func(b *testing.B) {
  1521  			raw, err := os.ReadFile(path)
  1522  			if err != nil {
  1523  				b.Error(err)
  1524  			}
  1525  			for i := SpeedFastest; i <= SpeedBestCompression; i++ {
  1526  				if testing.Short() && i > SpeedFastest {
  1527  					break
  1528  				}
  1529  				b.Run(i.String(), func(b *testing.B) {
  1530  					enc, err := NewWriter(nil, WithEncoderLevel(i), WithSingleSegment(true))
  1531  					if err != nil {
  1532  						b.Error(err)
  1533  					}
  1534  					encoded := enc.EncodeAll(raw, nil)
  1535  					if err != nil {
  1536  						b.Error(err)
  1537  					}
  1538  					dec, err := NewReader(nil, WithDecoderConcurrency(1))
  1539  					if err != nil {
  1540  						b.Error(err)
  1541  					}
  1542  					decoded, err := dec.DecodeAll(encoded, nil)
  1543  					if err != nil {
  1544  						b.Error(err)
  1545  					}
  1546  					b.SetBytes(int64(len(raw)))
  1547  					b.ReportAllocs()
  1548  					b.ResetTimer()
  1549  					for i := 0; i < b.N; i++ {
  1550  						decoded, err = dec.DecodeAll(encoded, decoded[:0])
  1551  						if err != nil {
  1552  							b.Error(err)
  1553  						}
  1554  					}
  1555  					b.ReportMetric(100*float64(len(encoded))/float64(len(raw)), "pct")
  1556  				})
  1557  			}
  1558  		})
  1559  		return nil
  1560  	})
  1561  }
  1562  
  1563  func BenchmarkDecoder_DecodeAllFilesP(b *testing.B) {
  1564  	filepath.Walk("../testdata/", func(path string, info os.FileInfo, err error) error {
  1565  		if info.IsDir() || info.Size() < 100 {
  1566  			return nil
  1567  		}
  1568  		b.Run(filepath.Base(path), func(b *testing.B) {
  1569  			raw, err := os.ReadFile(path)
  1570  			if err != nil {
  1571  				b.Error(err)
  1572  			}
  1573  			for i := SpeedFastest; i <= SpeedBestCompression; i++ {
  1574  				if testing.Short() && i > SpeedFastest {
  1575  					break
  1576  				}
  1577  				b.Run(i.String(), func(b *testing.B) {
  1578  					enc, err := NewWriter(nil, WithEncoderLevel(i), WithSingleSegment(true))
  1579  					if err != nil {
  1580  						b.Error(err)
  1581  					}
  1582  					encoded := enc.EncodeAll(raw, nil)
  1583  					if err != nil {
  1584  						b.Error(err)
  1585  					}
  1586  					dec, err := NewReader(nil, WithDecoderConcurrency(0))
  1587  					if err != nil {
  1588  						b.Error(err)
  1589  					}
  1590  					raw, err := dec.DecodeAll(encoded, nil)
  1591  					if err != nil {
  1592  						b.Error(err)
  1593  					}
  1594  
  1595  					b.SetBytes(int64(len(raw)))
  1596  					b.ReportAllocs()
  1597  					b.ResetTimer()
  1598  					b.RunParallel(func(pb *testing.PB) {
  1599  						buf := make([]byte, cap(raw))
  1600  						var err error
  1601  						for pb.Next() {
  1602  							buf, err = dec.DecodeAll(encoded, buf[:0])
  1603  							if err != nil {
  1604  								b.Error(err)
  1605  							}
  1606  						}
  1607  					})
  1608  					b.ReportMetric(100*float64(len(encoded))/float64(len(raw)), "pct")
  1609  				})
  1610  			}
  1611  		})
  1612  		return nil
  1613  	})
  1614  }
  1615  
  1616  func BenchmarkDecoder_DecodeAllParallel(b *testing.B) {
  1617  	zr := testCreateZipReader("testdata/benchdecoder.zip", b)
  1618  	dec, err := NewReader(nil, WithDecoderConcurrency(runtime.GOMAXPROCS(0)))
  1619  	if err != nil {
  1620  		b.Fatal(err)
  1621  		return
  1622  	}
  1623  	defer dec.Close()
  1624  	for _, tt := range zr.File {
  1625  		if !strings.HasSuffix(tt.Name, ".zst") {
  1626  			continue
  1627  		}
  1628  		b.Run(tt.Name, func(b *testing.B) {
  1629  			r, err := tt.Open()
  1630  			if err != nil {
  1631  				b.Fatal(err)
  1632  			}
  1633  			defer r.Close()
  1634  			in, err := io.ReadAll(r)
  1635  			if err != nil {
  1636  				b.Fatal(err)
  1637  			}
  1638  			got, err := dec.DecodeAll(in, nil)
  1639  			if err != nil {
  1640  				b.Fatal(err)
  1641  			}
  1642  			b.SetBytes(int64(len(got)))
  1643  			b.ReportAllocs()
  1644  			b.ResetTimer()
  1645  			b.RunParallel(func(pb *testing.PB) {
  1646  				got := make([]byte, cap(got))
  1647  				for pb.Next() {
  1648  					_, err = dec.DecodeAll(in, got[:0])
  1649  					if err != nil {
  1650  						b.Fatal(err)
  1651  					}
  1652  				}
  1653  			})
  1654  			b.ReportMetric(100*float64(len(in))/float64(len(got)), "pct")
  1655  		})
  1656  	}
  1657  }
  1658  
  1659  func benchmarkDecoderWithFile(path string, b *testing.B) {
  1660  	_, err := os.Stat(path)
  1661  	if err != nil {
  1662  		if os.IsNotExist(err) {
  1663  			b.Skipf("Missing %s", path)
  1664  			return
  1665  		}
  1666  		b.Fatal(err)
  1667  	}
  1668  
  1669  	data, err := os.ReadFile(path)
  1670  	if err != nil {
  1671  		b.Fatal(err)
  1672  	}
  1673  	dec, err := NewReader(bytes.NewBuffer(data), WithDecoderLowmem(false), WithDecoderConcurrency(1))
  1674  	if err != nil {
  1675  		b.Fatal(err)
  1676  	}
  1677  	n, err := io.Copy(io.Discard, dec)
  1678  	if err != nil {
  1679  		b.Fatal(err)
  1680  	}
  1681  
  1682  	b.Run("multithreaded-writer", func(b *testing.B) {
  1683  		dec, err := NewReader(nil, WithDecoderLowmem(true))
  1684  		if err != nil {
  1685  			b.Fatal(err)
  1686  		}
  1687  		b.SetBytes(n)
  1688  		b.ReportAllocs()
  1689  		b.ResetTimer()
  1690  		for i := 0; i < b.N; i++ {
  1691  			err = dec.Reset(bytes.NewBuffer(data))
  1692  			if err != nil {
  1693  				b.Fatal(err)
  1694  			}
  1695  			_, err := io.CopyN(io.Discard, dec, n)
  1696  			if err != nil {
  1697  				b.Fatal(err)
  1698  			}
  1699  		}
  1700  	})
  1701  
  1702  	b.Run("multithreaded-writer-himem", func(b *testing.B) {
  1703  		dec, err := NewReader(nil, WithDecoderLowmem(false))
  1704  		if err != nil {
  1705  			b.Fatal(err)
  1706  		}
  1707  
  1708  		b.SetBytes(n)
  1709  		b.ReportAllocs()
  1710  		b.ResetTimer()
  1711  		for i := 0; i < b.N; i++ {
  1712  			err = dec.Reset(bytes.NewBuffer(data))
  1713  			if err != nil {
  1714  				b.Fatal(err)
  1715  			}
  1716  			_, err := io.CopyN(io.Discard, dec, n)
  1717  			if err != nil {
  1718  				b.Fatal(err)
  1719  			}
  1720  		}
  1721  	})
  1722  
  1723  	b.Run("singlethreaded-writer", func(b *testing.B) {
  1724  		dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true))
  1725  		if err != nil {
  1726  			b.Fatal(err)
  1727  		}
  1728  
  1729  		b.SetBytes(n)
  1730  		b.ReportAllocs()
  1731  		b.ResetTimer()
  1732  		for i := 0; i < b.N; i++ {
  1733  			err = dec.Reset(bytes.NewBuffer(data))
  1734  			if err != nil {
  1735  				b.Fatal(err)
  1736  			}
  1737  			_, err := io.CopyN(io.Discard, dec, n)
  1738  			if err != nil {
  1739  				b.Fatal(err)
  1740  			}
  1741  		}
  1742  	})
  1743  
  1744  	b.Run("singlethreaded-writerto", func(b *testing.B) {
  1745  		dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(true))
  1746  		if err != nil {
  1747  			b.Fatal(err)
  1748  		}
  1749  
  1750  		b.SetBytes(n)
  1751  		b.ReportAllocs()
  1752  		b.ResetTimer()
  1753  		for i := 0; i < b.N; i++ {
  1754  			err = dec.Reset(bytes.NewBuffer(data))
  1755  			if err != nil {
  1756  				b.Fatal(err)
  1757  			}
  1758  			// io.Copy will use io.WriterTo
  1759  			_, err := io.Copy(io.Discard, dec)
  1760  			if err != nil {
  1761  				b.Fatal(err)
  1762  			}
  1763  		}
  1764  	})
  1765  	b.Run("singlethreaded-himem", func(b *testing.B) {
  1766  		dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderLowmem(false))
  1767  		if err != nil {
  1768  			b.Fatal(err)
  1769  		}
  1770  
  1771  		b.SetBytes(n)
  1772  		b.ReportAllocs()
  1773  		b.ResetTimer()
  1774  		for i := 0; i < b.N; i++ {
  1775  			err = dec.Reset(bytes.NewBuffer(data))
  1776  			if err != nil {
  1777  				b.Fatal(err)
  1778  			}
  1779  			// io.Copy will use io.WriterTo
  1780  			_, err := io.Copy(io.Discard, dec)
  1781  			if err != nil {
  1782  				b.Fatal(err)
  1783  			}
  1784  		}
  1785  	})
  1786  }
  1787  
  1788  func BenchmarkDecoderSilesia(b *testing.B) {
  1789  	benchmarkDecoderWithFile("testdata/silesia.tar.zst", b)
  1790  }
  1791  
  1792  func BenchmarkDecoderEnwik9(b *testing.B) {
  1793  	benchmarkDecoderWithFile("testdata/enwik9.zst", b)
  1794  }
  1795  
  1796  func BenchmarkDecoderWithCustomFiles(b *testing.B) {
  1797  	const info = "To run benchmark on custom .zst files, please place your files in subdirectory 'testdata/benchmark-custom'.\nEach file is tested in a separate benchmark, thus it is possible to select files with the standard command 'go test -bench BenchmarkDecoderWithCustomFiles/<pattern>."
  1798  
  1799  	const subdir = "testdata/benchmark-custom"
  1800  
  1801  	if _, err := os.Stat(subdir); os.IsNotExist(err) {
  1802  		b.Skip(info)
  1803  	}
  1804  
  1805  	files, err := filepath.Glob(filepath.Join(subdir, "*.zst"))
  1806  	if err != nil {
  1807  		b.Error(err)
  1808  		return
  1809  	}
  1810  
  1811  	if len(files) == 0 {
  1812  		b.Skip(info)
  1813  	}
  1814  
  1815  	for _, path := range files {
  1816  		name := filepath.Base(path)
  1817  		b.Run(name, func(b *testing.B) { benchmarkDecoderWithFile(path, b) })
  1818  	}
  1819  }
  1820  
  1821  func testDecoderDecodeAll(t *testing.T, fn string, dec *Decoder) {
  1822  	zr := testCreateZipReader(fn, t)
  1823  	var want = make(map[string][]byte)
  1824  	for _, tt := range zr.File {
  1825  		if strings.HasSuffix(tt.Name, ".zst") {
  1826  			continue
  1827  		}
  1828  		r, err := tt.Open()
  1829  		if err != nil {
  1830  			t.Fatal(err)
  1831  			return
  1832  		}
  1833  		want[tt.Name+".zst"], _ = io.ReadAll(r)
  1834  	}
  1835  	var wg sync.WaitGroup
  1836  	for i, tt := range zr.File {
  1837  		tt := tt
  1838  		if !strings.HasSuffix(tt.Name, ".zst") || (testing.Short() && i > 20) {
  1839  			continue
  1840  		}
  1841  		wg.Add(1)
  1842  		t.Run("DecodeAll-"+tt.Name, func(t *testing.T) {
  1843  			defer wg.Done()
  1844  			t.Parallel()
  1845  			r, err := tt.Open()
  1846  			if err != nil {
  1847  				t.Fatal(err)
  1848  			}
  1849  			in, err := io.ReadAll(r)
  1850  			if err != nil {
  1851  				t.Fatal(err)
  1852  			}
  1853  			wantB := want[tt.Name]
  1854  			// make a buffer that is too small.
  1855  			got, err := dec.DecodeAll(in, make([]byte, 10, 200))
  1856  			if err != nil {
  1857  				t.Error(err)
  1858  			}
  1859  			if len(got) < 10 {
  1860  				t.Fatal("didn't get input back")
  1861  			}
  1862  			got = got[10:]
  1863  			if !bytes.Equal(wantB, got) {
  1864  				if len(wantB)+len(got) < 1000 {
  1865  					t.Logf(" got: %v\nwant: %v", got, wantB)
  1866  				} else {
  1867  					fileName, _ := filepath.Abs(filepath.Join("testdata", t.Name()+"-want.bin"))
  1868  					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
  1869  					err := os.WriteFile(fileName, wantB, os.ModePerm)
  1870  					t.Log("Wrote file", fileName, err)
  1871  
  1872  					fileName, _ = filepath.Abs(filepath.Join("testdata", t.Name()+"-got.bin"))
  1873  					_ = os.MkdirAll(filepath.Dir(fileName), os.ModePerm)
  1874  					err = os.WriteFile(fileName, got, os.ModePerm)
  1875  					t.Log("Wrote file", fileName, err)
  1876  				}
  1877  				t.Logf("Length, want: %d, got: %d", len(wantB), len(got))
  1878  				t.Error("Output mismatch")
  1879  				return
  1880  			}
  1881  			t.Log(len(got), "bytes returned, matches input, ok!")
  1882  		})
  1883  	}
  1884  	go func() {
  1885  		wg.Wait()
  1886  		dec.Close()
  1887  	}()
  1888  }
  1889  
  1890  func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder, errMap map[string]string) {
  1891  	zr := testCreateZipReader(fn, t)
  1892  
  1893  	var wg sync.WaitGroup
  1894  	for _, tt := range zr.File {
  1895  		tt := tt
  1896  		if !strings.HasSuffix(tt.Name, ".zst") {
  1897  			continue
  1898  		}
  1899  		wg.Add(1)
  1900  		t.Run(tt.Name, func(t *testing.T) {
  1901  			defer wg.Done()
  1902  			r, err := tt.Open()
  1903  			if err != nil {
  1904  				t.Fatal(err)
  1905  			}
  1906  			in, err := io.ReadAll(r)
  1907  			if err != nil {
  1908  				t.Fatal(err)
  1909  			}
  1910  			// make a buffer that is small.
  1911  			got, err := dec.DecodeAll(in, make([]byte, 0, 20))
  1912  			if err == nil {
  1913  				t.Error("Did not get expected error, got", len(got), "bytes")
  1914  				return
  1915  			}
  1916  			t.Log(err)
  1917  			if errMap[tt.Name] == "" {
  1918  				t.Error("cannot check error")
  1919  			} else {
  1920  				want := errMap[tt.Name]
  1921  				if want != err.Error() {
  1922  					if want == ErrFrameSizeMismatch.Error() && err == ErrDecoderSizeExceeded {
  1923  						return
  1924  					}
  1925  					if want == ErrWindowSizeExceeded.Error() && err == ErrDecoderSizeExceeded {
  1926  						return
  1927  					}
  1928  					t.Errorf("error mismatch, prev run got %s, now got %s", want, err.Error())
  1929  				}
  1930  				return
  1931  			}
  1932  		})
  1933  	}
  1934  	go func() {
  1935  		wg.Wait()
  1936  		dec.Close()
  1937  	}()
  1938  }
  1939  
  1940  // Test our predefined tables are correct.
  1941  // We don't predefine them, since this also tests our transformations.
  1942  // Reference from here: https://github.com/facebook/zstd/blob/ededcfca57366461021c922720878c81a5854a0a/lib/decompress/zstd_decompress_block.c#L234
  1943  func TestPredefTables(t *testing.T) {
  1944  	initPredefined()
  1945  	x := func(nextState uint16, nbAddBits, nbBits uint8, baseVal uint32) decSymbol {
  1946  		return newDecSymbol(nbBits, nbAddBits, nextState, baseVal)
  1947  	}
  1948  	for i := range fsePredef[:] {
  1949  		var want []decSymbol
  1950  		switch tableIndex(i) {
  1951  		case tableLiteralLengths:
  1952  			want = []decSymbol{
  1953  				/* nextState, nbAddBits, nbBits, baseVal */
  1954  				x(0, 0, 4, 0), x(16, 0, 4, 0),
  1955  				x(32, 0, 5, 1), x(0, 0, 5, 3),
  1956  				x(0, 0, 5, 4), x(0, 0, 5, 6),
  1957  				x(0, 0, 5, 7), x(0, 0, 5, 9),
  1958  				x(0, 0, 5, 10), x(0, 0, 5, 12),
  1959  				x(0, 0, 6, 14), x(0, 1, 5, 16),
  1960  				x(0, 1, 5, 20), x(0, 1, 5, 22),
  1961  				x(0, 2, 5, 28), x(0, 3, 5, 32),
  1962  				x(0, 4, 5, 48), x(32, 6, 5, 64),
  1963  				x(0, 7, 5, 128), x(0, 8, 6, 256),
  1964  				x(0, 10, 6, 1024), x(0, 12, 6, 4096),
  1965  				x(32, 0, 4, 0), x(0, 0, 4, 1),
  1966  				x(0, 0, 5, 2), x(32, 0, 5, 4),
  1967  				x(0, 0, 5, 5), x(32, 0, 5, 7),
  1968  				x(0, 0, 5, 8), x(32, 0, 5, 10),
  1969  				x(0, 0, 5, 11), x(0, 0, 6, 13),
  1970  				x(32, 1, 5, 16), x(0, 1, 5, 18),
  1971  				x(32, 1, 5, 22), x(0, 2, 5, 24),
  1972  				x(32, 3, 5, 32), x(0, 3, 5, 40),
  1973  				x(0, 6, 4, 64), x(16, 6, 4, 64),
  1974  				x(32, 7, 5, 128), x(0, 9, 6, 512),
  1975  				x(0, 11, 6, 2048), x(48, 0, 4, 0),
  1976  				x(16, 0, 4, 1), x(32, 0, 5, 2),
  1977  				x(32, 0, 5, 3), x(32, 0, 5, 5),
  1978  				x(32, 0, 5, 6), x(32, 0, 5, 8),
  1979  				x(32, 0, 5, 9), x(32, 0, 5, 11),
  1980  				x(32, 0, 5, 12), x(0, 0, 6, 15),
  1981  				x(32, 1, 5, 18), x(32, 1, 5, 20),
  1982  				x(32, 2, 5, 24), x(32, 2, 5, 28),
  1983  				x(32, 3, 5, 40), x(32, 4, 5, 48),
  1984  				x(0, 16, 6, 65536), x(0, 15, 6, 32768),
  1985  				x(0, 14, 6, 16384), x(0, 13, 6, 8192),
  1986  			}
  1987  		case tableOffsets:
  1988  			want = []decSymbol{
  1989  				/* nextState, nbAddBits, nbBits, baseVal */
  1990  				x(0, 0, 5, 0), x(0, 6, 4, 61),
  1991  				x(0, 9, 5, 509), x(0, 15, 5, 32765),
  1992  				x(0, 21, 5, 2097149), x(0, 3, 5, 5),
  1993  				x(0, 7, 4, 125), x(0, 12, 5, 4093),
  1994  				x(0, 18, 5, 262141), x(0, 23, 5, 8388605),
  1995  				x(0, 5, 5, 29), x(0, 8, 4, 253),
  1996  				x(0, 14, 5, 16381), x(0, 20, 5, 1048573),
  1997  				x(0, 2, 5, 1), x(16, 7, 4, 125),
  1998  				x(0, 11, 5, 2045), x(0, 17, 5, 131069),
  1999  				x(0, 22, 5, 4194301), x(0, 4, 5, 13),
  2000  				x(16, 8, 4, 253), x(0, 13, 5, 8189),
  2001  				x(0, 19, 5, 524285), x(0, 1, 5, 1),
  2002  				x(16, 6, 4, 61), x(0, 10, 5, 1021),
  2003  				x(0, 16, 5, 65533), x(0, 28, 5, 268435453),
  2004  				x(0, 27, 5, 134217725), x(0, 26, 5, 67108861),
  2005  				x(0, 25, 5, 33554429), x(0, 24, 5, 16777213),
  2006  			}
  2007  		case tableMatchLengths:
  2008  			want = []decSymbol{
  2009  				/* nextState, nbAddBits, nbBits, baseVal */
  2010  				x(0, 0, 6, 3), x(0, 0, 4, 4),
  2011  				x(32, 0, 5, 5), x(0, 0, 5, 6),
  2012  				x(0, 0, 5, 8), x(0, 0, 5, 9),
  2013  				x(0, 0, 5, 11), x(0, 0, 6, 13),
  2014  				x(0, 0, 6, 16), x(0, 0, 6, 19),
  2015  				x(0, 0, 6, 22), x(0, 0, 6, 25),
  2016  				x(0, 0, 6, 28), x(0, 0, 6, 31),
  2017  				x(0, 0, 6, 34), x(0, 1, 6, 37),
  2018  				x(0, 1, 6, 41), x(0, 2, 6, 47),
  2019  				x(0, 3, 6, 59), x(0, 4, 6, 83),
  2020  				x(0, 7, 6, 131), x(0, 9, 6, 515),
  2021  				x(16, 0, 4, 4), x(0, 0, 4, 5),
  2022  				x(32, 0, 5, 6), x(0, 0, 5, 7),
  2023  				x(32, 0, 5, 9), x(0, 0, 5, 10),
  2024  				x(0, 0, 6, 12), x(0, 0, 6, 15),
  2025  				x(0, 0, 6, 18), x(0, 0, 6, 21),
  2026  				x(0, 0, 6, 24), x(0, 0, 6, 27),
  2027  				x(0, 0, 6, 30), x(0, 0, 6, 33),
  2028  				x(0, 1, 6, 35), x(0, 1, 6, 39),
  2029  				x(0, 2, 6, 43), x(0, 3, 6, 51),
  2030  				x(0, 4, 6, 67), x(0, 5, 6, 99),
  2031  				x(0, 8, 6, 259), x(32, 0, 4, 4),
  2032  				x(48, 0, 4, 4), x(16, 0, 4, 5),
  2033  				x(32, 0, 5, 7), x(32, 0, 5, 8),
  2034  				x(32, 0, 5, 10), x(32, 0, 5, 11),
  2035  				x(0, 0, 6, 14), x(0, 0, 6, 17),
  2036  				x(0, 0, 6, 20), x(0, 0, 6, 23),
  2037  				x(0, 0, 6, 26), x(0, 0, 6, 29),
  2038  				x(0, 0, 6, 32), x(0, 16, 6, 65539),
  2039  				x(0, 15, 6, 32771), x(0, 14, 6, 16387),
  2040  				x(0, 13, 6, 8195), x(0, 12, 6, 4099),
  2041  				x(0, 11, 6, 2051), x(0, 10, 6, 1027),
  2042  			}
  2043  		}
  2044  		pre := fsePredef[i]
  2045  		got := pre.dt[:1<<pre.actualTableLog]
  2046  		if !reflect.DeepEqual(got, want) {
  2047  			t.Logf("want: %v", want)
  2048  			t.Logf("got : %v", got)
  2049  			t.Errorf("Predefined table %d incorrect, len(got) = %d, len(want) = %d", i, len(got), len(want))
  2050  		}
  2051  	}
  2052  }
  2053  
  2054  func TestResetNil(t *testing.T) {
  2055  	dec, err := NewReader(nil)
  2056  	if err != nil {
  2057  		t.Fatal(err)
  2058  	}
  2059  	defer dec.Close()
  2060  
  2061  	_, err = io.ReadAll(dec)
  2062  	if err != ErrDecoderNilInput {
  2063  		t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err)
  2064  	}
  2065  
  2066  	emptyZstdBlob := []byte{40, 181, 47, 253, 32, 0, 1, 0, 0}
  2067  
  2068  	dec.Reset(bytes.NewBuffer(emptyZstdBlob))
  2069  
  2070  	result, err := io.ReadAll(dec)
  2071  	if err != nil && err != io.EOF {
  2072  		t.Fatal(err)
  2073  	}
  2074  	if len(result) != 0 {
  2075  		t.Fatalf("Expected to read 0 bytes, actually read %d", len(result))
  2076  	}
  2077  
  2078  	dec.Reset(nil)
  2079  
  2080  	_, err = io.ReadAll(dec)
  2081  	if err != ErrDecoderNilInput {
  2082  		t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err)
  2083  	}
  2084  
  2085  	dec.Reset(bytes.NewBuffer(emptyZstdBlob))
  2086  
  2087  	result, err = io.ReadAll(dec)
  2088  	if err != nil && err != io.EOF {
  2089  		t.Fatal(err)
  2090  	}
  2091  	if len(result) != 0 {
  2092  		t.Fatalf("Expected to read 0 bytes, actually read %d", len(result))
  2093  	}
  2094  }
  2095  
  2096  func TestIgnoreChecksum(t *testing.T) {
  2097  	// zstd file containing text "compress\n" and has an xxhash checksum
  2098  	zstdBlob := []byte{0x28, 0xb5, 0x2f, 0xfd, 0x24, 0x09, 0x49, 0x00, 0x00, 'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n', 0x79, 0x6e, 0xe0, 0xd2}
  2099  
  2100  	// replace letter 'c' with 'C', so decoding should fail.
  2101  	zstdBlob[9] = 'C'
  2102  
  2103  	{
  2104  		// Check if the file is indeed incorrect
  2105  		dec, err := NewReader(nil)
  2106  		if err != nil {
  2107  			t.Fatal(err)
  2108  		}
  2109  		defer dec.Close()
  2110  
  2111  		dec.Reset(bytes.NewBuffer(zstdBlob))
  2112  
  2113  		_, err = io.ReadAll(dec)
  2114  		if err == nil {
  2115  			t.Fatal("Expected decoding error")
  2116  		}
  2117  
  2118  		if !errors.Is(err, ErrCRCMismatch) {
  2119  			t.Fatalf("Expected checksum error, got '%s'", err)
  2120  		}
  2121  	}
  2122  
  2123  	{
  2124  		// Ignore CRC error and decompress the content
  2125  		dec, err := NewReader(nil, IgnoreChecksum(true))
  2126  		if err != nil {
  2127  			t.Fatal(err)
  2128  		}
  2129  		defer dec.Close()
  2130  
  2131  		dec.Reset(bytes.NewBuffer(zstdBlob))
  2132  
  2133  		res, err := io.ReadAll(dec)
  2134  		if err != nil {
  2135  			t.Fatalf("Unexpected error: '%s'", err)
  2136  		}
  2137  
  2138  		want := []byte{'C', 'o', 'm', 'p', 'r', 'e', 's', 's', '\n'}
  2139  		if !bytes.Equal(res, want) {
  2140  			t.Logf("want: %s", want)
  2141  			t.Logf("got:  %s", res)
  2142  			t.Fatalf("Wrong output")
  2143  		}
  2144  	}
  2145  }
  2146  
  2147  func timeout(after time.Duration) (cancel func()) {
  2148  	if isRaceTest {
  2149  		return func() {}
  2150  	}
  2151  	c := time.After(after)
  2152  	cc := make(chan struct{})
  2153  	go func() {
  2154  		select {
  2155  		case <-cc:
  2156  			return
  2157  		case <-c:
  2158  			buf := make([]byte, 1<<20)
  2159  			stacklen := runtime.Stack(buf, true)
  2160  			log.Printf("=== Timeout, assuming deadlock ===\n*** goroutine dump...\n%s\n*** end\n", string(buf[:stacklen]))
  2161  			os.Exit(2)
  2162  		}
  2163  	}()
  2164  	return func() {
  2165  		close(cc)
  2166  	}
  2167  }
  2168  
  2169  func TestWithDecodeAllCapLimit(t *testing.T) {
  2170  	var encs []*Encoder
  2171  	var decs []*Decoder
  2172  	addEnc := func(e *Encoder, _ error) {
  2173  		encs = append(encs, e)
  2174  	}
  2175  	addDec := func(d *Decoder, _ error) {
  2176  		decs = append(decs, d)
  2177  	}
  2178  	addEnc(NewWriter(nil, WithZeroFrames(true), WithWindowSize(4<<10)))
  2179  	addEnc(NewWriter(nil, WithEncoderConcurrency(1), WithWindowSize(4<<10)))
  2180  	addEnc(NewWriter(nil, WithZeroFrames(false), WithWindowSize(4<<10)))
  2181  	addEnc(NewWriter(nil, WithWindowSize(128<<10)))
  2182  	addDec(NewReader(nil, WithDecodeAllCapLimit(true)))
  2183  	addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderConcurrency(1)))
  2184  	addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderLowmem(true)))
  2185  	addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxWindow(128<<10)))
  2186  	addDec(NewReader(nil, WithDecodeAllCapLimit(true), WithDecoderMaxMemory(1<<20)))
  2187  	for sz := 0; sz < 1<<20; sz = (sz + 1) * 2 {
  2188  		sz := sz
  2189  		t.Run(strconv.Itoa(sz), func(t *testing.T) {
  2190  			t.Parallel()
  2191  			for ei, enc := range encs {
  2192  				for di, dec := range decs {
  2193  					t.Run(fmt.Sprintf("e%d:d%d", ei, di), func(t *testing.T) {
  2194  						encoded := enc.EncodeAll(make([]byte, sz), nil)
  2195  						for i := sz - 1; i < sz+1; i++ {
  2196  							if i < 0 {
  2197  								continue
  2198  							}
  2199  							const existinglen = 5
  2200  							got, err := dec.DecodeAll(encoded, make([]byte, existinglen, i+existinglen))
  2201  							if i < sz {
  2202  								if err != ErrDecoderSizeExceeded {
  2203  									t.Errorf("cap: %d, want %v, got %v", i, ErrDecoderSizeExceeded, err)
  2204  								}
  2205  							} else {
  2206  								if err != nil {
  2207  									t.Errorf("cap: %d, want %v, got %v", i, nil, err)
  2208  									continue
  2209  								}
  2210  								if len(got) != existinglen+i {
  2211  									t.Errorf("cap: %d, want output size %d, got %d", i, existinglen+i, len(got))
  2212  								}
  2213  							}
  2214  						}
  2215  					})
  2216  				}
  2217  			}
  2218  		})
  2219  	}
  2220  }
  2221  

View as plain text