1
2
3
4
5 package inspector_test
6
7 import (
8 "go/ast"
9 "go/build"
10 "go/parser"
11 "go/token"
12 "log"
13 "path/filepath"
14 "reflect"
15 "strconv"
16 "strings"
17 "testing"
18
19 "golang.org/x/tools/go/ast/inspector"
20 )
21
22 var netFiles []*ast.File
23
24 func init() {
25 files, err := parseNetFiles()
26 if err != nil {
27 log.Fatal(err)
28 }
29 netFiles = files
30 }
31
32 func parseNetFiles() ([]*ast.File, error) {
33 pkg, err := build.Default.Import("net", "", 0)
34 if err != nil {
35 return nil, err
36 }
37 fset := token.NewFileSet()
38 var files []*ast.File
39 for _, filename := range pkg.GoFiles {
40 filename = filepath.Join(pkg.Dir, filename)
41 f, err := parser.ParseFile(fset, filename, nil, 0)
42 if err != nil {
43 return nil, err
44 }
45 files = append(files, f)
46 }
47 return files, nil
48 }
49
50
51 func TestInspectAllNodes(t *testing.T) {
52 inspect := inspector.New(netFiles)
53
54 var nodesA []ast.Node
55 inspect.Nodes(nil, func(n ast.Node, push bool) bool {
56 if push {
57 nodesA = append(nodesA, n)
58 }
59 return true
60 })
61 var nodesB []ast.Node
62 for _, f := range netFiles {
63 ast.Inspect(f, func(n ast.Node) bool {
64 if n != nil {
65 nodesB = append(nodesB, n)
66 }
67 return true
68 })
69 }
70 compare(t, nodesA, nodesB)
71 }
72
73 func TestInspectGenericNodes(t *testing.T) {
74
75
76 const src = `package a
77
78 type I interface { ~i0|i1 }
79
80 type T[i2, i3 interface{ ~i4 }] struct {}
81
82 func f[i5, i6 any]() {
83 _ = f[i7, i8]
84 var x T[i9, i10]
85 }
86
87 func (*T[i11, i12]) m()
88
89 var _ i13[i14, i15]
90 `
91 fset := token.NewFileSet()
92 f, _ := parser.ParseFile(fset, "a.go", src, 0)
93 inspect := inspector.New([]*ast.File{f})
94 found := make([]bool, 16)
95
96 indexListExprs := make(map[*ast.IndexListExpr]bool)
97
98
99 inspect.Preorder(nil, func(n ast.Node) {
100 switch n := n.(type) {
101 case *ast.Ident:
102 if n.Name[0] == 'i' {
103 index, err := strconv.Atoi(n.Name[1:])
104 if err != nil {
105 t.Fatal(err)
106 }
107 found[index] = true
108 }
109 case *ast.IndexListExpr:
110 indexListExprs[n] = false
111 }
112 })
113 for i, v := range found {
114 if !v {
115 t.Errorf("missed identifier i%d", i)
116 }
117 }
118
119
120
121 if len(indexListExprs) == 0 {
122 t.Fatal("no index list exprs found")
123 }
124 inspect.Preorder([]ast.Node{&ast.IndexListExpr{}}, func(n ast.Node) {
125 ix := n.(*ast.IndexListExpr)
126 indexListExprs[ix] = true
127 })
128 for ix, v := range indexListExprs {
129 if !v {
130 t.Errorf("inspected node %v not filtered", ix)
131 }
132 }
133 }
134
135
136
137 func TestInspectPruning(t *testing.T) {
138 inspect := inspector.New(netFiles)
139
140 var nodesA []ast.Node
141 inspect.Nodes(nil, func(n ast.Node, push bool) bool {
142 if push {
143 nodesA = append(nodesA, n)
144 _, isCall := n.(*ast.CallExpr)
145 return !isCall
146 }
147 return false
148 })
149 var nodesB []ast.Node
150 for _, f := range netFiles {
151 ast.Inspect(f, func(n ast.Node) bool {
152 if n != nil {
153 nodesB = append(nodesB, n)
154 _, isCall := n.(*ast.CallExpr)
155 return !isCall
156 }
157 return false
158 })
159 }
160 compare(t, nodesA, nodesB)
161 }
162
163 func compare(t *testing.T, nodesA, nodesB []ast.Node) {
164 if len(nodesA) != len(nodesB) {
165 t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
166 } else {
167 for i := range nodesA {
168 if a, b := nodesA[i], nodesB[i]; a != b {
169 t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
170 }
171 }
172 }
173 }
174
175 func TestTypeFiltering(t *testing.T) {
176 const src = `package a
177 func f() {
178 print("hi")
179 panic("oops")
180 }
181 `
182 fset := token.NewFileSet()
183 f, _ := parser.ParseFile(fset, "a.go", src, 0)
184 inspect := inspector.New([]*ast.File{f})
185
186 var got []string
187 fn := func(n ast.Node, push bool) bool {
188 if push {
189 got = append(got, typeOf(n))
190 }
191 return true
192 }
193
194
195 inspect.Nodes(nil, fn)
196 if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
197 t.Errorf("inspect: got %s, want %s", got, want)
198 }
199
200
201 nodeTypes := []ast.Node{
202 (*ast.BasicLit)(nil),
203 (*ast.CallExpr)(nil),
204 }
205 got = nil
206 inspect.Nodes(nodeTypes, fn)
207 if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
208 t.Errorf("inspect: got %s, want %s", got, want)
209 }
210
211
212 got = nil
213 inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
214 if push {
215 var line []string
216 for _, n := range stack {
217 line = append(line, typeOf(n))
218 }
219 got = append(got, strings.Join(line, " "))
220 }
221 return true
222 })
223 want := []string{
224 "File FuncDecl BlockStmt ExprStmt CallExpr",
225 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
226 "File FuncDecl BlockStmt ExprStmt CallExpr",
227 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
228 }
229 if !reflect.DeepEqual(got, want) {
230 t.Errorf("inspect: got %s, want %s", got, want)
231 }
232 }
233
234 func typeOf(n ast.Node) string {
235 return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
236 }
237
238
239
240
241
242
243
244
245
246
247
248 func BenchmarkNewInspector(b *testing.B) {
249
250 for i := 0; i < b.N; i++ {
251 inspector.New(netFiles)
252 }
253 }
254
255 func BenchmarkInspect(b *testing.B) {
256 b.StopTimer()
257 inspect := inspector.New(netFiles)
258 b.StartTimer()
259
260
261 var ndecls, nlits int
262 for i := 0; i < b.N; i++ {
263 inspect.Preorder(nil, func(n ast.Node) {
264 switch n.(type) {
265 case *ast.FuncDecl:
266 ndecls++
267 case *ast.FuncLit:
268 nlits++
269 }
270 })
271 }
272 }
273
274 func BenchmarkInspectFilter(b *testing.B) {
275 b.StopTimer()
276 inspect := inspector.New(netFiles)
277 b.StartTimer()
278
279
280 nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)}
281 var ndecls, nlits int
282 for i := 0; i < b.N; i++ {
283 inspect.Preorder(nodeFilter, func(n ast.Node) {
284 switch n.(type) {
285 case *ast.FuncDecl:
286 ndecls++
287 case *ast.FuncLit:
288 nlits++
289 }
290 })
291 }
292 }
293
294 func BenchmarkInspectCalls(b *testing.B) {
295 b.StopTimer()
296 inspect := inspector.New(netFiles)
297 b.StartTimer()
298
299
300 nodeFilter := []ast.Node{(*ast.CallExpr)(nil)}
301 var ncalls int
302 for i := 0; i < b.N; i++ {
303 inspect.Preorder(nodeFilter, func(n ast.Node) {
304 _ = n.(*ast.CallExpr)
305 ncalls++
306 })
307 }
308 }
309
310 func BenchmarkASTInspect(b *testing.B) {
311 var ndecls, nlits int
312 for i := 0; i < b.N; i++ {
313 for _, f := range netFiles {
314 ast.Inspect(f, func(n ast.Node) bool {
315 switch n.(type) {
316 case *ast.FuncDecl:
317 ndecls++
318 case *ast.FuncLit:
319 nlits++
320 }
321 return true
322 })
323 }
324 }
325 }
326
View as plain text