1
2
3
4
5
6
7
8
9
10 package cha_test
11
12 import (
13 "bytes"
14 "fmt"
15 "go/ast"
16 "go/build"
17 "go/parser"
18 "go/token"
19 "go/types"
20 "os"
21 "sort"
22 "strings"
23 "testing"
24
25 "golang.org/x/tools/go/buildutil"
26 "golang.org/x/tools/go/callgraph"
27 "golang.org/x/tools/go/callgraph/cha"
28 "golang.org/x/tools/go/loader"
29 "golang.org/x/tools/go/ssa"
30 "golang.org/x/tools/go/ssa/ssautil"
31 )
32
33 var inputs = []string{
34 "testdata/func.go",
35 "testdata/iface.go",
36 "testdata/recv.go",
37 "testdata/issue23925.go",
38 }
39
40 func expectation(f *ast.File) (string, token.Pos) {
41 for _, c := range f.Comments {
42 text := strings.TrimSpace(c.Text())
43 if t := strings.TrimPrefix(text, "WANT:\n"); t != text {
44 return t, c.Pos()
45 }
46 }
47 return "", token.NoPos
48 }
49
50
51
52
53 func TestCHA(t *testing.T) {
54 for _, filename := range inputs {
55 prog, f, mainPkg, err := loadProgInfo(filename, ssa.InstantiateGenerics)
56 if err != nil {
57 t.Error(err)
58 continue
59 }
60
61 want, pos := expectation(f)
62 if pos == token.NoPos {
63 t.Error(fmt.Errorf("No WANT: comment in %s", filename))
64 continue
65 }
66
67 cg := cha.CallGraph(prog)
68
69 if got := printGraph(cg, mainPkg.Pkg, "dynamic", "Dynamic calls"); got != want {
70 t.Errorf("%s: got:\n%s\nwant:\n%s",
71 prog.Fset.Position(pos), got, want)
72 }
73 }
74 }
75
76
77 func TestCHAGenerics(t *testing.T) {
78 filename := "testdata/generics.go"
79 prog, f, mainPkg, err := loadProgInfo(filename, ssa.InstantiateGenerics)
80 if err != nil {
81 t.Fatal(err)
82 }
83
84 want, pos := expectation(f)
85 if pos == token.NoPos {
86 t.Fatal(fmt.Errorf("No WANT: comment in %s", filename))
87 }
88
89 cg := cha.CallGraph(prog)
90
91 if got := printGraph(cg, mainPkg.Pkg, "", "All calls"); got != want {
92 t.Errorf("%s: got:\n%s\nwant:\n%s",
93 prog.Fset.Position(pos), got, want)
94 }
95 }
96
97
98 func TestCHAUnexported(t *testing.T) {
99
100
101
102
103
104
105
106
107
108
109
110
111
112 main := `package main
113 import "p2"
114 type I1 interface { m() }
115 type S1 struct { p2.I2 }
116 func (s S1) m() { }
117 func main() {
118 var s S1
119 var o I1 = s
120 o.m()
121 p2.Foo(s)
122 }`
123
124 p2 := `package p2
125 type I2 interface { m() }
126 type S2 struct { }
127 func (s S2) m() { }
128 func Foo(i I2) { i.m() }`
129
130 want := `All calls
131 main.init --> p2.init
132 main.main --> (main.S1).m
133 main.main --> p2.Foo
134 p2.Foo --> (p2.S2).m`
135
136 conf := loader.Config{
137 Build: fakeContext(map[string]string{"main": main, "p2": p2}),
138 }
139 conf.Import("main")
140 iprog, err := conf.Load()
141 if err != nil {
142 t.Fatalf("Load failed: %v", err)
143 }
144 prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics)
145 prog.Build()
146
147 cg := cha.CallGraph(prog)
148
149
150 cg.DeleteSyntheticNodes()
151
152 if got := printGraph(cg, nil, "", "All calls"); got != want {
153 t.Errorf("cha.CallGraph: got:\n%s\nwant:\n%s", got, want)
154 }
155 }
156
157
158 func fakeContext(pkgs map[string]string) *build.Context {
159 pkgs2 := make(map[string]map[string]string)
160 for path, content := range pkgs {
161 pkgs2[path] = map[string]string{"x.go": content}
162 }
163 return buildutil.FakeContext(pkgs2)
164 }
165
166 func loadProgInfo(filename string, mode ssa.BuilderMode) (*ssa.Program, *ast.File, *ssa.Package, error) {
167 content, err := os.ReadFile(filename)
168 if err != nil {
169 return nil, nil, nil, fmt.Errorf("couldn't read file '%s': %s", filename, err)
170 }
171
172 conf := loader.Config{
173 ParserMode: parser.ParseComments,
174 }
175 f, err := conf.ParseFile(filename, content)
176 if err != nil {
177 return nil, nil, nil, err
178 }
179
180 conf.CreateFromFiles("main", f)
181 iprog, err := conf.Load()
182 if err != nil {
183 return nil, nil, nil, err
184 }
185
186 prog := ssautil.CreateProgram(iprog, mode)
187 prog.Build()
188
189 return prog, f, prog.Package(iprog.Created[0].Pkg), nil
190 }
191
192
193
194
195 func printGraph(cg *callgraph.Graph, from *types.Package, edgeMatch string, desc string) string {
196 var edges []string
197 callgraph.GraphVisitEdges(cg, func(e *callgraph.Edge) error {
198 if strings.Contains(e.Description(), edgeMatch) {
199 edges = append(edges, fmt.Sprintf("%s --> %s",
200 e.Caller.Func.RelString(from),
201 e.Callee.Func.RelString(from)))
202 }
203 return nil
204 })
205 sort.Strings(edges)
206
207 var buf bytes.Buffer
208 buf.WriteString(desc + "\n")
209 for _, edge := range edges {
210 fmt.Fprintf(&buf, " %s\n", edge)
211 }
212 return strings.TrimSpace(buf.String())
213 }
214
View as plain text