1
2
3
4
5
6
7 package imports
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "fmt"
14 "go/ast"
15 "go/format"
16 "go/parser"
17 "go/printer"
18 "go/token"
19 "io"
20 "regexp"
21 "strconv"
22 "strings"
23
24 "golang.org/x/tools/go/ast/astutil"
25 "golang.org/x/tools/internal/event"
26 )
27
28
29 type Options struct {
30 Env *ProcessEnv
31
32
33
34
35 LocalPrefix string
36
37 Fragment bool
38 AllErrors bool
39
40 Comments bool
41 TabIndent bool
42 TabWidth int
43
44 FormatOnly bool
45 }
46
47
48 func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
49 fileSet := token.NewFileSet()
50 file, adjust, err := parse(fileSet, filename, src, opt)
51 if err != nil {
52 return nil, err
53 }
54
55 if !opt.FormatOnly {
56 if err := fixImports(fileSet, file, filename, opt.Env); err != nil {
57 return nil, err
58 }
59 }
60 return formatFile(fileSet, file, src, adjust, opt)
61 }
62
63
64
65
66
67
68
69 func FixImports(ctx context.Context, filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
70 ctx, done := event.Start(ctx, "imports.FixImports")
71 defer done()
72
73 fileSet := token.NewFileSet()
74 file, _, err := parse(fileSet, filename, src, opt)
75 if err != nil {
76 return nil, err
77 }
78
79 return getFixes(ctx, fileSet, file, filename, opt.Env)
80 }
81
82
83
84
85 func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, extraMode parser.Mode) (formatted []byte, err error) {
86
87
88 fileSet := token.NewFileSet()
89 parserMode := parser.Mode(0)
90 if opt.Comments {
91 parserMode |= parser.ParseComments
92 }
93 if opt.AllErrors {
94 parserMode |= parser.AllErrors
95 }
96 parserMode |= extraMode
97
98 file, err := parser.ParseFile(fileSet, filename, src, parserMode)
99 if file == nil {
100 return nil, err
101 }
102
103
104 apply(fileSet, file, fixes)
105
106 return formatFile(fileSet, file, src, nil, opt)
107 }
108
109
110
111
112
113
114
115 func formatFile(fset *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
116 mergeImports(file)
117 sortImports(opt.LocalPrefix, fset.File(file.Pos()), file)
118 var spacesBefore []string
119 for _, impSection := range astutil.Imports(fset, file) {
120
121
122
123
124 lastGroup := -1
125 for _, importSpec := range impSection {
126 importPath, _ := strconv.Unquote(importSpec.Path.Value)
127 groupNum := importGroup(opt.LocalPrefix, importPath)
128 if groupNum != lastGroup && lastGroup != -1 {
129 spacesBefore = append(spacesBefore, importPath)
130 }
131 lastGroup = groupNum
132 }
133
134 }
135
136 printerMode := printer.UseSpaces
137 if opt.TabIndent {
138 printerMode |= printer.TabIndent
139 }
140 printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
141
142 var buf bytes.Buffer
143 err := printConfig.Fprint(&buf, fset, file)
144 if err != nil {
145 return nil, err
146 }
147 out := buf.Bytes()
148 if adjust != nil {
149 out = adjust(src, out)
150 }
151 if len(spacesBefore) > 0 {
152 out, err = addImportSpaces(bytes.NewReader(out), spacesBefore)
153 if err != nil {
154 return nil, err
155 }
156 }
157
158 out, err = format.Source(out)
159 if err != nil {
160 return nil, err
161 }
162 return out, nil
163 }
164
165
166
167 func parse(fset *token.FileSet, filename string, src []byte, opt *Options) (*ast.File, func(orig, src []byte) []byte, error) {
168 parserMode := parser.Mode(0)
169 if opt.Comments {
170 parserMode |= parser.ParseComments
171 }
172 if opt.AllErrors {
173 parserMode |= parser.AllErrors
174 }
175
176
177 file, err := parser.ParseFile(fset, filename, src, parserMode)
178 if err == nil {
179 return file, nil, nil
180 }
181
182
183
184 if !opt.Fragment || !strings.Contains(err.Error(), "expected 'package'") {
185 return nil, nil, err
186 }
187
188
189
190
191
192 const prefix = "package main;"
193 psrc := append([]byte(prefix), src...)
194 file, err = parser.ParseFile(fset, filename, psrc, parserMode)
195 if err == nil {
196
197
198
199 psrc[len(prefix)-1] = '\n'
200 fset.File(file.Package).SetLinesForContent(psrc)
201
202
203
204 if containsMainFunc(file) {
205 return file, nil, nil
206 }
207
208 adjust := func(orig, src []byte) []byte {
209
210 src = src[len(prefix):]
211 return matchSpace(orig, src)
212 }
213 return file, adjust, nil
214 }
215
216
217
218 if !strings.Contains(err.Error(), "expected declaration") {
219 return nil, nil, err
220 }
221
222
223
224
225
226
227 fsrc := append(append([]byte("package p; func _() {"), src...), '}')
228 file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
229 if err == nil {
230 adjust := func(orig, src []byte) []byte {
231
232
233 src = src[len("package p\n\nfunc _() {"):]
234 src = src[:len(src)-len("}\n")]
235
236
237 src = bytes.ReplaceAll(src, []byte("\n\t"), []byte("\n"))
238 return matchSpace(orig, src)
239 }
240 return file, adjust, nil
241 }
242
243
244 return nil, nil, err
245 }
246
247
248
249 func containsMainFunc(file *ast.File) bool {
250 for _, decl := range file.Decls {
251 if f, ok := decl.(*ast.FuncDecl); ok {
252 if f.Name.Name != "main" {
253 continue
254 }
255
256 if len(f.Type.Params.List) != 0 {
257 continue
258 }
259
260 if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
261 continue
262 }
263
264 return true
265 }
266 }
267
268 return false
269 }
270
271 func cutSpace(b []byte) (before, middle, after []byte) {
272 i := 0
273 for i < len(b) && (b[i] == ' ' || b[i] == '\t' || b[i] == '\n') {
274 i++
275 }
276 j := len(b)
277 for j > 0 && (b[j-1] == ' ' || b[j-1] == '\t' || b[j-1] == '\n') {
278 j--
279 }
280 if i <= j {
281 return b[:i], b[i:j], b[j:]
282 }
283 return nil, nil, b[j:]
284 }
285
286
287
288
289
290
291
292 func matchSpace(orig []byte, src []byte) []byte {
293 before, _, after := cutSpace(orig)
294 i := bytes.LastIndex(before, []byte{'\n'})
295 before, indent := before[:i+1], before[i+1:]
296
297 _, src, _ = cutSpace(src)
298
299 var b bytes.Buffer
300 b.Write(before)
301 for len(src) > 0 {
302 line := src
303 if i := bytes.IndexByte(line, '\n'); i >= 0 {
304 line, src = line[:i+1], line[i+1:]
305 } else {
306 src = nil
307 }
308 if len(line) > 0 && line[0] != '\n' {
309 b.Write(indent)
310 }
311 b.Write(line)
312 }
313 b.Write(after)
314 return b.Bytes()
315 }
316
317 var impLine = regexp.MustCompile(`^\s+(?:[\w\.]+\s+)?"(.+?)"`)
318
319 func addImportSpaces(r io.Reader, breaks []string) ([]byte, error) {
320 var out bytes.Buffer
321 in := bufio.NewReader(r)
322 inImports := false
323 done := false
324 for {
325 s, err := in.ReadString('\n')
326 if err == io.EOF {
327 break
328 } else if err != nil {
329 return nil, err
330 }
331
332 if !inImports && !done && strings.HasPrefix(s, "import") {
333 inImports = true
334 }
335 if inImports && (strings.HasPrefix(s, "var") ||
336 strings.HasPrefix(s, "func") ||
337 strings.HasPrefix(s, "const") ||
338 strings.HasPrefix(s, "type")) {
339 done = true
340 inImports = false
341 }
342 if inImports && len(breaks) > 0 {
343 if m := impLine.FindStringSubmatch(s); m != nil {
344 if m[1] == breaks[0] {
345 out.WriteByte('\n')
346 breaks = breaks[1:]
347 }
348 }
349 }
350
351 fmt.Fprint(&out, s)
352 }
353 return out.Bytes(), nil
354 }
355
View as plain text