...

Source file src/gotest.tools/v3/assert/cmd/gty-migrate-from-testify/main.go

Documentation: gotest.tools/v3/assert/cmd/gty-migrate-from-testify

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/token"
    11  	"log"
    12  	"os"
    13  	"path"
    14  	"path/filepath"
    15  	"strings"
    16  
    17  	"golang.org/x/tools/go/packages"
    18  	"golang.org/x/tools/imports"
    19  )
    20  
    21  type options struct {
    22  	pkgs             []string
    23  	dryRun           bool
    24  	debug            bool
    25  	cmpImportName    string
    26  	showLoaderErrors bool
    27  	buildFlags       []string
    28  	localImportPath  string
    29  }
    30  
    31  func main() {
    32  	name := os.Args[0]
    33  	flags, opts := setupFlags(name)
    34  	handleExitError(name, flags.Parse(os.Args[1:]))
    35  	setupLogging(opts)
    36  	opts.pkgs = flags.Args()
    37  	handleExitError(name, run(*opts))
    38  }
    39  
    40  func setupLogging(opts *options) {
    41  	log.SetFlags(0)
    42  	enableDebug = opts.debug
    43  }
    44  
    45  var enableDebug = false
    46  
    47  func debugf(msg string, args ...interface{}) {
    48  	if enableDebug {
    49  		log.Printf("DEBUG: "+msg, args...)
    50  	}
    51  }
    52  
    53  func setupFlags(name string) (*flag.FlagSet, *options) {
    54  	opts := options{}
    55  	flags := flag.NewFlagSet(name, flag.ContinueOnError)
    56  	flags.BoolVar(&opts.dryRun, "dry-run", false,
    57  		"don't write changes to file")
    58  	flags.BoolVar(&opts.debug, "debug", false, "enable debug logging")
    59  	flags.StringVar(&opts.cmpImportName, "cmp-pkg-import-alias", "is",
    60  		"import alias to use for the assert/cmp package")
    61  	flags.BoolVar(&opts.showLoaderErrors, "print-loader-errors", false,
    62  		"print errors from loading source")
    63  	flags.Var((*stringSliceValue)(&opts.buildFlags), "build-flags",
    64  		"build flags to pass to Go when loading source files")
    65  	flags.StringVar(&opts.localImportPath, "local-import-path", "",
    66  		"value to pass to 'goimports -local' flag for sorting local imports")
    67  	flags.Usage = func() {
    68  		fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]
    69  
    70  Migrate calls from testify/{assert|require} to gotest.tools/v3/assert.
    71  
    72  `, name)
    73  		flags.PrintDefaults()
    74  	}
    75  	return flags, &opts
    76  }
    77  
    78  func handleExitError(name string, err error) {
    79  	switch {
    80  	case err == nil:
    81  		return
    82  	case errors.Is(err, flag.ErrHelp):
    83  		os.Exit(0)
    84  	default:
    85  		log.Println(name + ": Error: " + err.Error())
    86  		os.Exit(3)
    87  	}
    88  }
    89  
    90  func run(opts options) error {
    91  	imports.LocalPrefix = opts.localImportPath
    92  
    93  	fset := token.NewFileSet()
    94  	pkgs, err := loadPackages(opts, fset)
    95  	if err != nil {
    96  		return fmt.Errorf("failed to load program: %w", err)
    97  	}
    98  
    99  	debugf("package count: %d", len(pkgs))
   100  	for _, pkg := range pkgs {
   101  		debugf("file count for package %v: %d", pkg.PkgPath, len(pkg.Syntax))
   102  		for _, astFile := range pkg.Syntax {
   103  			absFilename := fset.File(astFile.Pos()).Name()
   104  			filename := relativePath(absFilename)
   105  			importNames := newImportNames(astFile.Imports, opts)
   106  			if !importNames.hasTestifyImports() {
   107  				debugf("skipping file %s, no imports", filename)
   108  				continue
   109  			}
   110  
   111  			debugf("migrating %s with imports: %#v", filename, importNames)
   112  			m := migration{
   113  				file:        astFile,
   114  				fileset:     fset,
   115  				importNames: importNames,
   116  				pkgInfo:     pkg.TypesInfo,
   117  			}
   118  			migrateFile(m)
   119  			if opts.dryRun {
   120  				continue
   121  			}
   122  
   123  			raw, err := formatFile(m)
   124  			if err != nil {
   125  				return fmt.Errorf("failed to format %s: %w", filename, err)
   126  			}
   127  
   128  			if err := os.WriteFile(absFilename, raw, 0); err != nil {
   129  				return fmt.Errorf("failed to write file %s: %w", filename, err)
   130  			}
   131  		}
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  var loadMode = packages.NeedName |
   138  	packages.NeedFiles |
   139  	packages.NeedCompiledGoFiles |
   140  	packages.NeedDeps |
   141  	packages.NeedImports |
   142  	packages.NeedTypes |
   143  	packages.NeedTypesInfo |
   144  	packages.NeedTypesSizes |
   145  	packages.NeedSyntax
   146  
   147  func loadPackages(opts options, fset *token.FileSet) ([]*packages.Package, error) {
   148  	conf := &packages.Config{
   149  		Mode:       loadMode,
   150  		Fset:       fset,
   151  		Tests:      true,
   152  		Logf:       debugf,
   153  		BuildFlags: opts.buildFlags,
   154  	}
   155  
   156  	pkgs, err := packages.Load(conf, opts.pkgs...)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	if opts.showLoaderErrors {
   161  		packages.PrintErrors(pkgs)
   162  	}
   163  	return pkgs, nil
   164  }
   165  
   166  func relativePath(p string) string {
   167  	cwd, err := os.Getwd()
   168  	if err != nil {
   169  		return p
   170  	}
   171  	rel, err := filepath.Rel(cwd, p)
   172  	if err != nil {
   173  		return p
   174  	}
   175  	return rel
   176  }
   177  
   178  type importNames struct {
   179  	testifyAssert  string
   180  	testifyRequire string
   181  	assert         string
   182  	cmp            string
   183  }
   184  
   185  func (p importNames) hasTestifyImports() bool {
   186  	return p.testifyAssert != "" || p.testifyRequire != ""
   187  }
   188  
   189  func (p importNames) matchesTestify(ident *ast.Ident) bool {
   190  	return ident.Name == p.testifyAssert || ident.Name == p.testifyRequire
   191  }
   192  
   193  func (p importNames) funcNameFromTestifyName(name string) string {
   194  	switch name {
   195  	case p.testifyAssert:
   196  		return funcNameCheck
   197  	case p.testifyRequire:
   198  		return funcNameAssert
   199  	default:
   200  		panic("unexpected testify import name " + name)
   201  	}
   202  }
   203  
   204  func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
   205  	defaultAssertAlias := path.Base(pkgAssert)
   206  	importNames := importNames{
   207  		assert: defaultAssertAlias,
   208  		cmp:    path.Base(pkgCmp),
   209  	}
   210  	for _, spec := range imports {
   211  		switch strings.Trim(spec.Path.Value, `"`) {
   212  		case pkgTestifyAssert, pkgGopkgTestifyAssert:
   213  			importNames.testifyAssert = identOrDefault(spec.Name, "assert")
   214  		case pkgTestifyRequire, pkgGopkgTestifyRequire:
   215  			importNames.testifyRequire = identOrDefault(spec.Name, "require")
   216  		default:
   217  			pkgPath := strings.Trim(spec.Path.Value, `"`)
   218  
   219  			switch {
   220  			// v3/assert is already imported and has an alias
   221  			case pkgPath == pkgAssert:
   222  				if spec.Name != nil && spec.Name.Name != "" {
   223  					importNames.assert = spec.Name.Name
   224  				}
   225  				continue
   226  
   227  			// some other package is imported as assert
   228  			case importedAs(spec, path.Base(pkgAssert)) && importNames.assert == defaultAssertAlias:
   229  				importNames.assert = "gtyassert"
   230  			}
   231  		}
   232  	}
   233  
   234  	if opt.cmpImportName != "" {
   235  		importNames.cmp = opt.cmpImportName
   236  	}
   237  	return importNames
   238  }
   239  
   240  func importedAs(spec *ast.ImportSpec, pkg string) bool {
   241  	if path.Base(strings.Trim(spec.Path.Value, `"`)) == pkg {
   242  		return true
   243  	}
   244  	return spec.Name != nil && spec.Name.Name == pkg
   245  }
   246  
   247  func identOrDefault(ident *ast.Ident, def string) string {
   248  	if ident != nil {
   249  		return ident.Name
   250  	}
   251  	return def
   252  }
   253  
   254  func formatFile(migration migration) ([]byte, error) {
   255  	buf := new(bytes.Buffer)
   256  	err := format.Node(buf, migration.fileset, migration.file)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	filename := migration.fileset.File(migration.file.Pos()).Name()
   261  	return imports.Process(filename, buf.Bytes(), nil)
   262  }
   263  

View as plain text