1
2
3
4
5 package ssa
6
7
8
9 import (
10 "bytes"
11 "fmt"
12 "go/types"
13 "reflect"
14 "sort"
15 "strings"
16 "testing"
17
18 "golang.org/x/tools/go/loader"
19 )
20
21
22 func loadProgram(p string) (*loader.Program, error) {
23
24 var conf loader.Config
25 f, err := conf.ParseFile("<input>", p)
26 if err != nil {
27 return nil, fmt.Errorf("parse: %v", err)
28 }
29 conf.CreateFromFiles("p", f)
30
31
32 lprog, err := conf.Load()
33 if err != nil {
34 return nil, fmt.Errorf("Load: %v", err)
35 }
36 return lprog, nil
37 }
38
39
40 func buildPackage(lprog *loader.Program, pkg string, mode BuilderMode) *Package {
41 prog := NewProgram(lprog.Fset, mode)
42
43 for _, info := range lprog.AllPackages {
44 prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable)
45 }
46
47 p := prog.Package(lprog.Package(pkg).Pkg)
48 p.Build()
49 return p
50 }
51
52
53
54 func TestNeedsInstance(t *testing.T) {
55 const input = `
56 package p
57
58 import "unsafe"
59
60 type Pointer[T any] struct {
61 v unsafe.Pointer
62 }
63
64 func (x *Pointer[T]) Load() *T {
65 return (*T)(LoadPointer(&x.v))
66 }
67
68 func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer)
69 `
70
71
72
73
74
75
76
77 lprog, err := loadProgram(input)
78 if err != err {
79 t.Fatal(err)
80 }
81
82 for _, mode := range []BuilderMode{BuilderMode(0), InstantiateGenerics} {
83
84 p := buildPackage(lprog, "p", mode)
85 prog := p.Prog
86
87 ptr := p.Type("Pointer").Type().(*types.Named)
88 if ptr.NumMethods() != 1 {
89 t.Fatalf("Expected Pointer to have 1 method. got %d", ptr.NumMethods())
90 }
91
92 obj := ptr.Method(0)
93 if obj.Name() != "Load" {
94 t.Errorf("Expected Pointer to have method named 'Load'. got %q", obj.Name())
95 }
96
97 meth := prog.FuncValue(obj)
98
99 var cr creator
100 intSliceTyp := types.NewSlice(types.Typ[types.Int])
101 instance := meth.instance([]types.Type{intSliceTyp}, &cr)
102 if len(cr) != 1 {
103 t.Errorf("Expected first instance to create a function. got %d created functions", len(cr))
104 }
105 if instance.Origin() != meth {
106 t.Errorf("Expected Origin of %s to be %s. got %s", instance, meth, instance.Origin())
107 }
108 if len(instance.TypeArgs()) != 1 || !types.Identical(instance.TypeArgs()[0], intSliceTyp) {
109 t.Errorf("Expected TypeArgs of %s to be %v. got %v", instance, []types.Type{intSliceTyp}, instance.typeargs)
110 }
111 instances := allInstances(meth)
112 if want := []*Function{instance}; !reflect.DeepEqual(instances, want) {
113 t.Errorf("Expected instances of %s to be %v. got %v", meth, want, instances)
114 }
115
116
117 second := meth.instance([]types.Type{types.NewSlice(types.Typ[types.Int])}, &cr)
118 if second != instance || len(cr) != 1 {
119 t.Error("Expected second identical instantiation to not create a function")
120 }
121
122
123 inst2 := meth.instance([]types.Type{types.NewSlice(types.Typ[types.Uint])}, &cr)
124 instances = allInstances(meth)
125
126
127 sort.Slice(instances, func(i, j int) bool {
128 return instances[i].Name() < instances[j].Name()
129 })
130 if want := []*Function{instance, inst2}; !reflect.DeepEqual(instances, want) {
131 t.Errorf("Expected instances of %s to be %v. got %v", meth, want, instances)
132 }
133
134
135
136
137 var b builder
138 b.buildFunction(instance)
139 var buf bytes.Buffer
140 if !sanityCheck(instance, &buf) {
141 t.Errorf("sanityCheck of %s failed with: %s", instance, buf.String())
142 }
143 }
144 }
145
146
147
148 func TestCallsToInstances(t *testing.T) {
149 const input = `
150 package p
151
152 type I interface {
153 Foo()
154 }
155
156 type A int
157 func (a A) Foo() {}
158
159 type J[T any] interface{ Bar() T }
160 type K[T any] struct{ J[T] }
161
162 func Id[T any] (t T) T {
163 return t
164 }
165
166 func Lambda[T I]() func() func(T) {
167 return func() func(T) {
168 return T.Foo
169 }
170 }
171
172 func NoOp[T any]() {}
173
174 func Bar[T interface { Foo(); ~int | ~string }, U any] (t T, u U) {
175 Id[U](u)
176 Id[T](t)
177 }
178
179 func Make[T any]() interface{} {
180 NoOp[K[T]]()
181 return nil
182 }
183
184 func entry(i int, a A) int {
185 Lambda[A]()()(a)
186
187 x := Make[int]()
188 if j, ok := x.(interface{ Bar() int }); ok {
189 print(j)
190 }
191
192 Bar[A, int](a, i)
193
194 return Id[int](i)
195 }
196 `
197 lprog, err := loadProgram(input)
198 if err != err {
199 t.Fatal(err)
200 }
201
202 p := buildPackage(lprog, "p", SanityCheckFunctions)
203 prog := p.Prog
204
205 for _, ti := range []struct {
206 orig string
207 instance string
208 tparams string
209 targs string
210 chTypeInstrs int
211 }{
212 {"Id", "Id[int]", "[T]", "[int]", 2},
213 {"Lambda", "Lambda[p.A]", "[T]", "[p.A]", 1},
214 {"Make", "Make[int]", "[T]", "[int]", 0},
215 {"NoOp", "NoOp[p.K[T]]", "[T]", "[p.K[T]]", 0},
216 } {
217 test := ti
218 t.Run(test.instance, func(t *testing.T) {
219 f := p.Members[test.orig].(*Function)
220 if f == nil {
221 t.Fatalf("origin function not found")
222 }
223
224 i := instanceOf(f, test.instance, prog)
225 if i == nil {
226 t.Fatalf("instance not found")
227 }
228
229
230 var body strings.Builder
231 i.WriteTo(&body)
232 t.Log(body.String())
233
234 if len(i.Blocks) != 1 {
235 t.Fatalf("body has more than 1 block")
236 }
237
238 if instrs := changeTypeInstrs(i.Blocks[0]); instrs != test.chTypeInstrs {
239 t.Errorf("want %v instructions; got %v", test.chTypeInstrs, instrs)
240 }
241
242 if test.tparams != tparams(i) {
243 t.Errorf("want %v type params; got %v", test.tparams, tparams(i))
244 }
245
246 if test.targs != targs(i) {
247 t.Errorf("want %v type arguments; got %v", test.targs, targs(i))
248 }
249 })
250 }
251 }
252
253 func instanceOf(f *Function, name string, prog *Program) *Function {
254 for _, i := range allInstances(f) {
255 if i.Name() == name {
256 return i
257 }
258 }
259 return nil
260 }
261
262 func tparams(f *Function) string {
263 tplist := f.TypeParams()
264 var tps []string
265 for i := 0; i < tplist.Len(); i++ {
266 tps = append(tps, tplist.At(i).String())
267 }
268 return fmt.Sprint(tps)
269 }
270
271 func targs(f *Function) string {
272 var tas []string
273 for _, ta := range f.TypeArgs() {
274 tas = append(tas, ta.String())
275 }
276 return fmt.Sprint(tas)
277 }
278
279 func changeTypeInstrs(b *BasicBlock) int {
280 cnt := 0
281 for _, i := range b.Instrs {
282 if _, ok := i.(*ChangeType); ok {
283 cnt++
284 }
285 }
286 return cnt
287 }
288
289 func TestInstanceUniqueness(t *testing.T) {
290 const input = `
291 package p
292
293 func H[T any](t T) {
294 print(t)
295 }
296
297 func F[T any](t T) {
298 H[T](t)
299 H[T](t)
300 H[T](t)
301 }
302
303 func G[T any](t T) {
304 H[T](t)
305 H[T](t)
306 }
307
308 func Foo[T any, S any](t T, s S) {
309 Foo[S, T](s, t)
310 Foo[T, S](t, s)
311 }
312 `
313 lprog, err := loadProgram(input)
314 if err != err {
315 t.Fatal(err)
316 }
317
318 p := buildPackage(lprog, "p", SanityCheckFunctions)
319
320 for _, test := range []struct {
321 orig string
322 instances string
323 }{
324 {"H", "[p.H[T] p.H[T]]"},
325 {"Foo", "[p.Foo[S T] p.Foo[T S]]"},
326 } {
327 t.Run(test.orig, func(t *testing.T) {
328 f := p.Members[test.orig].(*Function)
329 if f == nil {
330 t.Fatalf("origin function not found")
331 }
332
333 instances := allInstances(f)
334 sort.Slice(instances, func(i, j int) bool { return instances[i].Name() < instances[j].Name() })
335
336 if got := fmt.Sprintf("%v", instances); !reflect.DeepEqual(got, test.instances) {
337 t.Errorf("got %v instances, want %v", got, test.instances)
338 }
339 })
340 }
341 }
342
343
344
345
346
347
348
349
350
351
352
353 func allInstances(fn *Function) []*Function {
354 if fn.generic == nil {
355 return nil
356 }
357
358 fn.generic.instancesMu.Lock()
359 defer fn.generic.instancesMu.Unlock()
360 return mapValues(fn.generic.instances)
361 }
362
View as plain text