...

Source file src/github.com/klauspost/compress/zstd/fuzz_test.go

Documentation: github.com/klauspost/compress/zstd

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  package zstd
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	rdebug "runtime/debug"
    11  	"testing"
    12  
    13  	"github.com/klauspost/compress/internal/cpuinfo"
    14  	"github.com/klauspost/compress/internal/fuzz"
    15  )
    16  
    17  func FuzzDecodeAll(f *testing.F) {
    18  	fuzz.AddFromZip(f, "testdata/decode-regression.zip", fuzz.TypeRaw, false)
    19  	fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
    20  	fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
    21  
    22  	f.Fuzz(func(t *testing.T, b []byte) {
    23  		// Just test if we crash...
    24  		defer func() {
    25  			if r := recover(); r != nil {
    26  				rdebug.PrintStack()
    27  				t.Fatal(r)
    28  			}
    29  		}()
    30  
    31  		decLow, err := NewReader(nil, WithDecoderLowmem(true), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true))
    32  		if err != nil {
    33  			t.Fatal(err)
    34  		}
    35  		defer decLow.Close()
    36  		decHi, err := NewReader(nil, WithDecoderLowmem(false), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true))
    37  		if err != nil {
    38  			t.Fatal(err)
    39  		}
    40  		defer decHi.Close()
    41  		b1, err1 := decLow.DecodeAll(b, make([]byte, 0, len(b)))
    42  		b2, err2 := decHi.DecodeAll(b, make([]byte, 0, len(b)))
    43  		if err1 != err2 {
    44  			if (err1 == nil) != (err2 == nil) {
    45  				t.Errorf("err low: %v, hi: %v", err1, err2)
    46  			}
    47  		}
    48  		if err1 != nil {
    49  			b1, b2 = b1[:0], b2[:0]
    50  		}
    51  		if !bytes.Equal(b1, b2) {
    52  			t.Fatalf("Output mismatch, low: %v, hi: %v", err1, err2)
    53  		}
    54  	})
    55  }
    56  
    57  func FuzzDecAllNoBMI2(f *testing.F) {
    58  	if !cpuinfo.HasBMI2() {
    59  		f.Skip("No BMI, so already tested")
    60  		return
    61  	}
    62  	defer cpuinfo.DisableBMI2()()
    63  	FuzzDecodeAll(f)
    64  }
    65  
    66  func FuzzDecoder(f *testing.F) {
    67  	fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
    68  	fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
    69  	//fuzz.AddFromZip(f, "testdata/fuzz/decode-oss.zip", fuzz.TypeOSSFuzz, false)
    70  
    71  	brLow := newBytesReader(nil)
    72  	brHi := newBytesReader(nil)
    73  	f.Fuzz(func(t *testing.T, b []byte) {
    74  		// Just test if we crash...
    75  		defer func() {
    76  			if r := recover(); r != nil {
    77  				rdebug.PrintStack()
    78  				t.Fatal(r)
    79  			}
    80  		}()
    81  		brLow.Reset(b)
    82  		brHi.Reset(b)
    83  		decLow, err := NewReader(brLow, WithDecoderLowmem(true), WithDecoderConcurrency(2), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true), WithDecodeBuffersBelow(8<<10))
    84  		if err != nil {
    85  			t.Fatal(err)
    86  		}
    87  		defer decLow.Close()
    88  
    89  		// Test with high memory, but sync decoding
    90  		decHi, err := NewReader(brHi, WithDecoderLowmem(false), WithDecoderConcurrency(1), WithDecoderMaxMemory(20<<20), WithDecoderMaxWindow(1<<20), IgnoreChecksum(true), WithDecodeBuffersBelow(8<<10))
    91  		if err != nil {
    92  			t.Fatal(err)
    93  		}
    94  		defer decHi.Close()
    95  
    96  		if debugDecoder {
    97  			fmt.Println("LOW CONCURRENT")
    98  		}
    99  		b1, err1 := io.ReadAll(decLow)
   100  
   101  		if debugDecoder {
   102  			fmt.Println("HI NOT CONCURRENT")
   103  		}
   104  		b2, err2 := io.ReadAll(decHi)
   105  		if err1 != err2 {
   106  			if (err1 == nil) != (err2 == nil) {
   107  				t.Errorf("err low concurrent: %v, hi: %v", err1, err2)
   108  			}
   109  		}
   110  		if err1 != nil {
   111  			b1, b2 = b1[:0], b2[:0]
   112  		}
   113  		if !bytes.Equal(b1, b2) {
   114  			t.Fatalf("Output mismatch, low concurrent: %v, hi: %v", err1, err2)
   115  		}
   116  	})
   117  }
   118  
   119  func FuzzNoBMI2Dec(f *testing.F) {
   120  	if !cpuinfo.HasBMI2() {
   121  		f.Skip("No BMI, so already tested")
   122  		return
   123  	}
   124  	defer cpuinfo.DisableBMI2()()
   125  	FuzzDecoder(f)
   126  }
   127  
   128  func FuzzEncoding(f *testing.F) {
   129  	fuzz.AddFromZip(f, "testdata/fuzz/encode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
   130  	fuzz.AddFromZip(f, "testdata/comp-crashers.zip", fuzz.TypeRaw, false)
   131  	fuzz.AddFromZip(f, "testdata/fuzz/encode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
   132  	// Fuzzing tweaks:
   133  	const (
   134  		// Test a subset of encoders.
   135  		startFuzz = SpeedFastest
   136  		endFuzz   = SpeedBestCompression
   137  
   138  		// Also tests with dictionaries...
   139  		testDicts = true
   140  	)
   141  
   142  	var dec *Decoder
   143  	var encs [SpeedBestCompression + 1]*Encoder
   144  	var encsD [SpeedBestCompression + 1]*Encoder
   145  
   146  	var dicts [][]byte
   147  	if testDicts {
   148  		zr := testCreateZipReader("testdata/dict-tests-small.zip", f)
   149  		dicts = readDicts(f, zr)
   150  	}
   151  
   152  	if testing.Short() && *fuzzEndF > int(SpeedBetterCompression) {
   153  		*fuzzEndF = int(SpeedBetterCompression)
   154  	}
   155  
   156  	initEnc := func() func() {
   157  		var err error
   158  		dec, err = NewReader(nil, WithDecoderConcurrency(2), WithDecoderDicts(dicts...), WithDecoderMaxWindow(64<<10), WithDecoderMaxMemory(uint64(*fuzzMaxF)))
   159  		if err != nil {
   160  			panic(err)
   161  		}
   162  		for level := startFuzz; level <= endFuzz; level++ {
   163  			encs[level], err = NewWriter(nil, WithEncoderCRC(true), WithEncoderLevel(level), WithEncoderConcurrency(2), WithWindowSize(64<<10), WithZeroFrames(true), WithLowerEncoderMem(true))
   164  			if testDicts {
   165  				encsD[level], err = NewWriter(nil, WithEncoderCRC(true), WithEncoderLevel(level), WithEncoderConcurrency(2), WithWindowSize(64<<10), WithZeroFrames(true), WithEncoderDict(dicts[int(level)%len(dicts)]), WithLowerEncoderMem(true), WithLowerEncoderMem(true))
   166  			}
   167  		}
   168  		return func() {
   169  			dec.Close()
   170  			for _, enc := range encs {
   171  				if enc != nil {
   172  					enc.Close()
   173  				}
   174  			}
   175  			if testDicts {
   176  				for _, enc := range encsD {
   177  					if enc != nil {
   178  						enc.Close()
   179  					}
   180  				}
   181  			}
   182  		}
   183  	}
   184  
   185  	f.Cleanup(initEnc())
   186  
   187  	var dst bytes.Buffer
   188  
   189  	f.Fuzz(func(t *testing.T, data []byte) {
   190  		// Just test if we crash...
   191  		defer func() {
   192  			if r := recover(); r != nil {
   193  				stack := rdebug.Stack()
   194  				t.Fatalf("%v:\n%v", r, string(stack))
   195  			}
   196  		}()
   197  		if len(data) > *fuzzMaxF {
   198  			return
   199  		}
   200  		var bufSize = len(data)
   201  		if bufSize > 2 {
   202  			// Make deterministic size
   203  			bufSize = int(data[0]) | int(data[1])<<8
   204  			if bufSize >= len(data) {
   205  				bufSize = len(data) / 2
   206  			}
   207  		}
   208  
   209  		for level := *fuzzStartF; level <= *fuzzEndF; level++ {
   210  			enc := encs[level]
   211  			dst.Reset()
   212  			enc.Reset(&dst)
   213  			n, err := enc.Write(data)
   214  			if err != nil {
   215  				t.Fatal(err)
   216  			}
   217  			if n != len(data) {
   218  				t.Fatal(fmt.Sprintln("Level", level, "Short write, got:", n, "want:", len(data)))
   219  			}
   220  
   221  			encoded := enc.EncodeAll(data, make([]byte, 0, bufSize))
   222  			if len(encoded) > enc.MaxEncodedSize(len(data)) {
   223  				t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
   224  			}
   225  
   226  			got, err := dec.DecodeAll(encoded, make([]byte, 0, bufSize))
   227  			if err != nil {
   228  				t.Fatal(fmt.Sprintln("Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
   229  			}
   230  			if !bytes.Equal(got, data) {
   231  				t.Fatal(fmt.Sprintln("Level", level, "DecodeAll output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded)))
   232  			}
   233  
   234  			err = enc.Close()
   235  			if err != nil {
   236  				t.Fatal(fmt.Sprintln("Level", level, "Close (buffer) error:", err))
   237  			}
   238  			encoded2 := dst.Bytes()
   239  			if len(encoded2) > enc.MaxEncodedSize(len(data)) {
   240  				t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
   241  			}
   242  			if !bytes.Equal(encoded, encoded2) {
   243  				got, err = dec.DecodeAll(encoded2, got[:0])
   244  				if err != nil {
   245  					t.Fatal(fmt.Sprintln("Level", level, "DecodeAll (buffer) error:", err, "\norg:", len(data), "\nencoded", len(encoded2)))
   246  				}
   247  				if !bytes.Equal(got, data) {
   248  					t.Fatal(fmt.Sprintln("Level", level, "DecodeAll (buffer) output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded2)))
   249  				}
   250  			}
   251  			if !testDicts {
   252  				continue
   253  			}
   254  			enc = encsD[level]
   255  			dst.Reset()
   256  			enc.Reset(&dst)
   257  			n, err = enc.Write(data)
   258  			if err != nil {
   259  				t.Fatal(err)
   260  			}
   261  			if n != len(data) {
   262  				t.Fatal(fmt.Sprintln("Dict Level", level, "Short write, got:", n, "want:", len(data)))
   263  			}
   264  
   265  			encoded = enc.EncodeAll(data, encoded[:0])
   266  			if len(encoded) > enc.MaxEncodedSize(len(data)) {
   267  				t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded), enc.MaxEncodedSize(len(data)))
   268  			}
   269  			got, err = dec.DecodeAll(encoded, got[:0])
   270  			if err != nil {
   271  				t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll error:", err, "\norg:", len(data), "\nencoded", len(encoded)))
   272  			}
   273  			if !bytes.Equal(got, data) {
   274  				t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded)))
   275  			}
   276  
   277  			err = enc.Close()
   278  			if err != nil {
   279  				t.Fatal(fmt.Sprintln("Dict Level", level, "Close (buffer) error:", err))
   280  			}
   281  			encoded2 = dst.Bytes()
   282  			if len(encoded2) > enc.MaxEncodedSize(len(data)) {
   283  				t.Errorf("max encoded size for %v: got: %d, want max: %d", len(data), len(encoded2), enc.MaxEncodedSize(len(data)))
   284  			}
   285  			if !bytes.Equal(encoded, encoded2) {
   286  				got, err = dec.DecodeAll(encoded2, got[:0])
   287  				if err != nil {
   288  					t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll (buffer) error:", err, "\norg:", len(data), "\nencoded", len(encoded2)))
   289  				}
   290  				if !bytes.Equal(got, data) {
   291  					t.Fatal(fmt.Sprintln("Dict Level", level, "DecodeAll (buffer) output mismatch\n", len(got), "org: \n", len(data), "(want)", "\nencoded:", len(encoded2)))
   292  				}
   293  			}
   294  		}
   295  	})
   296  }
   297  

View as plain text