1
2
3
4
5 package astutil_test
6
7 import (
8 "bytes"
9 "go/ast"
10 "go/format"
11 "go/parser"
12 "go/token"
13 "testing"
14
15 "golang.org/x/tools/go/ast/astutil"
16 )
17
18 type rewriteTest struct {
19 name string
20 orig, want string
21 pre, post astutil.ApplyFunc
22 }
23
24 var rewriteTests = []rewriteTest{
25 {name: "nop", orig: "package p\n", want: "package p\n"},
26
27 {name: "replace",
28 orig: `package p
29
30 var x int
31 `,
32 want: `package p
33
34 var t T
35 `,
36 post: func(c *astutil.Cursor) bool {
37 if _, ok := c.Node().(*ast.ValueSpec); ok {
38 c.Replace(valspec("t", "T"))
39 return false
40 }
41 return true
42 },
43 },
44
45 {name: "set doc strings",
46 orig: `package p
47
48 const z = 0
49
50 type T struct{}
51
52 var x int
53 `,
54 want: `package p
55 // a foo is a foo
56 const z = 0
57 // a foo is a foo
58 type T struct{}
59 // a foo is a foo
60 var x int
61 `,
62 post: func(c *astutil.Cursor) bool {
63 if _, ok := c.Parent().(*ast.GenDecl); ok && c.Name() == "Doc" && c.Node() == nil {
64 c.Replace(&ast.CommentGroup{List: []*ast.Comment{{Text: "// a foo is a foo"}}})
65 }
66 return true
67 },
68 },
69
70 {name: "insert names",
71 orig: `package p
72
73 const a = 1
74 `,
75 want: `package p
76
77 const a, b, c = 1, 2, 3
78 `,
79 pre: func(c *astutil.Cursor) bool {
80 if _, ok := c.Parent().(*ast.ValueSpec); ok {
81 switch c.Name() {
82 case "Names":
83 c.InsertAfter(ast.NewIdent("c"))
84 c.InsertAfter(ast.NewIdent("b"))
85 case "Values":
86 c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "3"})
87 c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "2"})
88 }
89 }
90 return true
91 },
92 },
93
94 {name: "insert",
95 orig: `package p
96
97 var (
98 x int
99 y int
100 )
101 `,
102 want: `package p
103
104 var before1 int
105 var before2 int
106
107 var (
108 x int
109 y int
110 )
111 var after2 int
112 var after1 int
113 `,
114 pre: func(c *astutil.Cursor) bool {
115 if _, ok := c.Node().(*ast.GenDecl); ok {
116 c.InsertBefore(vardecl("before1", "int"))
117 c.InsertAfter(vardecl("after1", "int"))
118 c.InsertAfter(vardecl("after2", "int"))
119 c.InsertBefore(vardecl("before2", "int"))
120 }
121 return true
122 },
123 },
124
125 {name: "delete",
126 orig: `package p
127
128 var x int
129 var y int
130 var z int
131 `,
132 want: `package p
133
134 var y int
135 var z int
136 `,
137 pre: func(c *astutil.Cursor) bool {
138 n := c.Node()
139 if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
140 c.Delete()
141 }
142 return true
143 },
144 },
145
146 {name: "insertafter-delete",
147 orig: `package p
148
149 var x int
150 var y int
151 var z int
152 `,
153 want: `package p
154
155 var x1 int
156
157 var y int
158 var z int
159 `,
160 pre: func(c *astutil.Cursor) bool {
161 n := c.Node()
162 if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
163 c.InsertAfter(vardecl("x1", "int"))
164 c.Delete()
165 }
166 return true
167 },
168 },
169
170 {name: "delete-insertafter",
171 orig: `package p
172
173 var x int
174 var y int
175 var z int
176 `,
177 want: `package p
178
179 var y int
180 var x1 int
181 var z int
182 `,
183 pre: func(c *astutil.Cursor) bool {
184 n := c.Node()
185 if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
186 c.Delete()
187
188 c.InsertAfter(vardecl("x1", "int"))
189 }
190 return true
191 },
192 },
193 {
194 name: "replace",
195 orig: `package p
196
197 type T[P1, P2 any] int
198
199 type R T[int, string]
200
201 func F[Q1 any](q Q1) {}
202 `,
203
204
205 want: `package p
206
207 type S[R1, P2 any] int32
208
209 type R S[int32, string]
210
211 func F[X1 any](q X1,) {}
212 `,
213 post: func(c *astutil.Cursor) bool {
214 if ident, ok := c.Node().(*ast.Ident); ok {
215 switch ident.Name {
216 case "int":
217 c.Replace(ast.NewIdent("int32"))
218 case "T":
219 c.Replace(ast.NewIdent("S"))
220 case "P1":
221 c.Replace(ast.NewIdent("R1"))
222 case "Q1":
223 c.Replace(ast.NewIdent("X1"))
224 }
225 }
226 return true
227 },
228 },
229 }
230
231 func valspec(name, typ string) *ast.ValueSpec {
232 return &ast.ValueSpec{Names: []*ast.Ident{ast.NewIdent(name)},
233 Type: ast.NewIdent(typ),
234 }
235 }
236
237 func vardecl(name, typ string) *ast.GenDecl {
238 return &ast.GenDecl{
239 Tok: token.VAR,
240 Specs: []ast.Spec{valspec(name, typ)},
241 }
242 }
243
244 func TestRewrite(t *testing.T) {
245 t.Run("*", func(t *testing.T) {
246 for _, test := range rewriteTests {
247 test := test
248 t.Run(test.name, func(t *testing.T) {
249 t.Parallel()
250 fset := token.NewFileSet()
251 f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
252 if err != nil {
253 t.Fatal(err)
254 }
255 n := astutil.Apply(f, test.pre, test.post)
256 var buf bytes.Buffer
257 if err := format.Node(&buf, fset, n); err != nil {
258 t.Fatal(err)
259 }
260 got := buf.String()
261 if got != test.want {
262 t.Errorf("got:\n\n%s\nwant:\n\n%s\n", got, test.want)
263 }
264 })
265 }
266 })
267 }
268
269 var sink ast.Node
270
271 func BenchmarkRewrite(b *testing.B) {
272 for _, test := range rewriteTests {
273 b.Run(test.name, func(b *testing.B) {
274 for i := 0; i < b.N; i++ {
275 b.StopTimer()
276 fset := token.NewFileSet()
277 f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
278 if err != nil {
279 b.Fatal(err)
280 }
281 b.StartTimer()
282 sink = astutil.Apply(f, test.pre, test.post)
283 }
284 })
285 }
286 }
287
View as plain text