1
2
3
4
5
6
7
8
9
10 package rta_test
11
12 import (
13 "fmt"
14 "go/ast"
15 "go/parser"
16 "go/types"
17 "sort"
18 "strings"
19 "testing"
20
21 "golang.org/x/tools/go/callgraph"
22 "golang.org/x/tools/go/callgraph/rta"
23 "golang.org/x/tools/go/loader"
24 "golang.org/x/tools/go/ssa"
25 "golang.org/x/tools/go/ssa/ssautil"
26 "golang.org/x/tools/internal/aliases"
27 )
28
29
30
31 func TestRTA(t *testing.T) {
32 filenames := []string{
33 "testdata/func.go",
34 "testdata/generics.go",
35 "testdata/iface.go",
36 "testdata/reflectcall.go",
37 "testdata/rtype.go",
38 }
39 for _, filename := range filenames {
40 t.Run(filename, func(t *testing.T) {
41
42
43 conf := loader.Config{ParserMode: parser.ParseComments}
44 f, err := conf.ParseFile(filename, nil)
45 if err != nil {
46 t.Fatal(err)
47 }
48 conf.CreateFromFiles("main", f)
49 lprog, err := conf.Load()
50 if err != nil {
51 t.Fatal(err)
52 }
53 prog := ssautil.CreateProgram(lprog, ssa.InstantiateGenerics)
54 prog.Build()
55 mainPkg := prog.Package(lprog.Created[0].Pkg)
56
57 res := rta.Analyze([]*ssa.Function{
58 mainPkg.Func("main"),
59 mainPkg.Func("init"),
60 }, true)
61
62 check(t, f, mainPkg, res)
63 })
64 }
65 }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 func check(t *testing.T, f *ast.File, pkg *ssa.Package, res *rta.Result) {
81 tokFile := pkg.Prog.Fset.File(f.Pos())
82
83
84 expectation := func(f *ast.File) (string, int) {
85 for _, c := range f.Comments {
86 text := strings.TrimSpace(c.Text())
87 if t := strings.TrimPrefix(text, "WANT:\n"); t != text {
88 return t, tokFile.Line(c.Pos())
89 }
90 }
91 t.Fatalf("No WANT: comment in %s", tokFile.Name())
92 return "", 0
93 }
94 want, linenum := expectation(f)
95
96
97 var (
98 wantEdge = make(map[string]bool)
99 wantReachable = make(map[string]bool)
100 wantRtype = make(map[string]bool)
101 )
102 for _, line := range strings.Split(want, "\n") {
103 linenum++
104 orig := line
105 bad := func() {
106 t.Fatalf("%s:%d: invalid assertion: %q", tokFile.Name(), linenum, orig)
107 }
108
109 line := strings.TrimSpace(line)
110 if line == "" {
111 continue
112 }
113
114
115 sense := true
116 if rest := strings.TrimPrefix(line, "!"); rest != line {
117 sense = false
118 line = strings.TrimSpace(rest)
119 if line == "" {
120 bad()
121 }
122 }
123
124
125 var want map[string]bool
126 kind := strings.Fields(line)[0]
127 switch kind {
128 case "edge":
129 want = wantEdge
130 case "reachable":
131 want = wantReachable
132 case "rtype":
133 want = wantRtype
134 default:
135 bad()
136 }
137
138
139 str := strings.TrimSpace(line[len(kind):])
140 want[str] = sense
141 }
142
143 type stringset = map[string]bool
144
145
146
147
148 compare := func(kind string, got stringset, want map[string]bool) {
149 ok := true
150 for str, sense := range want {
151 if got[str] != sense {
152 ok = false
153 if sense {
154 t.Errorf("missing %s %q", kind, str)
155 } else {
156 t.Errorf("unwanted %s %q", kind, str)
157 }
158 }
159 }
160
161
162 if !ok {
163 var strs []string
164 for s := range got {
165 strs = append(strs, s)
166 }
167 sort.Strings(strs)
168 var buf strings.Builder
169 for _, str := range strs {
170 fmt.Fprintf(&buf, "%s %s\n", kind, str)
171 }
172 t.Errorf("got:\n%s", &buf)
173 }
174 }
175
176
177 {
178 got := make(stringset)
179 callgraph.GraphVisitEdges(res.CallGraph, func(e *callgraph.Edge) error {
180 edge := fmt.Sprintf("%s --%s--> %s",
181 e.Caller.Func.RelString(pkg.Pkg),
182 e.Description(),
183 e.Callee.Func.RelString(pkg.Pkg))
184 got[edge] = true
185 return nil
186 })
187 compare("edge", got, wantEdge)
188 }
189
190
191 {
192 got := make(stringset)
193 for f := range res.Reachable {
194 got[f.RelString(pkg.Pkg)] = true
195 }
196 compare("reachable", got, wantReachable)
197 }
198
199
200 {
201 got := make(stringset)
202 res.RuntimeTypes.Iterate(func(key types.Type, value interface{}) {
203 if !value.(bool) {
204 typ := types.TypeString(aliases.Unalias(key), types.RelativeTo(pkg.Pkg))
205 got[typ] = true
206 }
207 })
208 compare("rtype", got, wantRtype)
209 }
210 }
211
View as plain text