...
1
2
3
4
5 package vta
6
7 import (
8 "bytes"
9 "fmt"
10 "go/ast"
11 "go/parser"
12 "os"
13 "sort"
14 "strings"
15 "testing"
16
17 "golang.org/x/tools/go/callgraph"
18 "golang.org/x/tools/go/ssa/ssautil"
19
20 "golang.org/x/tools/go/loader"
21 "golang.org/x/tools/go/ssa"
22 )
23
24
25
26
27 func want(f *ast.File) []string {
28 for _, c := range f.Comments {
29 text := strings.TrimSpace(c.Text())
30 if t := strings.TrimPrefix(text, "WANT:\n"); t != text {
31 return strings.Split(t, "\n")
32 }
33 }
34 return nil
35 }
36
37
38
39
40 func testProg(path string, mode ssa.BuilderMode) (*ssa.Program, []string, error) {
41 content, err := os.ReadFile(path)
42 if err != nil {
43 return nil, nil, err
44 }
45
46 conf := loader.Config{
47 ParserMode: parser.ParseComments,
48 }
49
50 f, err := conf.ParseFile(path, content)
51 if err != nil {
52 return nil, nil, err
53 }
54
55 conf.CreateFromFiles("testdata", f)
56 iprog, err := conf.Load()
57 if err != nil {
58 return nil, nil, err
59 }
60
61 prog := ssautil.CreateProgram(iprog, mode)
62
63 prog.Package(iprog.Created[0].Pkg).SetDebugMode(true)
64 prog.Build()
65 return prog, want(f), nil
66 }
67
68 func firstRegInstr(f *ssa.Function) ssa.Value {
69 for _, b := range f.Blocks {
70 for _, i := range b.Instrs {
71 if v, ok := i.(ssa.Value); ok {
72 return v
73 }
74 }
75 }
76 return nil
77 }
78
79
80
81 func funcName(f *ssa.Function) string {
82 recv := f.Signature.Recv()
83 if recv == nil {
84 return f.Name()
85 }
86 tp := recv.Type().String()
87 return tp[strings.LastIndex(tp, ".")+1:] + "." + f.Name()
88 }
89
90
91
92
93
94
95
96
97 func callGraphStr(g *callgraph.Graph) []string {
98 var gs []string
99 for f, n := range g.Nodes {
100 c := make(map[string][]string)
101 for _, edge := range n.Out {
102 cs := edge.Site.String()
103 c[cs] = append(c[cs], funcName(edge.Callee.Func))
104 }
105
106 var cs []string
107 for site, fs := range c {
108 sort.Strings(fs)
109 entry := fmt.Sprintf("%v -> %v", site, strings.Join(fs, ", "))
110 cs = append(cs, entry)
111 }
112
113 sort.Strings(cs)
114 entry := fmt.Sprintf("%v: %v", funcName(f), strings.Join(cs, "; "))
115 gs = append(gs, entry)
116 }
117 return gs
118 }
119
120
121 func logFns(t testing.TB, prog *ssa.Program) {
122 for fn := range ssautil.AllFunctions(prog) {
123 var buf bytes.Buffer
124 fn.WriteTo(&buf)
125 t.Log(buf.String())
126 }
127 }
128
View as plain text