1
2
3
4
5 package vta
6
7 import (
8 "fmt"
9 "go/types"
10 "reflect"
11 "sort"
12 "strings"
13 "testing"
14
15 "golang.org/x/tools/go/callgraph/cha"
16 "golang.org/x/tools/go/ssa"
17 "golang.org/x/tools/go/ssa/ssautil"
18 "golang.org/x/tools/internal/aliases"
19 )
20
21 func TestNodeInterface(t *testing.T) {
22
23
24
25
26
27
28
29 prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
30 if err != nil {
31 t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
32 }
33
34 pkg := prog.AllPackages()[0]
35 main := pkg.Func("main")
36 reg := firstRegInstr(main)
37 X := pkg.Type("X").Type()
38 gl := pkg.Var("gl")
39 glPtrType, ok := aliases.Unalias(gl.Type()).(*types.Pointer)
40 if !ok {
41 t.Fatalf("could not cast gl variable to pointer type")
42 }
43 bint := glPtrType.Elem()
44
45 pint := types.NewPointer(bint)
46 i := types.NewInterface(nil, nil)
47
48 voidFunc := main.Signature.Underlying()
49
50 for _, test := range []struct {
51 n node
52 s string
53 t types.Type
54 }{
55 {constant{typ: bint}, "Constant(int)", bint},
56 {pointer{typ: pint}, "Pointer(*int)", pint},
57 {mapKey{typ: bint}, "MapKey(int)", bint},
58 {mapValue{typ: pint}, "MapValue(*int)", pint},
59 {sliceElem{typ: bint}, "Slice([]int)", bint},
60 {channelElem{typ: pint}, "Channel(chan *int)", pint},
61 {field{StructType: X, index: 0}, "Field(testdata.X:a)", bint},
62 {field{StructType: X, index: 1}, "Field(testdata.X:b)", bint},
63 {global{val: gl}, "Global(gl)", gl.Type()},
64 {local{val: reg}, "Local(t0)", bint},
65 {indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X},
66 {function{f: main}, "Function(main)", voidFunc},
67 {nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i},
68 {nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc},
69 {panicArg{}, "Panic", nil},
70 {recoverReturn{}, "Recover", nil},
71 } {
72 if test.s != test.n.String() {
73 t.Errorf("want %s; got %s", test.s, test.n.String())
74 }
75 if test.t != test.n.Type() {
76 t.Errorf("want %s; got %s", test.t, test.n.Type())
77 }
78 }
79 }
80
81 func TestVtaGraph(t *testing.T) {
82
83 prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
84 if err != nil {
85 t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
86 }
87
88 glPtrType, ok := prog.AllPackages()[0].Var("gl").Type().(*types.Pointer)
89 if !ok {
90 t.Fatalf("could not cast gl variable to pointer type")
91 }
92 bint := glPtrType.Elem()
93
94 n1 := constant{typ: bint}
95 n2 := pointer{typ: types.NewPointer(bint)}
96 n3 := mapKey{typ: types.NewMap(bint, bint)}
97 n4 := mapValue{typ: types.NewMap(bint, bint)}
98
99
100
101
102
103
104
105 g := make(vtaGraph)
106 g.addEdge(n1, n3)
107 g.addEdge(n2, n3)
108 g.addEdge(n3, n4)
109 g.addEdge(n2, n4)
110
111 g.addEdge(n1, n3)
112
113 want := vtaGraph{
114 n1: map[node]bool{n3: true},
115 n2: map[node]bool{n3: true, n4: true},
116 n3: map[node]bool{n4: true},
117 }
118
119 if !reflect.DeepEqual(want, g) {
120 t.Errorf("want %v; got %v", want, g)
121 }
122
123 for _, test := range []struct {
124 n node
125 l int
126 }{
127 {n1, 1},
128 {n2, 2},
129 {n3, 1},
130 {n4, 0},
131 } {
132 if sl := len(g.successors(test.n)); sl != test.l {
133 t.Errorf("want %d successors; got %d", test.l, sl)
134 }
135 }
136 }
137
138
139
140
141
142 func vtaGraphStr(g vtaGraph) []string {
143 var vgs []string
144 for n, succ := range g {
145 var succStr []string
146 for s := range succ {
147 succStr = append(succStr, s.String())
148 }
149 sort.Strings(succStr)
150 entry := fmt.Sprintf("%v -> %v", n.String(), strings.Join(succStr, ", "))
151 vgs = append(vgs, entry)
152 }
153 return vgs
154 }
155
156
157 func setdiff(X, Y []string) []string {
158 y := make(map[string]bool)
159 var delta []string
160 for _, s := range Y {
161 y[s] = true
162 }
163
164 for _, s := range X {
165 if _, ok := y[s]; !ok {
166 delta = append(delta, s)
167 }
168 }
169 sort.Strings(delta)
170 return delta
171 }
172
173 func TestVTAGraphConstruction(t *testing.T) {
174 for _, file := range []string{
175 "testdata/src/store.go",
176 "testdata/src/phi.go",
177 "testdata/src/type_conversions.go",
178 "testdata/src/type_assertions.go",
179 "testdata/src/fields.go",
180 "testdata/src/node_uniqueness.go",
181 "testdata/src/store_load_alias.go",
182 "testdata/src/phi_alias.go",
183 "testdata/src/channels.go",
184 "testdata/src/generic_channels.go",
185 "testdata/src/select.go",
186 "testdata/src/stores_arrays.go",
187 "testdata/src/maps.go",
188 "testdata/src/ranges.go",
189 "testdata/src/closures.go",
190 "testdata/src/function_alias.go",
191 "testdata/src/static_calls.go",
192 "testdata/src/dynamic_calls.go",
193 "testdata/src/returns.go",
194 "testdata/src/panic.go",
195 } {
196 t.Run(file, func(t *testing.T) {
197 prog, want, err := testProg(file, ssa.BuilderMode(0))
198 if err != nil {
199 t.Fatalf("couldn't load test file '%s': %s", file, err)
200 }
201 if len(want) == 0 {
202 t.Fatalf("couldn't find want in `%s`", file)
203 }
204
205 g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
206 got := vtaGraphStr(g)
207 if diff := setdiff(want, got); len(diff) > 0 {
208 t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff)
209 }
210 })
211 }
212 }
213
View as plain text