1 package main
2
3 import (
4 "bytes"
5 "errors"
6 "flag"
7 "fmt"
8 "go/ast"
9 "go/format"
10 "go/token"
11 "log"
12 "os"
13 "path"
14 "path/filepath"
15 "strings"
16
17 "golang.org/x/tools/go/packages"
18 "golang.org/x/tools/imports"
19 )
20
21 type options struct {
22 pkgs []string
23 dryRun bool
24 debug bool
25 cmpImportName string
26 showLoaderErrors bool
27 buildFlags []string
28 localImportPath string
29 }
30
31 func main() {
32 name := os.Args[0]
33 flags, opts := setupFlags(name)
34 handleExitError(name, flags.Parse(os.Args[1:]))
35 setupLogging(opts)
36 opts.pkgs = flags.Args()
37 handleExitError(name, run(*opts))
38 }
39
40 func setupLogging(opts *options) {
41 log.SetFlags(0)
42 enableDebug = opts.debug
43 }
44
45 var enableDebug = false
46
47 func debugf(msg string, args ...interface{}) {
48 if enableDebug {
49 log.Printf("DEBUG: "+msg, args...)
50 }
51 }
52
53 func setupFlags(name string) (*flag.FlagSet, *options) {
54 opts := options{}
55 flags := flag.NewFlagSet(name, flag.ContinueOnError)
56 flags.BoolVar(&opts.dryRun, "dry-run", false,
57 "don't write changes to file")
58 flags.BoolVar(&opts.debug, "debug", false, "enable debug logging")
59 flags.StringVar(&opts.cmpImportName, "cmp-pkg-import-alias", "is",
60 "import alias to use for the assert/cmp package")
61 flags.BoolVar(&opts.showLoaderErrors, "print-loader-errors", false,
62 "print errors from loading source")
63 flags.Var((*stringSliceValue)(&opts.buildFlags), "build-flags",
64 "build flags to pass to Go when loading source files")
65 flags.StringVar(&opts.localImportPath, "local-import-path", "",
66 "value to pass to 'goimports -local' flag for sorting local imports")
67 flags.Usage = func() {
68 fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]
69
70 Migrate calls from testify/{assert|require} to gotest.tools/v3/assert.
71
72 `, name)
73 flags.PrintDefaults()
74 }
75 return flags, &opts
76 }
77
78 func handleExitError(name string, err error) {
79 switch {
80 case err == nil:
81 return
82 case errors.Is(err, flag.ErrHelp):
83 os.Exit(0)
84 default:
85 log.Println(name + ": Error: " + err.Error())
86 os.Exit(3)
87 }
88 }
89
90 func run(opts options) error {
91 imports.LocalPrefix = opts.localImportPath
92
93 fset := token.NewFileSet()
94 pkgs, err := loadPackages(opts, fset)
95 if err != nil {
96 return fmt.Errorf("failed to load program: %w", err)
97 }
98
99 debugf("package count: %d", len(pkgs))
100 for _, pkg := range pkgs {
101 debugf("file count for package %v: %d", pkg.PkgPath, len(pkg.Syntax))
102 for _, astFile := range pkg.Syntax {
103 absFilename := fset.File(astFile.Pos()).Name()
104 filename := relativePath(absFilename)
105 importNames := newImportNames(astFile.Imports, opts)
106 if !importNames.hasTestifyImports() {
107 debugf("skipping file %s, no imports", filename)
108 continue
109 }
110
111 debugf("migrating %s with imports: %#v", filename, importNames)
112 m := migration{
113 file: astFile,
114 fileset: fset,
115 importNames: importNames,
116 pkgInfo: pkg.TypesInfo,
117 }
118 migrateFile(m)
119 if opts.dryRun {
120 continue
121 }
122
123 raw, err := formatFile(m)
124 if err != nil {
125 return fmt.Errorf("failed to format %s: %w", filename, err)
126 }
127
128 if err := os.WriteFile(absFilename, raw, 0); err != nil {
129 return fmt.Errorf("failed to write file %s: %w", filename, err)
130 }
131 }
132 }
133
134 return nil
135 }
136
137 var loadMode = packages.NeedName |
138 packages.NeedFiles |
139 packages.NeedCompiledGoFiles |
140 packages.NeedDeps |
141 packages.NeedImports |
142 packages.NeedTypes |
143 packages.NeedTypesInfo |
144 packages.NeedTypesSizes |
145 packages.NeedSyntax
146
147 func loadPackages(opts options, fset *token.FileSet) ([]*packages.Package, error) {
148 conf := &packages.Config{
149 Mode: loadMode,
150 Fset: fset,
151 Tests: true,
152 Logf: debugf,
153 BuildFlags: opts.buildFlags,
154 }
155
156 pkgs, err := packages.Load(conf, opts.pkgs...)
157 if err != nil {
158 return nil, err
159 }
160 if opts.showLoaderErrors {
161 packages.PrintErrors(pkgs)
162 }
163 return pkgs, nil
164 }
165
166 func relativePath(p string) string {
167 cwd, err := os.Getwd()
168 if err != nil {
169 return p
170 }
171 rel, err := filepath.Rel(cwd, p)
172 if err != nil {
173 return p
174 }
175 return rel
176 }
177
178 type importNames struct {
179 testifyAssert string
180 testifyRequire string
181 assert string
182 cmp string
183 }
184
185 func (p importNames) hasTestifyImports() bool {
186 return p.testifyAssert != "" || p.testifyRequire != ""
187 }
188
189 func (p importNames) matchesTestify(ident *ast.Ident) bool {
190 return ident.Name == p.testifyAssert || ident.Name == p.testifyRequire
191 }
192
193 func (p importNames) funcNameFromTestifyName(name string) string {
194 switch name {
195 case p.testifyAssert:
196 return funcNameCheck
197 case p.testifyRequire:
198 return funcNameAssert
199 default:
200 panic("unexpected testify import name " + name)
201 }
202 }
203
204 func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
205 defaultAssertAlias := path.Base(pkgAssert)
206 importNames := importNames{
207 assert: defaultAssertAlias,
208 cmp: path.Base(pkgCmp),
209 }
210 for _, spec := range imports {
211 switch strings.Trim(spec.Path.Value, `"`) {
212 case pkgTestifyAssert, pkgGopkgTestifyAssert:
213 importNames.testifyAssert = identOrDefault(spec.Name, "assert")
214 case pkgTestifyRequire, pkgGopkgTestifyRequire:
215 importNames.testifyRequire = identOrDefault(spec.Name, "require")
216 default:
217 pkgPath := strings.Trim(spec.Path.Value, `"`)
218
219 switch {
220
221 case pkgPath == pkgAssert:
222 if spec.Name != nil && spec.Name.Name != "" {
223 importNames.assert = spec.Name.Name
224 }
225 continue
226
227
228 case importedAs(spec, path.Base(pkgAssert)) && importNames.assert == defaultAssertAlias:
229 importNames.assert = "gtyassert"
230 }
231 }
232 }
233
234 if opt.cmpImportName != "" {
235 importNames.cmp = opt.cmpImportName
236 }
237 return importNames
238 }
239
240 func importedAs(spec *ast.ImportSpec, pkg string) bool {
241 if path.Base(strings.Trim(spec.Path.Value, `"`)) == pkg {
242 return true
243 }
244 return spec.Name != nil && spec.Name.Name == pkg
245 }
246
247 func identOrDefault(ident *ast.Ident, def string) string {
248 if ident != nil {
249 return ident.Name
250 }
251 return def
252 }
253
254 func formatFile(migration migration) ([]byte, error) {
255 buf := new(bytes.Buffer)
256 err := format.Node(buf, migration.fileset, migration.file)
257 if err != nil {
258 return nil, err
259 }
260 filename := migration.fileset.File(migration.file.Pos()).Name()
261 return imports.Process(filename, buf.Bytes(), nil)
262 }
263
View as plain text