1
2
3
4
5 package ssa_test
6
7
8
9 import (
10 "fmt"
11 "go/ast"
12 "go/constant"
13 "go/parser"
14 "go/token"
15 "go/types"
16 "os"
17 "runtime"
18 "strings"
19 "testing"
20
21 "golang.org/x/tools/go/ast/astutil"
22 "golang.org/x/tools/go/expect"
23 "golang.org/x/tools/go/loader"
24 "golang.org/x/tools/go/ssa"
25 "golang.org/x/tools/go/ssa/ssautil"
26 )
27
28 func TestObjValueLookup(t *testing.T) {
29 if runtime.GOOS == "android" {
30 t.Skipf("no testdata directory on %s", runtime.GOOS)
31 }
32
33 conf := loader.Config{ParserMode: parser.ParseComments}
34 src, err := os.ReadFile("testdata/objlookup.go")
35 if err != nil {
36 t.Fatal(err)
37 }
38 readFile := func(filename string) ([]byte, error) { return src, nil }
39 f, err := conf.ParseFile("testdata/objlookup.go", src)
40 if err != nil {
41 t.Fatal(err)
42 }
43 conf.CreateFromFiles("main", f)
44
45
46
47 expectations := make(map[string]string)
48
49
50
51
52 notes, err := expect.ExtractGo(conf.Fset, f)
53 if err != nil {
54 t.Fatal(err)
55 }
56 for _, n := range notes {
57 if n.Name != "ssa" {
58 t.Errorf("%v: unexpected note type %q, want \"ssa\"", conf.Fset.Position(n.Pos), n.Name)
59 continue
60 }
61 if len(n.Args) != 2 {
62 t.Errorf("%v: ssa has %d args, want 2", conf.Fset.Position(n.Pos), len(n.Args))
63 continue
64 }
65 ident, ok := n.Args[0].(expect.Identifier)
66 if !ok {
67 t.Errorf("%v: got %v for arg 1, want identifier", conf.Fset.Position(n.Pos), n.Args[0])
68 continue
69 }
70 exp, ok := n.Args[1].(string)
71 if !ok {
72 t.Errorf("%v: got %v for arg 2, want string", conf.Fset.Position(n.Pos), n.Args[1])
73 continue
74 }
75 p, _, err := expect.MatchBefore(conf.Fset, readFile, n.Pos, string(ident))
76 if err != nil {
77 t.Error(err)
78 continue
79 }
80 pos := conf.Fset.Position(p)
81 key := fmt.Sprintf("%s:%d", ident, pos.Line)
82 expectations[key] = exp
83 }
84
85 iprog, err := conf.Load()
86 if err != nil {
87 t.Error(err)
88 return
89 }
90
91 prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0) )
92 mainInfo := iprog.Created[0]
93 mainPkg := prog.Package(mainInfo.Pkg)
94 mainPkg.SetDebugMode(true)
95 mainPkg.Build()
96
97 var varIds []*ast.Ident
98 var varObjs []*types.Var
99 for id, obj := range mainInfo.Defs {
100
101 switch obj := obj.(type) {
102 case *types.Func:
103 checkFuncValue(t, prog, obj)
104
105 case *types.Const:
106 checkConstValue(t, prog, obj)
107
108 case *types.Var:
109 if id.Name == "_" {
110 continue
111 }
112 varIds = append(varIds, id)
113 varObjs = append(varObjs, obj)
114 }
115 }
116 for id, obj := range mainInfo.Uses {
117 if obj, ok := obj.(*types.Var); ok {
118 varIds = append(varIds, id)
119 varObjs = append(varObjs, obj)
120 }
121 }
122
123
124
125 for i, id := range varIds {
126 obj := varObjs[i]
127 ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos())
128 pos := prog.Fset.Position(id.Pos())
129 exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)]
130 if exp == "" {
131 t.Errorf("%s: no expectation for var ident %s ", pos, id.Name)
132 continue
133 }
134 wantAddr := false
135 if exp[0] == '&' {
136 wantAddr = true
137 exp = exp[1:]
138 }
139 checkVarValue(t, prog, mainPkg, ref, obj, exp, wantAddr)
140 }
141 }
142
143 func checkFuncValue(t *testing.T, prog *ssa.Program, obj *types.Func) {
144 fn := prog.FuncValue(obj)
145
146 if fn == nil {
147 if obj.Name() != "interfaceMethod" {
148 t.Errorf("FuncValue(%s) == nil", obj)
149 }
150 return
151 }
152 if fnobj := fn.Object(); fnobj != obj {
153 t.Errorf("FuncValue(%s).Object() == %s; value was %s",
154 obj, fnobj, fn.Name())
155 return
156 }
157 if !types.Identical(fn.Type(), obj.Type()) {
158 t.Errorf("FuncValue(%s).Type() == %s", obj, fn.Type())
159 return
160 }
161 }
162
163 func checkConstValue(t *testing.T, prog *ssa.Program, obj *types.Const) {
164 c := prog.ConstValue(obj)
165
166 if c == nil {
167 t.Errorf("ConstValue(%s) == nil", obj)
168 return
169 }
170 if !types.Identical(c.Type(), obj.Type()) {
171 t.Errorf("ConstValue(%s).Type() == %s", obj, c.Type())
172 return
173 }
174 if obj.Name() != "nil" {
175 if !constant.Compare(c.Value, token.EQL, obj.Val()) {
176 t.Errorf("ConstValue(%s).Value (%s) != %s",
177 obj, c.Value, obj.Val())
178 return
179 }
180 }
181 }
182
183 func checkVarValue(t *testing.T, prog *ssa.Program, pkg *ssa.Package, ref []ast.Node, obj *types.Var, expKind string, wantAddr bool) {
184
185 prefix := fmt.Sprintf("VarValue(%s @ L%d)",
186 obj, prog.Fset.Position(ref[0].Pos()).Line)
187
188 v, gotAddr := prog.VarValue(obj, pkg, ref)
189
190
191 gotKind := "nil"
192 if v != nil {
193 gotKind = fmt.Sprintf("%T", v)[len("*ssa."):]
194 }
195
196
197
198
199
200 if expKind != gotKind {
201 t.Errorf("%s concrete type == %s, want %s", prefix, gotKind, expKind)
202 }
203
204
205
206 if v != nil {
207 expType := obj.Type()
208 if wantAddr {
209 expType = types.NewPointer(expType)
210 if !gotAddr {
211 t.Errorf("%s: got value, want address", prefix)
212 }
213 } else if gotAddr {
214 t.Errorf("%s: got address, want value", prefix)
215 }
216 if !types.Identical(v.Type(), expType) {
217 t.Errorf("%s.Type() == %s, want %s", prefix, v.Type(), expType)
218 }
219 }
220 }
221
222
223
224 func TestValueForExpr(t *testing.T) {
225 testValueForExpr(t, "testdata/valueforexpr.go")
226 }
227
228 func TestValueForExprStructConv(t *testing.T) {
229 testValueForExpr(t, "testdata/structconv.go")
230 }
231
232 func testValueForExpr(t *testing.T, testfile string) {
233 if runtime.GOOS == "android" {
234 t.Skipf("no testdata dir on %s", runtime.GOOS)
235 }
236
237 conf := loader.Config{ParserMode: parser.ParseComments}
238 f, err := conf.ParseFile(testfile, nil)
239 if err != nil {
240 t.Error(err)
241 return
242 }
243 conf.CreateFromFiles("main", f)
244
245 iprog, err := conf.Load()
246 if err != nil {
247 t.Error(err)
248 return
249 }
250
251 mainInfo := iprog.Created[0]
252
253 prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0))
254 mainPkg := prog.Package(mainInfo.Pkg)
255 mainPkg.SetDebugMode(true)
256 mainPkg.Build()
257
258 if false {
259
260 for _, mem := range mainPkg.Members {
261 if fn, ok := mem.(*ssa.Function); ok {
262 fn.WriteTo(os.Stderr)
263 }
264 }
265 }
266
267 var parenExprs []*ast.ParenExpr
268 ast.Inspect(f, func(n ast.Node) bool {
269 if n != nil {
270 if e, ok := n.(*ast.ParenExpr); ok {
271 parenExprs = append(parenExprs, e)
272 }
273 }
274 return true
275 })
276
277 notes, err := expect.ExtractGo(prog.Fset, f)
278 if err != nil {
279 t.Fatal(err)
280 }
281 for _, n := range notes {
282 want := n.Name
283 if want == "nil" {
284 want = "<nil>"
285 }
286 position := prog.Fset.Position(n.Pos)
287 var e ast.Expr
288 for _, paren := range parenExprs {
289 if paren.Pos() > n.Pos {
290 e = paren.X
291 break
292 }
293 }
294 if e == nil {
295 t.Errorf("%s: note doesn't precede ParenExpr: %q", position, want)
296 continue
297 }
298
299 path, _ := astutil.PathEnclosingInterval(f, n.Pos, n.Pos)
300 if path == nil {
301 t.Errorf("%s: can't find AST path from root to comment: %s", position, want)
302 continue
303 }
304
305 fn := ssa.EnclosingFunction(mainPkg, path)
306 if fn == nil {
307 t.Errorf("%s: can't find enclosing function", position)
308 continue
309 }
310
311 v, gotAddr := fn.ValueForExpr(e)
312 got := strings.TrimPrefix(fmt.Sprintf("%T", v), "*ssa.")
313 if got != want {
314 t.Errorf("%s: got value %q, want %q", position, got, want)
315 }
316 if v != nil {
317 T := v.Type()
318 if gotAddr {
319 T = T.Underlying().(*types.Pointer).Elem()
320 }
321 if !types.Identical(T, mainInfo.TypeOf(e)) {
322 t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T)
323 }
324 }
325 }
326 }
327
328
329
330
331 func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *ast.File, start, end token.Pos) {
332 f, err := parser.ParseFile(fset, "<input>", input, 0)
333 if err != nil {
334 t.Errorf("parse error: %s", err)
335 return
336 }
337
338 i := strings.Index(input, substr)
339 if i < 0 {
340 t.Errorf("%q is not a substring of input", substr)
341 f = nil
342 return
343 }
344
345 filePos := fset.File(f.Package)
346 return f, filePos.Pos(i), filePos.Pos(i + len(substr))
347 }
348
349 func TestEnclosingFunction(t *testing.T) {
350 tests := []struct {
351 input string
352 substr string
353 fn string
354 }{
355
356
357
358 {`package main
359 func f() { println(1003) }`,
360 "100", "main.f"},
361
362 {`package main
363 type T int
364 func (t T) f() { println(200) }`,
365 "200", "(main.T).f"},
366
367 {`package main
368 func f() { println(func() { print(300) }) }`,
369 "300", "main.f$1"},
370
371 {`package main
372 func f() { println(func() { print(func() { print(350) })})}`,
373 "350", "main.f$1$1"},
374
375 {"package main; var a = 400", "400", "main.init"},
376
377 {"package main; const a = 500", "500", "(none)"},
378
379 {"package main; func init() { println(600) }", "600", "main.init#1"},
380
381 {`package main
382 func init() { println("foo") }
383 func init() { println(800) }`,
384 "800", "main.init#2"},
385
386 {`package main
387 func init() { println(func(){print(900)}) }`,
388 "900", "main.init#1$1"},
389
390 {`package main
391 type S[T any] struct{}
392 func (*S[T]) Foo() { println(1000) }
393 type P[T any] struct{ *S[T] }`,
394 "1000", "(*main.S[T]).Foo",
395 },
396 }
397 for _, test := range tests {
398 conf := loader.Config{Fset: token.NewFileSet()}
399 f, start, end := findInterval(t, conf.Fset, test.input, test.substr)
400 if f == nil {
401 continue
402 }
403 path, exact := astutil.PathEnclosingInterval(f, start, end)
404 if !exact {
405 t.Errorf("EnclosingFunction(%q) not exact", test.substr)
406 continue
407 }
408
409 conf.CreateFromFiles("main", f)
410
411 iprog, err := conf.Load()
412 if err != nil {
413 t.Error(err)
414 continue
415 }
416 prog := ssautil.CreateProgram(iprog, ssa.BuilderMode(0))
417 pkg := prog.Package(iprog.Created[0].Pkg)
418 pkg.Build()
419
420 name := "(none)"
421 fn := ssa.EnclosingFunction(pkg, path)
422 if fn != nil {
423 name = fn.String()
424 }
425
426 if name != test.fn {
427 t.Errorf("EnclosingFunction(%q in %q) got %s, want %s",
428 test.substr, test.input, name, test.fn)
429 continue
430 }
431
432
433 if has := ssa.HasEnclosingFunction(pkg, path); has != (fn != nil) {
434 t.Errorf("HasEnclosingFunction(%q in %q) got %v, want %v",
435 test.substr, test.input, has, fn != nil)
436 continue
437 }
438 }
439 }
440
View as plain text