1
16
17
18 package main
19
20 import (
21 "bytes"
22 "encoding/json"
23 "flag"
24 "fmt"
25 "go/ast"
26 "go/build"
27 "go/format"
28 "go/parser"
29 "go/token"
30 "log"
31 "os"
32 "path/filepath"
33 "regexp"
34 "sort"
35 "strings"
36
37 "golang.org/x/term"
38 )
39
40 var (
41 importAliases = flag.String("import-aliases", "hack/.import-aliases", "json file with import aliases")
42 confirm = flag.Bool("confirm", false, "update file with the preferred aliases for imports")
43 regex = flag.String("include-path", "(test/e2e/|test/e2e_node)", "only files with paths matching this regex is touched")
44 isTerminal = term.IsTerminal(int(os.Stdout.Fd()))
45 logPrefix = ""
46 aliases = map[*regexp.Regexp]string{}
47 )
48
49 type analyzer struct {
50 fset *token.FileSet
51 ctx build.Context
52 failed bool
53 donePaths map[string]interface{}
54 }
55
56 func newAnalyzer() *analyzer {
57 ctx := build.Default
58 ctx.CgoEnabled = true
59
60 a := &analyzer{
61 fset: token.NewFileSet(),
62 ctx: ctx,
63 donePaths: make(map[string]interface{}),
64 }
65
66 return a
67 }
68
69
70 func (a *analyzer) collect(dir string) {
71 if _, ok := a.donePaths[dir]; ok {
72 return
73 }
74 a.donePaths[dir] = nil
75
76
77 fs, err := parser.ParseDir(a.fset, dir, nil, parser.AllErrors|parser.ParseComments)
78
79 if err != nil {
80 fmt.Fprintln(os.Stderr, "ERROR(syntax)", logPrefix, err)
81 a.failed = true
82 return
83 }
84
85 for _, p := range fs {
86
87 files := a.filterFiles(p.Files)
88 for _, file := range files {
89 replacements := make(map[string]string)
90 pathToFile := a.fset.File(file.Pos()).Name()
91 for _, imp := range file.Imports {
92 importPath := strings.Replace(imp.Path.Value, "\"", "", -1)
93 pathSegments := strings.Split(importPath, "/")
94 importName := pathSegments[len(pathSegments)-1]
95 if imp.Name != nil {
96 importName = imp.Name.Name
97 }
98 for re, template := range aliases {
99 match := re.FindStringSubmatchIndex(importPath)
100 if match == nil {
101
102 continue
103 }
104 if match[0] > 0 || match[1] < len(importPath) {
105
106 continue
107 }
108 alias := string(re.ExpandString(nil, template, importPath, match))
109 if alias != importName {
110 if !*confirm {
111 fmt.Fprintf(os.Stderr, "%sERROR wrong alias for import \"%s\" should be %s in file %s\n", logPrefix, importPath, alias, pathToFile)
112 a.failed = true
113 }
114 replacements[importName] = alias
115 if imp.Name != nil {
116 imp.Name.Name = alias
117 } else {
118 imp.Name = ast.NewIdent(alias)
119 }
120 }
121 break
122 }
123 }
124
125 if len(replacements) > 0 {
126 if *confirm {
127 fmt.Printf("%sReplacing imports with aliases in file %s\n", logPrefix, pathToFile)
128 for key, value := range replacements {
129 renameImportUsages(file, key, value)
130 }
131 ast.SortImports(a.fset, file)
132 var buffer bytes.Buffer
133 if err = format.Node(&buffer, a.fset, file); err != nil {
134 panic(fmt.Sprintf("Error formatting ast node after rewriting import.\n%s\n", err.Error()))
135 }
136
137 fileInfo, err := os.Stat(pathToFile)
138 if err != nil {
139 panic(fmt.Sprintf("Error stat'ing file: %s\n%s\n", pathToFile, err.Error()))
140 }
141
142 err = os.WriteFile(pathToFile, buffer.Bytes(), fileInfo.Mode())
143 if err != nil {
144 panic(fmt.Sprintf("Error writing file: %s\n%s\n", pathToFile, err.Error()))
145 }
146 }
147 }
148 }
149 }
150 }
151
152 func renameImportUsages(f *ast.File, old, new string) {
153
154
155
156 var pkg *ast.Ident
157
158
159
160
161 ast.Inspect(f, func(node ast.Node) bool {
162 if node == nil {
163 return false
164 }
165 switch id := node.(type) {
166 case *ast.File:
167 pkg = id.Name
168 case *ast.Ident:
169 if pkg != nil && id == pkg {
170 return false
171 }
172 if id.Name == old {
173 id.Name = new
174 }
175 }
176 return true
177 })
178 }
179
180 func (a *analyzer) filterFiles(fs map[string]*ast.File) []*ast.File {
181 var files []*ast.File
182 for _, f := range fs {
183 files = append(files, f)
184 }
185 return files
186 }
187
188 type collector struct {
189 dirs []string
190 regex *regexp.Regexp
191 }
192
193
194
195
196 func (c *collector) handlePath(path string, info os.FileInfo, err error) error {
197 if err != nil {
198 return err
199 }
200 if info.IsDir() {
201
202 if len(path) > 1 && path[0] == '.' ||
203
204
205
206 path == "vendor" ||
207 path == "_output" ||
208
209
210
211
212 path == "pkg/kubectl/cmd/testdata/edit" {
213 return filepath.SkipDir
214 }
215 if c.regex.MatchString(path) {
216 c.dirs = append(c.dirs, path)
217 }
218 }
219 return nil
220 }
221
222 func main() {
223 flag.Parse()
224 args := flag.Args()
225
226 if len(args) == 0 {
227 args = append(args, ".")
228 }
229
230 regex, err := regexp.Compile(*regex)
231 if err != nil {
232 log.Fatalf("Error compiling regex: %v", err)
233 }
234 c := collector{regex: regex}
235 for _, arg := range args {
236 err := filepath.Walk(arg, c.handlePath)
237 if err != nil {
238 log.Fatalf("Error walking: %v", err)
239 }
240 }
241 sort.Strings(c.dirs)
242
243 if len(*importAliases) > 0 {
244 bytes, err := os.ReadFile(*importAliases)
245 if err != nil {
246 log.Fatalf("Error reading import aliases: %v", err)
247 }
248 var stringAliases map[string]string
249 err = json.Unmarshal(bytes, &stringAliases)
250 if err != nil {
251 log.Fatalf("Error loading aliases: %v", err)
252 }
253 for pattern, name := range stringAliases {
254 re, err := regexp.Compile(pattern)
255 if err != nil {
256 log.Fatalf("Error parsing import path pattern %q as regular expression: %v", pattern, err)
257 }
258 aliases[re] = name
259 }
260 }
261 if isTerminal {
262 logPrefix = "\r"
263 }
264 fmt.Println("checking-imports: ")
265
266 a := newAnalyzer()
267 for _, dir := range c.dirs {
268 if isTerminal {
269 fmt.Printf("\r\033[0m %-80s", dir)
270 }
271 a.collect(dir)
272 }
273 fmt.Println()
274 if a.failed {
275 os.Exit(1)
276 }
277 }
278
View as plain text