1 package main
2
3 import (
4 "go/ast"
5 "go/token"
6 "go/types"
7 "log"
8 "path"
9
10 "golang.org/x/tools/go/ast/astutil"
11 )
12
13 const (
14 pkgTestifyAssert = "github.com/stretchr/testify/assert"
15 pkgGopkgTestifyAssert = "gopkg.in/stretchr/testify.v1/assert"
16 pkgTestifyRequire = "github.com/stretchr/testify/require"
17 pkgGopkgTestifyRequire = "gopkg.in/stretchr/testify.v1/require"
18 pkgAssert = "gotest.tools/v3/assert"
19 pkgCmp = "gotest.tools/v3/assert/cmp"
20 )
21
22 const (
23 funcNameAssert = "Assert"
24 funcNameCheck = "Check"
25 )
26
27 var allTestifyPks = []string{
28 pkgTestifyAssert,
29 pkgTestifyRequire,
30 pkgGopkgTestifyAssert,
31 pkgGopkgTestifyRequire,
32 }
33
34 type migration struct {
35 file *ast.File
36 fileset *token.FileSet
37 importNames importNames
38 pkgInfo *types.Info
39 }
40
41 func migrateFile(migration migration) {
42 astutil.Apply(migration.file, nil, replaceCalls(migration))
43 updateImports(migration)
44 }
45
46 func updateImports(migration migration) {
47 for _, remove := range allTestifyPks {
48 astutil.DeleteImport(migration.fileset, migration.file, remove)
49 }
50
51 var alias string
52 if migration.importNames.assert != path.Base(pkgAssert) {
53 alias = migration.importNames.assert
54 }
55 astutil.AddNamedImport(migration.fileset, migration.file, alias, pkgAssert)
56
57 if migration.importNames.cmp != path.Base(pkgCmp) {
58 alias = migration.importNames.cmp
59 }
60 astutil.AddNamedImport(migration.fileset, migration.file, alias, pkgCmp)
61 }
62
63 type emptyNode struct{}
64
65 func (n emptyNode) Pos() token.Pos {
66 return 0
67 }
68
69 func (n emptyNode) End() token.Pos {
70 return 0
71 }
72
73 var removeNode = emptyNode{}
74
75 func replaceCalls(migration migration) func(cursor *astutil.Cursor) bool {
76 return func(cursor *astutil.Cursor) bool {
77 var newNode ast.Node
78 switch typed := cursor.Node().(type) {
79 case *ast.SelectorExpr:
80 newNode = getReplacementTestingT(typed, migration.importNames)
81 case *ast.CallExpr:
82 newNode = getReplacementAssertion(typed, migration)
83 case *ast.AssignStmt:
84 newNode = getReplacementAssignment(typed, migration)
85 }
86
87 switch newNode {
88 case nil:
89 case removeNode:
90 cursor.Delete()
91 default:
92 cursor.Replace(newNode)
93 }
94 return true
95 }
96 }
97
98 func getReplacementTestingT(selector *ast.SelectorExpr, names importNames) ast.Node {
99 xIdent, ok := selector.X.(*ast.Ident)
100 if !ok {
101 return nil
102 }
103 if selector.Sel.Name != "TestingT" || !names.matchesTestify(xIdent) {
104 return nil
105 }
106 return &ast.SelectorExpr{
107 X: &ast.Ident{Name: names.assert, NamePos: xIdent.NamePos},
108 Sel: selector.Sel,
109 }
110 }
111
112 func getReplacementAssertion(callExpr *ast.CallExpr, migration migration) ast.Node {
113 tcall, ok := newTestifyCallFromNode(callExpr, migration)
114 if !ok {
115 return nil
116 }
117 if len(tcall.expr.Args) < 2 {
118 return convertTestifySingleArgCall(tcall)
119 }
120 return convertTestifyAssertion(tcall, migration)
121 }
122
123 func getReplacementAssignment(assign *ast.AssignStmt, migration migration) ast.Node {
124 if isAssignmentFromAssertNew(assign, migration) {
125 return removeNode
126 }
127 return nil
128 }
129
130 func convertTestifySingleArgCall(tcall call) ast.Node {
131 switch tcall.selExpr.Sel.Name {
132 case "TestingT":
133
134 return nil
135 case "New":
136
137 return nil
138 default:
139 log.Printf("%s: skipping unknown selector", tcall.StringWithFileInfo())
140 return nil
141 }
142 }
143
144 func convertTestifyAssertion(tcall call, migration migration) ast.Node {
145 imports := migration.importNames
146
147 switch tcall.selExpr.Sel.Name {
148 case "NoError", "NoErrorf":
149 return convertNoError(tcall, imports)
150 case "True", "Truef":
151 return convertTrue(tcall, imports)
152 case "False", "Falsef":
153 return convertFalse(tcall, imports)
154 case "Equal", "Equalf", "Exactly", "Exactlyf", "EqualValues", "EqualValuesf":
155 return convertEqual(tcall, migration)
156 case "Contains", "Containsf":
157 return convertTwoArgComparison(tcall, imports, "Contains")
158 case "Len", "Lenf":
159 return convertTwoArgComparison(tcall, imports, "Len")
160 case "Panics", "Panicsf":
161 return convertOneArgComparison(tcall, imports, "Panics")
162 case "EqualError", "EqualErrorf":
163 return convertEqualError(tcall, imports)
164 case "Error", "Errorf":
165 return convertError(tcall, imports)
166 case "ErrorContains", "ErrorContainsf":
167 return convertErrorContains(tcall, imports)
168 case "Empty", "Emptyf":
169 return convertEmpty(tcall, imports)
170 case "Nil", "Nilf":
171 return convertNil(tcall, migration)
172 case "NotNil", "NotNilf":
173 return convertNegativeComparison(tcall, imports, &ast.Ident{Name: "nil"}, 2)
174 case "NotEqual", "NotEqualf":
175 return convertNegativeComparison(tcall, imports, tcall.arg(2), 3)
176 case "Fail", "Failf":
177 return convertFail(tcall, "Error")
178 case "FailNow", "FailNowf":
179 return convertFail(tcall, "Fatal")
180 case "NotEmpty", "NotEmptyf":
181 return convertNotEmpty(tcall, imports)
182 case "NotZero", "NotZerof":
183 zero := &ast.BasicLit{Kind: token.INT, Value: "0"}
184 return convertNegativeComparison(tcall, imports, zero, 2)
185 }
186 log.Printf("%s: skipping unsupported assertion", tcall.StringWithFileInfo())
187 return nil
188 }
189
190 func newCallExpr(x, sel string, args []ast.Expr) *ast.CallExpr {
191 return &ast.CallExpr{
192 Fun: &ast.SelectorExpr{
193 X: &ast.Ident{Name: x},
194 Sel: &ast.Ident{Name: sel},
195 },
196 Args: args,
197 }
198 }
199
200 func newCallExprArgs(t ast.Expr, cmp ast.Expr, extra ...ast.Expr) []ast.Expr {
201 return append(append([]ast.Expr{t}, cmp), extra...)
202 }
203
204 func newCallExprWithPosition(tcall call, imports importNames, args []ast.Expr) *ast.CallExpr {
205 return &ast.CallExpr{
206 Fun: &ast.SelectorExpr{
207 X: &ast.Ident{
208 Name: imports.assert,
209 NamePos: tcall.xIdent.NamePos,
210 },
211 Sel: &ast.Ident{Name: tcall.assert},
212 },
213 Args: args,
214 }
215 }
216
217 func convertNoError(tcall call, imports importNames) ast.Node {
218
219 if tcall.assert == funcNameAssert {
220 return newCallExprWithoutComparison(tcall, imports, "NilError")
221 }
222
223 return newCallExprWithoutComparison(tcall, imports, "Check")
224 }
225
226 func convertEqualError(tcall call, imports importNames) ast.Node {
227 if tcall.assert == funcNameAssert {
228 return newCallExprWithoutComparison(tcall, imports, "Error")
229 }
230 return convertTwoArgComparison(tcall, imports, "Error")
231 }
232
233 func newCallExprWithoutComparison(tcall call, imports importNames, name string) ast.Node {
234 return &ast.CallExpr{
235 Fun: &ast.SelectorExpr{
236 X: &ast.Ident{
237 Name: imports.assert,
238 NamePos: tcall.xIdent.NamePos,
239 },
240 Sel: &ast.Ident{Name: name},
241 },
242 Args: tcall.expr.Args,
243 }
244 }
245
246 func convertOneArgComparison(tcall call, imports importNames, cmpName string) ast.Node {
247 return newCallExprWithPosition(tcall, imports,
248 newCallExprArgs(
249 tcall.testingT(),
250 newCallExpr(imports.cmp, cmpName, []ast.Expr{tcall.arg(1)}),
251 tcall.extraArgs(2)...))
252 }
253
254 func convertTrue(tcall call, imports importNames) ast.Node {
255 return newCallExprWithPosition(tcall, imports, tcall.expr.Args)
256 }
257
258 func convertFalse(tcall call, imports importNames) ast.Node {
259 return newCallExprWithPosition(tcall, imports,
260 newCallExprArgs(
261 tcall.testingT(),
262 &ast.UnaryExpr{Op: token.NOT, X: tcall.arg(1)},
263 tcall.extraArgs(2)...))
264 }
265
266 func convertEqual(tcall call, migration migration) ast.Node {
267 imports := migration.importNames
268
269 hasExtraArgs := len(tcall.extraArgs(3)) > 0
270
271 cmpEqual := convertTwoArgComparison(tcall, imports, "Equal")
272 if tcall.assert == funcNameAssert {
273 cmpEqual = newCallExprWithoutComparison(tcall, imports, "Equal")
274 }
275 cmpDeepEqual := convertTwoArgComparison(tcall, imports, "DeepEqual")
276 if tcall.assert == funcNameAssert && !hasExtraArgs {
277 cmpDeepEqual = newCallExprWithoutComparison(tcall, imports, "DeepEqual")
278 }
279
280 gotype := walkForType(migration.pkgInfo, tcall.arg(1))
281 if isUnknownType(gotype) {
282 gotype = walkForType(migration.pkgInfo, tcall.arg(2))
283 }
284 if isUnknownType(gotype) {
285 return cmpDeepEqual
286 }
287
288 switch gotype.Underlying().(type) {
289 case *types.Basic:
290 return cmpEqual
291 default:
292 return cmpDeepEqual
293 }
294 }
295
296 func convertTwoArgComparison(tcall call, imports importNames, cmpName string) ast.Node {
297 return newCallExprWithPosition(tcall, imports,
298 newCallExprArgs(
299 tcall.testingT(),
300 newCallExpr(imports.cmp, cmpName, tcall.args(1, 3)),
301 tcall.extraArgs(3)...))
302 }
303
304 func convertError(tcall call, imports importNames) ast.Node {
305 cmpArgs := []ast.Expr{
306 tcall.arg(1),
307 &ast.BasicLit{Kind: token.STRING, Value: `""`}}
308
309 return newCallExprWithPosition(tcall, imports,
310 newCallExprArgs(
311 tcall.testingT(),
312 newCallExpr(imports.cmp, "ErrorContains", cmpArgs),
313 tcall.extraArgs(2)...))
314 }
315
316 func convertErrorContains(tcall call, imports importNames) ast.Node {
317 return &ast.CallExpr{
318 Fun: &ast.SelectorExpr{
319 X: &ast.Ident{
320 Name: imports.assert,
321 NamePos: tcall.xIdent.NamePos,
322 },
323 Sel: &ast.Ident{Name: "ErrorContains"},
324 },
325 Args: tcall.expr.Args,
326 }
327 }
328
329 func convertEmpty(tcall call, imports importNames) ast.Node {
330 cmpArgs := []ast.Expr{
331 tcall.arg(1),
332 &ast.BasicLit{Kind: token.INT, Value: "0"},
333 }
334 return newCallExprWithPosition(tcall, imports,
335 newCallExprArgs(
336 tcall.testingT(),
337 newCallExpr(imports.cmp, "Len", cmpArgs),
338 tcall.extraArgs(2)...))
339 }
340
341 func convertNil(tcall call, migration migration) ast.Node {
342 gotype := walkForType(migration.pkgInfo, tcall.arg(1))
343 if gotype != nil && gotype.String() == "error" {
344 return convertNoError(tcall, migration.importNames)
345 }
346 return convertOneArgComparison(tcall, migration.importNames, "Nil")
347 }
348
349 func convertNegativeComparison(
350 tcall call,
351 imports importNames,
352 right ast.Expr,
353 extra int,
354 ) ast.Node {
355 return newCallExprWithPosition(tcall, imports,
356 newCallExprArgs(
357 tcall.testingT(),
358 &ast.BinaryExpr{X: tcall.arg(1), Op: token.NEQ, Y: right},
359 tcall.extraArgs(extra)...))
360 }
361
362 func convertFail(tcall call, selector string) ast.Node {
363 extraArgs := tcall.extraArgs(1)
364 if len(extraArgs) > 1 {
365 selector += "f"
366 }
367
368 return &ast.CallExpr{
369 Fun: &ast.SelectorExpr{
370 X: tcall.testingT(),
371 Sel: &ast.Ident{Name: selector},
372 },
373 Args: extraArgs,
374 }
375 }
376
377 func convertNotEmpty(tcall call, imports importNames) ast.Node {
378 lenExpr := &ast.CallExpr{
379 Fun: &ast.Ident{Name: "len"},
380 Args: tcall.args(1, 2),
381 }
382 zeroExpr := &ast.BasicLit{Kind: token.INT, Value: "0"}
383 return newCallExprWithPosition(tcall, imports,
384 newCallExprArgs(
385 tcall.testingT(),
386 &ast.BinaryExpr{X: lenExpr, Op: token.NEQ, Y: zeroExpr},
387 tcall.extraArgs(2)...))
388 }
389
View as plain text