1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 package main
68
69 import (
70 "bytes"
71 "encoding/json"
72 "flag"
73 "fmt"
74 "go/ast"
75 "go/format"
76 "go/parser"
77 "go/token"
78 "io"
79 "log"
80 "os"
81 "os/exec"
82 "path"
83 "path/filepath"
84 "sort"
85 "strconv"
86 "strings"
87 )
88
89
90 var (
91 dryrun = flag.Bool("n", false, "dry run: show changes, but don't apply them")
92 badDomains = flag.String("baddomains", "code.google.com",
93 "a comma-separated list of domains from which packages should not be imported")
94 replaceFlag = flag.String("replace", "",
95 "a comma-separated list of noncanonical=canonical pairs of package paths. If both items in a pair end with '...', they are treated as path prefixes.")
96 )
97
98
99 var (
100 stderr io.Writer = os.Stderr
101 writeFile = os.WriteFile
102 )
103
104 const usage = `fiximports: rewrite import paths to use canonical package names.
105
106 Usage: fiximports [-n] package...
107
108 The package... arguments specify a list of packages
109 in the style of the go tool; see "go help packages".
110 Hint: use "all" or "..." to match the entire workspace.
111
112 For details, see https://pkg.go.dev/golang.org/x/tools/cmd/fiximports
113
114 Flags:
115 -n: dry run: show changes, but don't apply them
116 -baddomains a comma-separated list of domains from which packages
117 should not be imported
118 `
119
120 func main() {
121 flag.Parse()
122
123 if len(flag.Args()) == 0 {
124 fmt.Fprint(stderr, usage)
125 os.Exit(1)
126 }
127 if !fiximports(flag.Args()...) {
128 os.Exit(1)
129 }
130 }
131
132 type canonicalName struct{ path, name string }
133
134
135
136 func fiximports(packages ...string) bool {
137
138 importedBy := make(map[string]map[*listPackage]bool)
139
140
141 addEdge := func(from *listPackage, to string) {
142 if to == "C" || to == "unsafe" {
143 return
144 }
145 pkgs := importedBy[to]
146 if pkgs == nil {
147 pkgs = make(map[*listPackage]bool)
148 importedBy[to] = pkgs
149 }
150 pkgs[from] = true
151 }
152
153
154 pkgs, err := list("...")
155 if err != nil {
156 fmt.Fprintf(stderr, "importfix: %v\n", err)
157 return false
158 }
159
160
161 packageName := make(map[string]string)
162 for _, p := range pkgs {
163 packageName[p.ImportPath] = p.Name
164 }
165
166
167
168
169
170 canonical := make(map[string]canonicalName)
171 domains := strings.Split(*badDomains, ",")
172
173 type replaceItem struct {
174 old, new string
175 matchPrefix bool
176 }
177 var replace []replaceItem
178 for _, pair := range strings.Split(*replaceFlag, ",") {
179 if pair == "" {
180 continue
181 }
182 words := strings.Split(pair, "=")
183 if len(words) != 2 {
184 fmt.Fprintf(stderr, "importfix: -replace: %q is not of the form \"canonical=noncanonical\".\n", pair)
185 return false
186 }
187 replace = append(replace, replaceItem{
188 old: strings.TrimSuffix(words[0], "..."),
189 new: strings.TrimSuffix(words[1], "..."),
190 matchPrefix: strings.HasSuffix(words[0], "...") &&
191 strings.HasSuffix(words[1], "..."),
192 })
193 }
194
195
196 for _, p := range pkgs {
197 if p.Error != nil {
198 msg := p.Error.Err
199 if strings.Contains(msg, "code in directory") &&
200 strings.Contains(msg, "expects import") {
201
202 } else {
203 fmt.Fprintln(stderr, p.Error)
204 }
205 }
206
207 for _, imp := range p.Imports {
208 addEdge(p, imp)
209 }
210 for _, imp := range p.TestImports {
211 addEdge(p, imp)
212 }
213 for _, imp := range p.XTestImports {
214 addEdge(p, imp)
215 }
216
217
218 if p.ImportComment != "" {
219 if p.ImportComment != p.ImportPath {
220 canonical[p.ImportPath] = canonicalName{
221 path: p.ImportComment,
222 name: p.Name,
223 }
224 }
225 } else {
226
227 var newPath string
228 for _, item := range replace {
229 if item.matchPrefix {
230 if strings.HasPrefix(p.ImportPath, item.old) {
231 newPath = item.new + p.ImportPath[len(item.old):]
232 break
233 }
234 } else if p.ImportPath == item.old {
235 newPath = item.new
236 break
237 }
238 }
239 if newPath != "" {
240 newName := packageName[newPath]
241 if newName == "" {
242 newName = filepath.Base(newPath)
243 }
244 canonical[p.ImportPath] = canonicalName{
245 path: newPath,
246 name: newName,
247 }
248 continue
249 }
250
251
252 for _, domain := range domains {
253 slash := strings.Index(p.ImportPath, "/")
254 if slash < 0 {
255 continue
256 }
257 if p.ImportPath[:slash] == domain {
258
259
260 canonical[p.ImportPath] = canonicalName{}
261
262
263
264
265
266 }
267 break
268 }
269 }
270 }
271
272
273
274 clients := make(map[*listPackage]bool)
275 for path := range canonical {
276 for client := range importedBy[path] {
277 clients[client] = true
278 }
279 }
280
281
282 if len(packages) == 1 && (packages[0] == "all" || packages[0] == "...") {
283
284 } else {
285 pkgs, err := list(packages...)
286 if err != nil {
287 fmt.Fprintf(stderr, "importfix: %v\n", err)
288 return false
289 }
290 seen := make(map[string]bool)
291 for _, p := range pkgs {
292 seen[p.ImportPath] = true
293 }
294 for client := range clients {
295 if !seen[client.ImportPath] {
296 delete(clients, client)
297 }
298 }
299 }
300
301
302 ok := true
303 for client := range clients {
304 if !rewritePackage(client, canonical) {
305 ok = false
306
307
308
309 seen := make(map[string]bool)
310 var direct, indirect []string
311 for p := range importedBy[client.ImportPath] {
312 direct = append(direct, p.ImportPath)
313 seen[p.ImportPath] = true
314 }
315
316 var visit func(path string)
317 visit = func(path string) {
318 for q := range importedBy[path] {
319 qpath := q.ImportPath
320 if !seen[qpath] {
321 seen[qpath] = true
322 indirect = append(indirect, qpath)
323 visit(qpath)
324 }
325 }
326 }
327
328 if direct != nil {
329 fmt.Fprintf(stderr, "\timported directly by:\n")
330 sort.Strings(direct)
331 for _, path := range direct {
332 fmt.Fprintf(stderr, "\t\t%s\n", path)
333 visit(path)
334 }
335
336 if indirect != nil {
337 fmt.Fprintf(stderr, "\timported indirectly by:\n")
338 sort.Strings(indirect)
339 for _, path := range indirect {
340 fmt.Fprintf(stderr, "\t\t%s\n", path)
341 }
342 }
343 }
344 }
345 }
346
347 return ok
348 }
349
350
351 func rewritePackage(client *listPackage, canonical map[string]canonicalName) bool {
352 ok := true
353
354 used := make(map[string]bool)
355 var filenames []string
356 filenames = append(filenames, client.GoFiles...)
357 filenames = append(filenames, client.TestGoFiles...)
358 filenames = append(filenames, client.XTestGoFiles...)
359 var first bool
360 for _, filename := range filenames {
361 if !first {
362 first = true
363 fmt.Fprintf(stderr, "%s\n", client.ImportPath)
364 }
365 err := rewriteFile(filepath.Join(client.Dir, filename), canonical, used)
366 if err != nil {
367 fmt.Fprintf(stderr, "\tERROR: %v\n", err)
368 ok = false
369 }
370 }
371
372
373 var keys []string
374 for key := range used {
375 keys = append(keys, key)
376 }
377 sort.Strings(keys)
378 for _, key := range keys {
379 if p := canonical[key]; p.path != "" {
380 fmt.Fprintf(stderr, "\tfixed: %s -> %s\n", key, p.path)
381 } else {
382 fmt.Fprintf(stderr, "\tERROR: %s has no import comment\n", key)
383 ok = false
384 }
385 }
386
387 return ok
388 }
389
390
391
392
393
394 func rewriteFile(filename string, canonical map[string]canonicalName, used map[string]bool) error {
395 fset := token.NewFileSet()
396 f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
397 if err != nil {
398 return err
399 }
400 var changed bool
401 for _, imp := range f.Imports {
402 impPath, err := strconv.Unquote(imp.Path.Value)
403 if err != nil {
404 log.Printf("%s: bad import spec %q: %v",
405 fset.Position(imp.Pos()), imp.Path.Value, err)
406 continue
407 }
408 canon, ok := canonical[impPath]
409 if !ok {
410 continue
411 }
412
413 used[impPath] = true
414
415 if canon.path == "" {
416
417
418
419 fmt.Fprintf(stderr, "\t%s:%d: import %q\n",
420 shortPath(filename),
421 fset.Position(imp.Pos()).Line, impPath)
422 continue
423 }
424
425 changed = true
426
427 imp.Path.Value = strconv.Quote(canon.path)
428
429
430
431
432
433
434 newBase := path.Base(canon.path)
435 if imp.Name == nil && newBase != canon.name {
436 imp.Name = &ast.Ident{Name: canon.name}
437 }
438 }
439
440 if changed && !*dryrun {
441 var buf bytes.Buffer
442 if err := format.Node(&buf, fset, f); err != nil {
443 return fmt.Errorf("%s: couldn't format file: %v", filename, err)
444 }
445 return writeFile(filename, buf.Bytes(), 0644)
446 }
447
448 return nil
449 }
450
451
452
453 type listPackage struct {
454 Name string
455 Dir string
456 ImportPath string
457 GoFiles []string
458 TestGoFiles []string
459 XTestGoFiles []string
460 Imports []string
461 TestImports []string
462 XTestImports []string
463 ImportComment string
464 Error *packageError
465 }
466
467
468 type packageError struct {
469 ImportStack []string
470 Pos string
471 Err string
472 }
473
474 func (e packageError) Error() string {
475 if e.Pos != "" {
476 return e.Pos + ": " + e.Err
477 }
478 return e.Err
479 }
480
481
482
483 func list(args ...string) ([]*listPackage, error) {
484 cmd := exec.Command("go", append([]string{"list", "-e", "-json"}, args...)...)
485 cmd.Stdout = new(bytes.Buffer)
486 cmd.Stderr = stderr
487 if err := cmd.Run(); err != nil {
488 return nil, err
489 }
490
491 dec := json.NewDecoder(cmd.Stdout.(io.Reader))
492 var pkgs []*listPackage
493 for {
494 var p listPackage
495 if err := dec.Decode(&p); err == io.EOF {
496 break
497 } else if err != nil {
498 return nil, err
499 }
500 pkgs = append(pkgs, &p)
501 }
502 return pkgs, nil
503 }
504
505
506
507
508
509
510 var cwd string = func() string {
511 cwd, err := os.Getwd()
512 if err != nil {
513 log.Fatalf("os.Getwd: %v", err)
514 }
515 return cwd
516 }()
517
518
519
520 func shortPath(path string) string {
521 if rel, err := filepath.Rel(cwd, path); err == nil && len(rel) < len(path) {
522 return rel
523 }
524 return path
525 }
526
View as plain text