...

Source file src/gotest.tools/v3/assert/cmd/gty-migrate-from-testify/migrate.go

Documentation: gotest.tools/v3/assert/cmd/gty-migrate-from-testify

     1  package main
     2  
     3  import (
     4  	"go/ast"
     5  	"go/token"
     6  	"go/types"
     7  	"log"
     8  	"path"
     9  
    10  	"golang.org/x/tools/go/ast/astutil"
    11  )
    12  
    13  const (
    14  	pkgTestifyAssert       = "github.com/stretchr/testify/assert"
    15  	pkgGopkgTestifyAssert  = "gopkg.in/stretchr/testify.v1/assert"
    16  	pkgTestifyRequire      = "github.com/stretchr/testify/require"
    17  	pkgGopkgTestifyRequire = "gopkg.in/stretchr/testify.v1/require"
    18  	pkgAssert              = "gotest.tools/v3/assert"
    19  	pkgCmp                 = "gotest.tools/v3/assert/cmp"
    20  )
    21  
    22  const (
    23  	funcNameAssert = "Assert"
    24  	funcNameCheck  = "Check"
    25  )
    26  
    27  var allTestifyPks = []string{
    28  	pkgTestifyAssert,
    29  	pkgTestifyRequire,
    30  	pkgGopkgTestifyAssert,
    31  	pkgGopkgTestifyRequire,
    32  }
    33  
    34  type migration struct {
    35  	file        *ast.File
    36  	fileset     *token.FileSet
    37  	importNames importNames
    38  	pkgInfo     *types.Info
    39  }
    40  
    41  func migrateFile(migration migration) {
    42  	astutil.Apply(migration.file, nil, replaceCalls(migration))
    43  	updateImports(migration)
    44  }
    45  
    46  func updateImports(migration migration) {
    47  	for _, remove := range allTestifyPks {
    48  		astutil.DeleteImport(migration.fileset, migration.file, remove)
    49  	}
    50  
    51  	var alias string
    52  	if migration.importNames.assert != path.Base(pkgAssert) {
    53  		alias = migration.importNames.assert
    54  	}
    55  	astutil.AddNamedImport(migration.fileset, migration.file, alias, pkgAssert)
    56  
    57  	if migration.importNames.cmp != path.Base(pkgCmp) {
    58  		alias = migration.importNames.cmp
    59  	}
    60  	astutil.AddNamedImport(migration.fileset, migration.file, alias, pkgCmp)
    61  }
    62  
    63  type emptyNode struct{}
    64  
    65  func (n emptyNode) Pos() token.Pos {
    66  	return 0
    67  }
    68  
    69  func (n emptyNode) End() token.Pos {
    70  	return 0
    71  }
    72  
    73  var removeNode = emptyNode{}
    74  
    75  func replaceCalls(migration migration) func(cursor *astutil.Cursor) bool {
    76  	return func(cursor *astutil.Cursor) bool {
    77  		var newNode ast.Node
    78  		switch typed := cursor.Node().(type) {
    79  		case *ast.SelectorExpr:
    80  			newNode = getReplacementTestingT(typed, migration.importNames)
    81  		case *ast.CallExpr:
    82  			newNode = getReplacementAssertion(typed, migration)
    83  		case *ast.AssignStmt:
    84  			newNode = getReplacementAssignment(typed, migration)
    85  		}
    86  
    87  		switch newNode {
    88  		case nil:
    89  		case removeNode:
    90  			cursor.Delete()
    91  		default:
    92  			cursor.Replace(newNode)
    93  		}
    94  		return true
    95  	}
    96  }
    97  
    98  func getReplacementTestingT(selector *ast.SelectorExpr, names importNames) ast.Node {
    99  	xIdent, ok := selector.X.(*ast.Ident)
   100  	if !ok {
   101  		return nil
   102  	}
   103  	if selector.Sel.Name != "TestingT" || !names.matchesTestify(xIdent) {
   104  		return nil
   105  	}
   106  	return &ast.SelectorExpr{
   107  		X:   &ast.Ident{Name: names.assert, NamePos: xIdent.NamePos},
   108  		Sel: selector.Sel,
   109  	}
   110  }
   111  
   112  func getReplacementAssertion(callExpr *ast.CallExpr, migration migration) ast.Node {
   113  	tcall, ok := newTestifyCallFromNode(callExpr, migration)
   114  	if !ok {
   115  		return nil
   116  	}
   117  	if len(tcall.expr.Args) < 2 {
   118  		return convertTestifySingleArgCall(tcall)
   119  	}
   120  	return convertTestifyAssertion(tcall, migration)
   121  }
   122  
   123  func getReplacementAssignment(assign *ast.AssignStmt, migration migration) ast.Node {
   124  	if isAssignmentFromAssertNew(assign, migration) {
   125  		return removeNode
   126  	}
   127  	return nil
   128  }
   129  
   130  func convertTestifySingleArgCall(tcall call) ast.Node {
   131  	switch tcall.selExpr.Sel.Name {
   132  	case "TestingT":
   133  		// handled as SelectorExpr
   134  		return nil
   135  	case "New":
   136  		// handled by getReplacementAssignment
   137  		return nil
   138  	default:
   139  		log.Printf("%s: skipping unknown selector", tcall.StringWithFileInfo())
   140  		return nil
   141  	}
   142  }
   143  
   144  func convertTestifyAssertion(tcall call, migration migration) ast.Node {
   145  	imports := migration.importNames
   146  
   147  	switch tcall.selExpr.Sel.Name {
   148  	case "NoError", "NoErrorf":
   149  		return convertNoError(tcall, imports)
   150  	case "True", "Truef":
   151  		return convertTrue(tcall, imports)
   152  	case "False", "Falsef":
   153  		return convertFalse(tcall, imports)
   154  	case "Equal", "Equalf", "Exactly", "Exactlyf", "EqualValues", "EqualValuesf":
   155  		return convertEqual(tcall, migration)
   156  	case "Contains", "Containsf":
   157  		return convertTwoArgComparison(tcall, imports, "Contains")
   158  	case "Len", "Lenf":
   159  		return convertTwoArgComparison(tcall, imports, "Len")
   160  	case "Panics", "Panicsf":
   161  		return convertOneArgComparison(tcall, imports, "Panics")
   162  	case "EqualError", "EqualErrorf":
   163  		return convertEqualError(tcall, imports)
   164  	case "Error", "Errorf":
   165  		return convertError(tcall, imports)
   166  	case "ErrorContains", "ErrorContainsf":
   167  		return convertErrorContains(tcall, imports)
   168  	case "Empty", "Emptyf":
   169  		return convertEmpty(tcall, imports)
   170  	case "Nil", "Nilf":
   171  		return convertNil(tcall, migration)
   172  	case "NotNil", "NotNilf":
   173  		return convertNegativeComparison(tcall, imports, &ast.Ident{Name: "nil"}, 2)
   174  	case "NotEqual", "NotEqualf":
   175  		return convertNegativeComparison(tcall, imports, tcall.arg(2), 3)
   176  	case "Fail", "Failf":
   177  		return convertFail(tcall, "Error")
   178  	case "FailNow", "FailNowf":
   179  		return convertFail(tcall, "Fatal")
   180  	case "NotEmpty", "NotEmptyf":
   181  		return convertNotEmpty(tcall, imports)
   182  	case "NotZero", "NotZerof":
   183  		zero := &ast.BasicLit{Kind: token.INT, Value: "0"}
   184  		return convertNegativeComparison(tcall, imports, zero, 2)
   185  	}
   186  	log.Printf("%s: skipping unsupported assertion", tcall.StringWithFileInfo())
   187  	return nil
   188  }
   189  
   190  func newCallExpr(x, sel string, args []ast.Expr) *ast.CallExpr {
   191  	return &ast.CallExpr{
   192  		Fun: &ast.SelectorExpr{
   193  			X:   &ast.Ident{Name: x},
   194  			Sel: &ast.Ident{Name: sel},
   195  		},
   196  		Args: args,
   197  	}
   198  }
   199  
   200  func newCallExprArgs(t ast.Expr, cmp ast.Expr, extra ...ast.Expr) []ast.Expr {
   201  	return append(append([]ast.Expr{t}, cmp), extra...)
   202  }
   203  
   204  func newCallExprWithPosition(tcall call, imports importNames, args []ast.Expr) *ast.CallExpr {
   205  	return &ast.CallExpr{
   206  		Fun: &ast.SelectorExpr{
   207  			X: &ast.Ident{
   208  				Name:    imports.assert,
   209  				NamePos: tcall.xIdent.NamePos,
   210  			},
   211  			Sel: &ast.Ident{Name: tcall.assert},
   212  		},
   213  		Args: args,
   214  	}
   215  }
   216  
   217  func convertNoError(tcall call, imports importNames) ast.Node {
   218  	// use assert.NilError() for require.NoError()
   219  	if tcall.assert == funcNameAssert {
   220  		return newCallExprWithoutComparison(tcall, imports, "NilError")
   221  	}
   222  	// use assert.Check() for assert.NoError()
   223  	return newCallExprWithoutComparison(tcall, imports, "Check")
   224  }
   225  
   226  func convertEqualError(tcall call, imports importNames) ast.Node {
   227  	if tcall.assert == funcNameAssert {
   228  		return newCallExprWithoutComparison(tcall, imports, "Error")
   229  	}
   230  	return convertTwoArgComparison(tcall, imports, "Error")
   231  }
   232  
   233  func newCallExprWithoutComparison(tcall call, imports importNames, name string) ast.Node {
   234  	return &ast.CallExpr{
   235  		Fun: &ast.SelectorExpr{
   236  			X: &ast.Ident{
   237  				Name:    imports.assert,
   238  				NamePos: tcall.xIdent.NamePos,
   239  			},
   240  			Sel: &ast.Ident{Name: name},
   241  		},
   242  		Args: tcall.expr.Args,
   243  	}
   244  }
   245  
   246  func convertOneArgComparison(tcall call, imports importNames, cmpName string) ast.Node {
   247  	return newCallExprWithPosition(tcall, imports,
   248  		newCallExprArgs(
   249  			tcall.testingT(),
   250  			newCallExpr(imports.cmp, cmpName, []ast.Expr{tcall.arg(1)}),
   251  			tcall.extraArgs(2)...))
   252  }
   253  
   254  func convertTrue(tcall call, imports importNames) ast.Node {
   255  	return newCallExprWithPosition(tcall, imports, tcall.expr.Args)
   256  }
   257  
   258  func convertFalse(tcall call, imports importNames) ast.Node {
   259  	return newCallExprWithPosition(tcall, imports,
   260  		newCallExprArgs(
   261  			tcall.testingT(),
   262  			&ast.UnaryExpr{Op: token.NOT, X: tcall.arg(1)},
   263  			tcall.extraArgs(2)...))
   264  }
   265  
   266  func convertEqual(tcall call, migration migration) ast.Node {
   267  	imports := migration.importNames
   268  
   269  	hasExtraArgs := len(tcall.extraArgs(3)) > 0
   270  
   271  	cmpEqual := convertTwoArgComparison(tcall, imports, "Equal")
   272  	if tcall.assert == funcNameAssert {
   273  		cmpEqual = newCallExprWithoutComparison(tcall, imports, "Equal")
   274  	}
   275  	cmpDeepEqual := convertTwoArgComparison(tcall, imports, "DeepEqual")
   276  	if tcall.assert == funcNameAssert && !hasExtraArgs {
   277  		cmpDeepEqual = newCallExprWithoutComparison(tcall, imports, "DeepEqual")
   278  	}
   279  
   280  	gotype := walkForType(migration.pkgInfo, tcall.arg(1))
   281  	if isUnknownType(gotype) {
   282  		gotype = walkForType(migration.pkgInfo, tcall.arg(2))
   283  	}
   284  	if isUnknownType(gotype) {
   285  		return cmpDeepEqual
   286  	}
   287  
   288  	switch gotype.Underlying().(type) {
   289  	case *types.Basic:
   290  		return cmpEqual
   291  	default:
   292  		return cmpDeepEqual
   293  	}
   294  }
   295  
   296  func convertTwoArgComparison(tcall call, imports importNames, cmpName string) ast.Node {
   297  	return newCallExprWithPosition(tcall, imports,
   298  		newCallExprArgs(
   299  			tcall.testingT(),
   300  			newCallExpr(imports.cmp, cmpName, tcall.args(1, 3)),
   301  			tcall.extraArgs(3)...))
   302  }
   303  
   304  func convertError(tcall call, imports importNames) ast.Node {
   305  	cmpArgs := []ast.Expr{
   306  		tcall.arg(1),
   307  		&ast.BasicLit{Kind: token.STRING, Value: `""`}}
   308  
   309  	return newCallExprWithPosition(tcall, imports,
   310  		newCallExprArgs(
   311  			tcall.testingT(),
   312  			newCallExpr(imports.cmp, "ErrorContains", cmpArgs),
   313  			tcall.extraArgs(2)...))
   314  }
   315  
   316  func convertErrorContains(tcall call, imports importNames) ast.Node {
   317  	return &ast.CallExpr{
   318  		Fun: &ast.SelectorExpr{
   319  			X: &ast.Ident{
   320  				Name:    imports.assert,
   321  				NamePos: tcall.xIdent.NamePos,
   322  			},
   323  			Sel: &ast.Ident{Name: "ErrorContains"},
   324  		},
   325  		Args: tcall.expr.Args,
   326  	}
   327  }
   328  
   329  func convertEmpty(tcall call, imports importNames) ast.Node {
   330  	cmpArgs := []ast.Expr{
   331  		tcall.arg(1),
   332  		&ast.BasicLit{Kind: token.INT, Value: "0"},
   333  	}
   334  	return newCallExprWithPosition(tcall, imports,
   335  		newCallExprArgs(
   336  			tcall.testingT(),
   337  			newCallExpr(imports.cmp, "Len", cmpArgs),
   338  			tcall.extraArgs(2)...))
   339  }
   340  
   341  func convertNil(tcall call, migration migration) ast.Node {
   342  	gotype := walkForType(migration.pkgInfo, tcall.arg(1))
   343  	if gotype != nil && gotype.String() == "error" {
   344  		return convertNoError(tcall, migration.importNames)
   345  	}
   346  	return convertOneArgComparison(tcall, migration.importNames, "Nil")
   347  }
   348  
   349  func convertNegativeComparison(
   350  	tcall call,
   351  	imports importNames,
   352  	right ast.Expr,
   353  	extra int,
   354  ) ast.Node {
   355  	return newCallExprWithPosition(tcall, imports,
   356  		newCallExprArgs(
   357  			tcall.testingT(),
   358  			&ast.BinaryExpr{X: tcall.arg(1), Op: token.NEQ, Y: right},
   359  			tcall.extraArgs(extra)...))
   360  }
   361  
   362  func convertFail(tcall call, selector string) ast.Node {
   363  	extraArgs := tcall.extraArgs(1)
   364  	if len(extraArgs) > 1 {
   365  		selector += "f"
   366  	}
   367  
   368  	return &ast.CallExpr{
   369  		Fun: &ast.SelectorExpr{
   370  			X:   tcall.testingT(),
   371  			Sel: &ast.Ident{Name: selector},
   372  		},
   373  		Args: extraArgs,
   374  	}
   375  }
   376  
   377  func convertNotEmpty(tcall call, imports importNames) ast.Node {
   378  	lenExpr := &ast.CallExpr{
   379  		Fun:  &ast.Ident{Name: "len"},
   380  		Args: tcall.args(1, 2),
   381  	}
   382  	zeroExpr := &ast.BasicLit{Kind: token.INT, Value: "0"}
   383  	return newCallExprWithPosition(tcall, imports,
   384  		newCallExprArgs(
   385  			tcall.testingT(),
   386  			&ast.BinaryExpr{X: lenExpr, Op: token.NEQ, Y: zeroExpr},
   387  			tcall.extraArgs(2)...))
   388  }
   389  

View as plain text