...

Source file src/golang.org/x/tools/go/ast/inspector/inspector_test.go

Documentation: golang.org/x/tools/go/ast/inspector

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inspector_test
     6  
     7  import (
     8  	"go/ast"
     9  	"go/build"
    10  	"go/parser"
    11  	"go/token"
    12  	"log"
    13  	"path/filepath"
    14  	"reflect"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  
    19  	"golang.org/x/tools/go/ast/inspector"
    20  )
    21  
    22  var netFiles []*ast.File
    23  
    24  func init() {
    25  	files, err := parseNetFiles()
    26  	if err != nil {
    27  		log.Fatal(err)
    28  	}
    29  	netFiles = files
    30  }
    31  
    32  func parseNetFiles() ([]*ast.File, error) {
    33  	pkg, err := build.Default.Import("net", "", 0)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	fset := token.NewFileSet()
    38  	var files []*ast.File
    39  	for _, filename := range pkg.GoFiles {
    40  		filename = filepath.Join(pkg.Dir, filename)
    41  		f, err := parser.ParseFile(fset, filename, nil, 0)
    42  		if err != nil {
    43  			return nil, err
    44  		}
    45  		files = append(files, f)
    46  	}
    47  	return files, nil
    48  }
    49  
    50  // TestAllNodes compares Inspector against ast.Inspect.
    51  func TestInspectAllNodes(t *testing.T) {
    52  	inspect := inspector.New(netFiles)
    53  
    54  	var nodesA []ast.Node
    55  	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
    56  		if push {
    57  			nodesA = append(nodesA, n)
    58  		}
    59  		return true
    60  	})
    61  	var nodesB []ast.Node
    62  	for _, f := range netFiles {
    63  		ast.Inspect(f, func(n ast.Node) bool {
    64  			if n != nil {
    65  				nodesB = append(nodesB, n)
    66  			}
    67  			return true
    68  		})
    69  	}
    70  	compare(t, nodesA, nodesB)
    71  }
    72  
    73  func TestInspectGenericNodes(t *testing.T) {
    74  	// src is using the 16 identifiers i0, i1, ... i15 so
    75  	// we can easily verify that we've found all of them.
    76  	const src = `package a
    77  
    78  type I interface { ~i0|i1 }
    79  
    80  type T[i2, i3 interface{ ~i4 }] struct {}
    81  
    82  func f[i5, i6 any]() {
    83  	_ = f[i7, i8]
    84  	var x T[i9, i10]
    85  }
    86  
    87  func (*T[i11, i12]) m()
    88  
    89  var _ i13[i14, i15]
    90  `
    91  	fset := token.NewFileSet()
    92  	f, _ := parser.ParseFile(fset, "a.go", src, 0)
    93  	inspect := inspector.New([]*ast.File{f})
    94  	found := make([]bool, 16)
    95  
    96  	indexListExprs := make(map[*ast.IndexListExpr]bool)
    97  
    98  	// Verify that we reach all i* identifiers, and collect IndexListExpr nodes.
    99  	inspect.Preorder(nil, func(n ast.Node) {
   100  		switch n := n.(type) {
   101  		case *ast.Ident:
   102  			if n.Name[0] == 'i' {
   103  				index, err := strconv.Atoi(n.Name[1:])
   104  				if err != nil {
   105  					t.Fatal(err)
   106  				}
   107  				found[index] = true
   108  			}
   109  		case *ast.IndexListExpr:
   110  			indexListExprs[n] = false
   111  		}
   112  	})
   113  	for i, v := range found {
   114  		if !v {
   115  			t.Errorf("missed identifier i%d", i)
   116  		}
   117  	}
   118  
   119  	// Verify that we can filter to IndexListExprs that we found in the first
   120  	// step.
   121  	if len(indexListExprs) == 0 {
   122  		t.Fatal("no index list exprs found")
   123  	}
   124  	inspect.Preorder([]ast.Node{&ast.IndexListExpr{}}, func(n ast.Node) {
   125  		ix := n.(*ast.IndexListExpr)
   126  		indexListExprs[ix] = true
   127  	})
   128  	for ix, v := range indexListExprs {
   129  		if !v {
   130  			t.Errorf("inspected node %v not filtered", ix)
   131  		}
   132  	}
   133  }
   134  
   135  // TestPruning compares Inspector against ast.Inspect,
   136  // pruning descent within ast.CallExpr nodes.
   137  func TestInspectPruning(t *testing.T) {
   138  	inspect := inspector.New(netFiles)
   139  
   140  	var nodesA []ast.Node
   141  	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
   142  		if push {
   143  			nodesA = append(nodesA, n)
   144  			_, isCall := n.(*ast.CallExpr)
   145  			return !isCall // don't descend into function calls
   146  		}
   147  		return false
   148  	})
   149  	var nodesB []ast.Node
   150  	for _, f := range netFiles {
   151  		ast.Inspect(f, func(n ast.Node) bool {
   152  			if n != nil {
   153  				nodesB = append(nodesB, n)
   154  				_, isCall := n.(*ast.CallExpr)
   155  				return !isCall // don't descend into function calls
   156  			}
   157  			return false
   158  		})
   159  	}
   160  	compare(t, nodesA, nodesB)
   161  }
   162  
   163  func compare(t *testing.T, nodesA, nodesB []ast.Node) {
   164  	if len(nodesA) != len(nodesB) {
   165  		t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
   166  	} else {
   167  		for i := range nodesA {
   168  			if a, b := nodesA[i], nodesB[i]; a != b {
   169  				t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
   170  			}
   171  		}
   172  	}
   173  }
   174  
   175  func TestTypeFiltering(t *testing.T) {
   176  	const src = `package a
   177  func f() {
   178  	print("hi")
   179  	panic("oops")
   180  }
   181  `
   182  	fset := token.NewFileSet()
   183  	f, _ := parser.ParseFile(fset, "a.go", src, 0)
   184  	inspect := inspector.New([]*ast.File{f})
   185  
   186  	var got []string
   187  	fn := func(n ast.Node, push bool) bool {
   188  		if push {
   189  			got = append(got, typeOf(n))
   190  		}
   191  		return true
   192  	}
   193  
   194  	// no type filtering
   195  	inspect.Nodes(nil, fn)
   196  	if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
   197  		t.Errorf("inspect: got %s, want %s", got, want)
   198  	}
   199  
   200  	// type filtering
   201  	nodeTypes := []ast.Node{
   202  		(*ast.BasicLit)(nil),
   203  		(*ast.CallExpr)(nil),
   204  	}
   205  	got = nil
   206  	inspect.Nodes(nodeTypes, fn)
   207  	if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
   208  		t.Errorf("inspect: got %s, want %s", got, want)
   209  	}
   210  
   211  	// inspect with stack
   212  	got = nil
   213  	inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
   214  		if push {
   215  			var line []string
   216  			for _, n := range stack {
   217  				line = append(line, typeOf(n))
   218  			}
   219  			got = append(got, strings.Join(line, " "))
   220  		}
   221  		return true
   222  	})
   223  	want := []string{
   224  		"File FuncDecl BlockStmt ExprStmt CallExpr",
   225  		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
   226  		"File FuncDecl BlockStmt ExprStmt CallExpr",
   227  		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
   228  	}
   229  	if !reflect.DeepEqual(got, want) {
   230  		t.Errorf("inspect: got %s, want %s", got, want)
   231  	}
   232  }
   233  
   234  func typeOf(n ast.Node) string {
   235  	return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
   236  }
   237  
   238  // The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x,
   239  // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
   240  // traversals.
   241  //
   242  // BenchmarkASTInspect     1.0 ms
   243  // BenchmarkNewInspector   2.2 ms
   244  // BenchmarkInspect        0.39ms
   245  // BenchmarkInspectFilter  0.01ms
   246  // BenchmarkInspectCalls   0.14ms
   247  
   248  func BenchmarkNewInspector(b *testing.B) {
   249  	// Measure one-time construction overhead.
   250  	for i := 0; i < b.N; i++ {
   251  		inspector.New(netFiles)
   252  	}
   253  }
   254  
   255  func BenchmarkInspect(b *testing.B) {
   256  	b.StopTimer()
   257  	inspect := inspector.New(netFiles)
   258  	b.StartTimer()
   259  
   260  	// Measure marginal cost of traversal.
   261  	var ndecls, nlits int
   262  	for i := 0; i < b.N; i++ {
   263  		inspect.Preorder(nil, func(n ast.Node) {
   264  			switch n.(type) {
   265  			case *ast.FuncDecl:
   266  				ndecls++
   267  			case *ast.FuncLit:
   268  				nlits++
   269  			}
   270  		})
   271  	}
   272  }
   273  
   274  func BenchmarkInspectFilter(b *testing.B) {
   275  	b.StopTimer()
   276  	inspect := inspector.New(netFiles)
   277  	b.StartTimer()
   278  
   279  	// Measure marginal cost of traversal.
   280  	nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)}
   281  	var ndecls, nlits int
   282  	for i := 0; i < b.N; i++ {
   283  		inspect.Preorder(nodeFilter, func(n ast.Node) {
   284  			switch n.(type) {
   285  			case *ast.FuncDecl:
   286  				ndecls++
   287  			case *ast.FuncLit:
   288  				nlits++
   289  			}
   290  		})
   291  	}
   292  }
   293  
   294  func BenchmarkInspectCalls(b *testing.B) {
   295  	b.StopTimer()
   296  	inspect := inspector.New(netFiles)
   297  	b.StartTimer()
   298  
   299  	// Measure marginal cost of traversal.
   300  	nodeFilter := []ast.Node{(*ast.CallExpr)(nil)}
   301  	var ncalls int
   302  	for i := 0; i < b.N; i++ {
   303  		inspect.Preorder(nodeFilter, func(n ast.Node) {
   304  			_ = n.(*ast.CallExpr)
   305  			ncalls++
   306  		})
   307  	}
   308  }
   309  
   310  func BenchmarkASTInspect(b *testing.B) {
   311  	var ndecls, nlits int
   312  	for i := 0; i < b.N; i++ {
   313  		for _, f := range netFiles {
   314  			ast.Inspect(f, func(n ast.Node) bool {
   315  				switch n.(type) {
   316  				case *ast.FuncDecl:
   317  					ndecls++
   318  				case *ast.FuncLit:
   319  					nlits++
   320  				}
   321  				return true
   322  			})
   323  		}
   324  	}
   325  }
   326  

View as plain text