...

Source file src/github.com/klauspost/compress/internal/fuzz/helpers.go

Documentation: github.com/klauspost/compress/internal/fuzz

     1  // Copyright (c) 2024+ Klaus Post. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package fuzz provides a way to add test cases to a testing.F instance from a zip file.
     6  package fuzz
     7  
     8  import (
     9  	"archive/zip"
    10  	"bytes"
    11  	"encoding/binary"
    12  	"fmt"
    13  	"go/ast"
    14  	"go/parser"
    15  	"go/token"
    16  	"io"
    17  	"os"
    18  	"strconv"
    19  	"testing"
    20  )
    21  
    22  type InputType uint8
    23  
    24  const (
    25  	// TypeRaw indicates that files are raw bytes.
    26  	TypeRaw InputType = iota
    27  	// TypeGoFuzz indicates files are from Go Fuzzer.
    28  	TypeGoFuzz
    29  	// TypeOSSFuzz indicates that files are from OSS fuzzer with size before data.
    30  	TypeOSSFuzz
    31  )
    32  
    33  // AddFromZip will read the supplied zip and add all as corpus for f.
    34  // Byte slices only.
    35  func AddFromZip(f *testing.F, filename string, t InputType, short bool) {
    36  	file, err := os.Open(filename)
    37  	if err != nil {
    38  		f.Fatal(err)
    39  	}
    40  	fi, err := file.Stat()
    41  	if fi == nil {
    42  		return
    43  	}
    44  
    45  	if err != nil {
    46  		f.Fatal(err)
    47  	}
    48  	zr, err := zip.NewReader(file, fi.Size())
    49  	if err != nil {
    50  		f.Fatal(err)
    51  	}
    52  	for i, file := range zr.File {
    53  		if short && i%10 != 0 {
    54  			continue
    55  		}
    56  		rc, err := file.Open()
    57  		if err != nil {
    58  			f.Fatal(err)
    59  		}
    60  
    61  		b, err := io.ReadAll(rc)
    62  		if err != nil {
    63  			f.Fatal(err)
    64  		}
    65  		rc.Close()
    66  		t := t
    67  		if t == TypeOSSFuzz {
    68  			t = TypeRaw // Fallback
    69  			if len(b) >= 4 {
    70  				sz := binary.BigEndian.Uint32(b)
    71  				if sz <= uint32(len(b))-4 {
    72  					f.Add(b[4 : 4+sz])
    73  					continue
    74  				}
    75  			}
    76  		}
    77  
    78  		if bytes.HasPrefix(b, []byte("go test fuzz")) {
    79  			t = TypeGoFuzz
    80  		} else {
    81  			t = TypeRaw
    82  		}
    83  
    84  		if t == TypeRaw {
    85  			f.Add(b)
    86  			continue
    87  		}
    88  		vals, err := unmarshalCorpusFile(b)
    89  		if err != nil {
    90  			f.Fatal(err)
    91  		}
    92  		for _, v := range vals {
    93  			f.Add(v)
    94  		}
    95  	}
    96  }
    97  
    98  // ReturnFromZip will read the supplied zip and add all as corpus for f.
    99  // Byte slices only.
   100  func ReturnFromZip(tb testing.TB, filename string, t InputType, fn func([]byte)) {
   101  	file, err := os.Open(filename)
   102  	if err != nil {
   103  		tb.Fatal(err)
   104  	}
   105  	fi, err := file.Stat()
   106  	if fi == nil {
   107  		return
   108  	}
   109  	if err != nil {
   110  		tb.Fatal(err)
   111  	}
   112  	zr, err := zip.NewReader(file, fi.Size())
   113  	if err != nil {
   114  		tb.Fatal(err)
   115  	}
   116  	for _, file := range zr.File {
   117  		rc, err := file.Open()
   118  		if err != nil {
   119  			tb.Fatal(err)
   120  		}
   121  
   122  		b, err := io.ReadAll(rc)
   123  		if err != nil {
   124  			tb.Fatal(err)
   125  		}
   126  		rc.Close()
   127  		t := t
   128  		if t == TypeOSSFuzz {
   129  			t = TypeRaw // Fallback
   130  			if len(b) >= 4 {
   131  				sz := binary.BigEndian.Uint32(b)
   132  				if sz <= uint32(len(b))-4 {
   133  					fn(b[4 : 4+sz])
   134  					continue
   135  				}
   136  			}
   137  		}
   138  
   139  		if bytes.HasPrefix(b, []byte("go test fuzz")) {
   140  			t = TypeGoFuzz
   141  		} else {
   142  			t = TypeRaw
   143  		}
   144  
   145  		if t == TypeRaw {
   146  			fn(b)
   147  			continue
   148  		}
   149  		vals, err := unmarshalCorpusFile(b)
   150  		if err != nil {
   151  			tb.Fatal(err)
   152  		}
   153  		for _, v := range vals {
   154  			fn(v)
   155  		}
   156  	}
   157  }
   158  
   159  // unmarshalCorpusFile decodes corpus bytes into their respective values.
   160  func unmarshalCorpusFile(b []byte) ([][]byte, error) {
   161  	if len(b) == 0 {
   162  		return nil, fmt.Errorf("cannot unmarshal empty string")
   163  	}
   164  	lines := bytes.Split(b, []byte("\n"))
   165  	if len(lines) < 2 {
   166  		return nil, fmt.Errorf("must include version and at least one value")
   167  	}
   168  	var vals = make([][]byte, 0, len(lines)-1)
   169  	for _, line := range lines[1:] {
   170  		line = bytes.TrimSpace(line)
   171  		if len(line) == 0 {
   172  			continue
   173  		}
   174  		v, err := parseCorpusValue(line)
   175  		if err != nil {
   176  			return nil, fmt.Errorf("malformed line %q: %v", line, err)
   177  		}
   178  		vals = append(vals, v)
   179  	}
   180  	return vals, nil
   181  }
   182  
   183  // parseCorpusValue
   184  func parseCorpusValue(line []byte) ([]byte, error) {
   185  	fs := token.NewFileSet()
   186  	expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	call, ok := expr.(*ast.CallExpr)
   191  	if !ok {
   192  		return nil, fmt.Errorf("expected call expression")
   193  	}
   194  	if len(call.Args) != 1 {
   195  		return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
   196  	}
   197  	arg := call.Args[0]
   198  
   199  	if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
   200  		if arrayType.Len != nil {
   201  			return nil, fmt.Errorf("expected []byte or primitive type")
   202  		}
   203  		elt, ok := arrayType.Elt.(*ast.Ident)
   204  		if !ok || elt.Name != "byte" {
   205  			return nil, fmt.Errorf("expected []byte")
   206  		}
   207  		lit, ok := arg.(*ast.BasicLit)
   208  		if !ok || lit.Kind != token.STRING {
   209  			return nil, fmt.Errorf("string literal required for type []byte")
   210  		}
   211  		s, err := strconv.Unquote(lit.Value)
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  		return []byte(s), nil
   216  	}
   217  	return nil, fmt.Errorf("expected []byte")
   218  }
   219  

View as plain text