...

Source file src/golang.org/x/tools/go/ssa/source_test.go

Documentation: golang.org/x/tools/go/ssa

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa_test
     6  
     7  // This file defines tests of source-level debugging utilities.
     8  
     9  import (
    10  	"fmt"
    11  	"go/ast"
    12  	"go/constant"
    13  	"go/parser"
    14  	"go/token"
    15  	"go/types"
    16  	"os"
    17  	"runtime"
    18  	"strings"
    19  	"testing"
    20  
    21  	"golang.org/x/tools/go/ast/astutil"
    22  	"golang.org/x/tools/go/expect"
    23  	"golang.org/x/tools/go/loader"
    24  	"golang.org/x/tools/go/ssa"
    25  	"golang.org/x/tools/go/ssa/ssautil"
    26  )
    27  
    28  func TestObjValueLookup(t *testing.T) {
    29  	if runtime.GOOS == "android" {
    30  		t.Skipf("no testdata directory on %s", runtime.GOOS)
    31  	}
    32  
    33  	conf := loader.Config{ParserMode: parser.ParseComments}
    34  	src, err := os.ReadFile("testdata/objlookup.go")
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	readFile := func(filename string) ([]byte, error) { return src, nil }
    39  	f, err := conf.ParseFile("testdata/objlookup.go", src)
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	conf.CreateFromFiles("main", f)
    44  
    45  	// Maps each var Ident (represented "name:linenum") to the
    46  	// kind of ssa.Value we expect (represented "Constant", "&Alloc").
    47  	expectations := make(map[string]string)
    48  
    49  	// Each note of the form @ssa(x, "BinOp") in testdata/objlookup.go
    50  	// specifies an expectation that an object named x declared on the
    51  	// same line is associated with an ssa.Value of type *ssa.BinOp.
    52  	notes, err := expect.ExtractGo(conf.Fset, f)
    53  	if err != nil {
    54  		t.Fatal(err)
    55  	}
    56  	for _, n := range notes {
    57  		if n.Name != "ssa" {
    58  			t.Errorf("%v: unexpected note type %q, want \"ssa\"", conf.Fset.Position(n.Pos), n.Name)
    59  			continue
    60  		}
    61  		if len(n.Args) != 2 {
    62  			t.Errorf("%v: ssa has %d args, want 2", conf.Fset.Position(n.Pos), len(n.Args))
    63  			continue
    64  		}
    65  		ident, ok := n.Args[0].(expect.Identifier)
    66  		if !ok {
    67  			t.Errorf("%v: got %v for arg 1, want identifier", conf.Fset.Position(n.Pos), n.Args[0])
    68  			continue
    69  		}
    70  		exp, ok := n.Args[1].(string)
    71  		if !ok {
    72  			t.Errorf("%v: got %v for arg 2, want string", conf.Fset.Position(n.Pos), n.Args[1])
    73  			continue
    74  		}
    75  		p, _, err := expect.MatchBefore(conf.Fset, readFile, n.Pos, string(ident))
    76  		if err != nil {
    77  			t.Error(err)
    78  			continue
    79  		}
    80  		pos := conf.Fset.Position(p)
    81  		key := fmt.Sprintf("%s:%d", ident, pos.Line)
    82  		expectations[key] = exp
    83  	}
    84  
    85  	iprog, err := conf.Load()
    86  	if err != nil {
    87  		t.Error(err)
    88  		return
    89  	}
    90  
    91  	prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0) /*|ssa.PrintFunctions*/)
    92  	mainInfo := iprog.Created[0]
    93  	mainPkg := prog.Package(mainInfo.Pkg)
    94  	mainPkg.SetDebugMode(true)
    95  	mainPkg.Build()
    96  
    97  	var varIds []*ast.Ident
    98  	var varObjs []*types.Var
    99  	for id, obj := range mainInfo.Defs {
   100  		// Check invariants for func and const objects.
   101  		switch obj := obj.(type) {
   102  		case *types.Func:
   103  			checkFuncValue(t, prog, obj)
   104  
   105  		case *types.Const:
   106  			checkConstValue(t, prog, obj)
   107  
   108  		case *types.Var:
   109  			if id.Name == "_" {
   110  				continue
   111  			}
   112  			varIds = append(varIds, id)
   113  			varObjs = append(varObjs, obj)
   114  		}
   115  	}
   116  	for id, obj := range mainInfo.Uses {
   117  		if obj, ok := obj.(*types.Var); ok {
   118  			varIds = append(varIds, id)
   119  			varObjs = append(varObjs, obj)
   120  		}
   121  	}
   122  
   123  	// Check invariants for var objects.
   124  	// The result varies based on the specific Ident.
   125  	for i, id := range varIds {
   126  		obj := varObjs[i]
   127  		ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos())
   128  		pos := prog.Fset.Position(id.Pos())
   129  		exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)]
   130  		if exp == "" {
   131  			t.Errorf("%s: no expectation for var ident %s ", pos, id.Name)
   132  			continue
   133  		}
   134  		wantAddr := false
   135  		if exp[0] == '&' {
   136  			wantAddr = true
   137  			exp = exp[1:]
   138  		}
   139  		checkVarValue(t, prog, mainPkg, ref, obj, exp, wantAddr)
   140  	}
   141  }
   142  
   143  func checkFuncValue(t *testing.T, prog *ssa.Program, obj *types.Func) {
   144  	fn := prog.FuncValue(obj)
   145  	// fmt.Printf("FuncValue(%s) = %s\n", obj, fn) // debugging
   146  	if fn == nil {
   147  		if obj.Name() != "interfaceMethod" {
   148  			t.Errorf("FuncValue(%s) == nil", obj)
   149  		}
   150  		return
   151  	}
   152  	if fnobj := fn.Object(); fnobj != obj {
   153  		t.Errorf("FuncValue(%s).Object() == %s; value was %s",
   154  			obj, fnobj, fn.Name())
   155  		return
   156  	}
   157  	if !types.Identical(fn.Type(), obj.Type()) {
   158  		t.Errorf("FuncValue(%s).Type() == %s", obj, fn.Type())
   159  		return
   160  	}
   161  }
   162  
   163  func checkConstValue(t *testing.T, prog *ssa.Program, obj *types.Const) {
   164  	c := prog.ConstValue(obj)
   165  	// fmt.Printf("ConstValue(%s) = %s\n", obj, c) // debugging
   166  	if c == nil {
   167  		t.Errorf("ConstValue(%s) == nil", obj)
   168  		return
   169  	}
   170  	if !types.Identical(c.Type(), obj.Type()) {
   171  		t.Errorf("ConstValue(%s).Type() == %s", obj, c.Type())
   172  		return
   173  	}
   174  	if obj.Name() != "nil" {
   175  		if !constant.Compare(c.Value, token.EQL, obj.Val()) {
   176  			t.Errorf("ConstValue(%s).Value (%s) != %s",
   177  				obj, c.Value, obj.Val())
   178  			return
   179  		}
   180  	}
   181  }
   182  
   183  func checkVarValue(t *testing.T, prog *ssa.Program, pkg *ssa.Package, ref []ast.Node, obj *types.Var, expKind string, wantAddr bool) {
   184  	// The prefix of all assertions messages.
   185  	prefix := fmt.Sprintf("VarValue(%s @ L%d)",
   186  		obj, prog.Fset.Position(ref[0].Pos()).Line)
   187  
   188  	v, gotAddr := prog.VarValue(obj, pkg, ref)
   189  
   190  	// Kind is the concrete type of the ssa Value.
   191  	gotKind := "nil"
   192  	if v != nil {
   193  		gotKind = fmt.Sprintf("%T", v)[len("*ssa."):]
   194  	}
   195  
   196  	// fmt.Printf("%s = %v (kind %q; expect %q) wantAddr=%t gotAddr=%t\n", prefix, v, gotKind, expKind, wantAddr, gotAddr) // debugging
   197  
   198  	// Check the kinds match.
   199  	// "nil" indicates expected failure (e.g. optimized away).
   200  	if expKind != gotKind {
   201  		t.Errorf("%s concrete type == %s, want %s", prefix, gotKind, expKind)
   202  	}
   203  
   204  	// Check the types match.
   205  	// If wantAddr, the expected type is the object's address.
   206  	if v != nil {
   207  		expType := obj.Type()
   208  		if wantAddr {
   209  			expType = types.NewPointer(expType)
   210  			if !gotAddr {
   211  				t.Errorf("%s: got value, want address", prefix)
   212  			}
   213  		} else if gotAddr {
   214  			t.Errorf("%s: got address, want value", prefix)
   215  		}
   216  		if !types.Identical(v.Type(), expType) {
   217  			t.Errorf("%s.Type() == %s, want %s", prefix, v.Type(), expType)
   218  		}
   219  	}
   220  }
   221  
   222  // Ensure that, in debug mode, we can determine the ssa.Value
   223  // corresponding to every ast.Expr.
   224  func TestValueForExpr(t *testing.T) {
   225  	testValueForExpr(t, "testdata/valueforexpr.go")
   226  }
   227  
   228  func TestValueForExprStructConv(t *testing.T) {
   229  	testValueForExpr(t, "testdata/structconv.go")
   230  }
   231  
   232  func testValueForExpr(t *testing.T, testfile string) {
   233  	if runtime.GOOS == "android" {
   234  		t.Skipf("no testdata dir on %s", runtime.GOOS)
   235  	}
   236  
   237  	conf := loader.Config{ParserMode: parser.ParseComments}
   238  	f, err := conf.ParseFile(testfile, nil)
   239  	if err != nil {
   240  		t.Error(err)
   241  		return
   242  	}
   243  	conf.CreateFromFiles("main", f)
   244  
   245  	iprog, err := conf.Load()
   246  	if err != nil {
   247  		t.Error(err)
   248  		return
   249  	}
   250  
   251  	mainInfo := iprog.Created[0]
   252  
   253  	prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0))
   254  	mainPkg := prog.Package(mainInfo.Pkg)
   255  	mainPkg.SetDebugMode(true)
   256  	mainPkg.Build()
   257  
   258  	if false {
   259  		// debugging
   260  		for _, mem := range mainPkg.Members {
   261  			if fn, ok := mem.(*ssa.Function); ok {
   262  				fn.WriteTo(os.Stderr)
   263  			}
   264  		}
   265  	}
   266  
   267  	var parenExprs []*ast.ParenExpr
   268  	ast.Inspect(f, func(n ast.Node) bool {
   269  		if n != nil {
   270  			if e, ok := n.(*ast.ParenExpr); ok {
   271  				parenExprs = append(parenExprs, e)
   272  			}
   273  		}
   274  		return true
   275  	})
   276  
   277  	notes, err := expect.ExtractGo(prog.Fset, f)
   278  	if err != nil {
   279  		t.Fatal(err)
   280  	}
   281  	for _, n := range notes {
   282  		want := n.Name
   283  		if want == "nil" {
   284  			want = "<nil>"
   285  		}
   286  		position := prog.Fset.Position(n.Pos)
   287  		var e ast.Expr
   288  		for _, paren := range parenExprs {
   289  			if paren.Pos() > n.Pos {
   290  				e = paren.X
   291  				break
   292  			}
   293  		}
   294  		if e == nil {
   295  			t.Errorf("%s: note doesn't precede ParenExpr: %q", position, want)
   296  			continue
   297  		}
   298  
   299  		path, _ := astutil.PathEnclosingInterval(f, n.Pos, n.Pos)
   300  		if path == nil {
   301  			t.Errorf("%s: can't find AST path from root to comment: %s", position, want)
   302  			continue
   303  		}
   304  
   305  		fn := ssa.EnclosingFunction(mainPkg, path)
   306  		if fn == nil {
   307  			t.Errorf("%s: can't find enclosing function", position)
   308  			continue
   309  		}
   310  
   311  		v, gotAddr := fn.ValueForExpr(e) // (may be nil)
   312  		got := strings.TrimPrefix(fmt.Sprintf("%T", v), "*ssa.")
   313  		if got != want {
   314  			t.Errorf("%s: got value %q, want %q", position, got, want)
   315  		}
   316  		if v != nil {
   317  			T := v.Type()
   318  			if gotAddr {
   319  				T = T.Underlying().(*types.Pointer).Elem() // deref
   320  			}
   321  			if !types.Identical(T, mainInfo.TypeOf(e)) {
   322  				t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T)
   323  			}
   324  		}
   325  	}
   326  }
   327  
   328  // findInterval parses input and returns the [start, end) positions of
   329  // the first occurrence of substr in input.  f==nil indicates failure;
   330  // an error has already been reported in that case.
   331  func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *ast.File, start, end token.Pos) {
   332  	f, err := parser.ParseFile(fset, "<input>", input, 0)
   333  	if err != nil {
   334  		t.Errorf("parse error: %s", err)
   335  		return
   336  	}
   337  
   338  	i := strings.Index(input, substr)
   339  	if i < 0 {
   340  		t.Errorf("%q is not a substring of input", substr)
   341  		f = nil
   342  		return
   343  	}
   344  
   345  	filePos := fset.File(f.Package)
   346  	return f, filePos.Pos(i), filePos.Pos(i + len(substr))
   347  }
   348  
   349  func TestEnclosingFunction(t *testing.T) {
   350  	tests := []struct {
   351  		input  string // the input file
   352  		substr string // first occurrence of this string denotes interval
   353  		fn     string // name of expected containing function
   354  	}{
   355  		// We use distinctive numbers as syntactic landmarks.
   356  
   357  		// Ordinary function:
   358  		{`package main
   359  		  func f() { println(1003) }`,
   360  			"100", "main.f"},
   361  		// Methods:
   362  		{`package main
   363                    type T int
   364  		  func (t T) f() { println(200) }`,
   365  			"200", "(main.T).f"},
   366  		// Function literal:
   367  		{`package main
   368  		  func f() { println(func() { print(300) }) }`,
   369  			"300", "main.f$1"},
   370  		// Doubly nested
   371  		{`package main
   372  		  func f() { println(func() { print(func() { print(350) })})}`,
   373  			"350", "main.f$1$1"},
   374  		// Implicit init for package-level var initializer.
   375  		{"package main; var a = 400", "400", "main.init"},
   376  		// No code for constants:
   377  		{"package main; const a = 500", "500", "(none)"},
   378  		// Explicit init()
   379  		{"package main; func init() { println(600) }", "600", "main.init#1"},
   380  		// Multiple explicit init functions:
   381  		{`package main
   382  		  func init() { println("foo") }
   383  		  func init() { println(800) }`,
   384  			"800", "main.init#2"},
   385  		// init() containing FuncLit.
   386  		{`package main
   387  		  func init() { println(func(){print(900)}) }`,
   388  			"900", "main.init#1$1"},
   389  		// generics
   390  		{`package main
   391  			type S[T any] struct{}
   392  			func (*S[T]) Foo() { println(1000) }
   393  			type P[T any] struct{ *S[T] }`,
   394  			"1000", "(*main.S[T]).Foo",
   395  		},
   396  	}
   397  	for _, test := range tests {
   398  		conf := loader.Config{Fset: token.NewFileSet()}
   399  		f, start, end := findInterval(t, conf.Fset, test.input, test.substr)
   400  		if f == nil {
   401  			continue
   402  		}
   403  		path, exact := astutil.PathEnclosingInterval(f, start, end)
   404  		if !exact {
   405  			t.Errorf("EnclosingFunction(%q) not exact", test.substr)
   406  			continue
   407  		}
   408  
   409  		conf.CreateFromFiles("main", f)
   410  
   411  		iprog, err := conf.Load()
   412  		if err != nil {
   413  			t.Error(err)
   414  			continue
   415  		}
   416  		prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0))
   417  		pkg := prog.Package(iprog.Created[0].Pkg)
   418  		pkg.Build()
   419  
   420  		name := "(none)"
   421  		fn := ssa.EnclosingFunction(pkg, path)
   422  		if fn != nil {
   423  			name = fn.String()
   424  		}
   425  
   426  		if name != test.fn {
   427  			t.Errorf("EnclosingFunction(%q in %q) got %s, want %s",
   428  				test.substr, test.input, name, test.fn)
   429  			continue
   430  		}
   431  
   432  		// While we're here: test HasEnclosingFunction.
   433  		if has := ssa.HasEnclosingFunction(pkg, path); has != (fn != nil) {
   434  			t.Errorf("HasEnclosingFunction(%q in %q) got %v, want %v",
   435  				test.substr, test.input, has, fn != nil)
   436  			continue
   437  		}
   438  	}
   439  }
   440  

View as plain text