1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package rta
39
40 import (
41 "fmt"
42 "go/types"
43 "hash/crc32"
44
45 "golang.org/x/tools/go/callgraph"
46 "golang.org/x/tools/go/ssa"
47 "golang.org/x/tools/go/types/typeutil"
48 "golang.org/x/tools/internal/aliases"
49 )
50
51
52
53 type Result struct {
54
55
56 CallGraph *callgraph.Graph
57
58
59
60
61
62
63
64
65 Reachable map[*ssa.Function]struct{ AddrTaken bool }
66
67
68
69
70
71
72
73
74
75
76 RuntimeTypes typeutil.Map
77 }
78
79
80 type rta struct {
81 result *Result
82
83 prog *ssa.Program
84
85 reflectValueCall *ssa.Function
86
87 worklist []*ssa.Function
88
89
90
91 addrTakenFuncsBySig typeutil.Map
92
93
94
95 dynCallSites typeutil.Map
96
97
98
99
100 invokeSites typeutil.Map
101
102
103
104
105
106
107
108 concreteTypes typeutil.Map
109
110
111
112
113 interfaceTypes typeutil.Map
114 }
115
116 type concreteTypeInfo struct {
117 C types.Type
118 mset *types.MethodSet
119 fprint uint64
120 implements []*types.Interface
121 }
122
123 type interfaceTypeInfo struct {
124 I *types.Interface
125 mset *types.MethodSet
126 fprint uint64
127 implementations []types.Type
128 }
129
130
131
132 func (r *rta) addReachable(f *ssa.Function, addrTaken bool) {
133 reachable := r.result.Reachable
134 n := len(reachable)
135 v := reachable[f]
136 if addrTaken {
137 v.AddrTaken = true
138 }
139 reachable[f] = v
140 if len(reachable) > n {
141
142 r.worklist = append(r.worklist, f)
143 }
144 }
145
146
147
148
149 func (r *rta) addEdge(caller *ssa.Function, site ssa.CallInstruction, callee *ssa.Function, addrTaken bool) {
150 r.addReachable(callee, addrTaken)
151
152 if g := r.result.CallGraph; g != nil {
153 if caller == nil {
154 panic(site)
155 }
156 from := g.CreateNode(caller)
157 to := g.CreateNode(callee)
158 callgraph.AddEdge(from, site, to)
159 }
160 }
161
162
163
164
165 func (r *rta) visitAddrTakenFunc(f *ssa.Function) {
166
167 S := f.Signature
168 funcs, _ := r.addrTakenFuncsBySig.At(S).(map[*ssa.Function]bool)
169 if funcs == nil {
170 funcs = make(map[*ssa.Function]bool)
171 r.addrTakenFuncsBySig.Set(S, funcs)
172 }
173 if !funcs[f] {
174
175 funcs[f] = true
176
177
178
179 sites, _ := r.dynCallSites.At(S).([]ssa.CallInstruction)
180 for _, site := range sites {
181 r.addEdge(site.Parent(), site, f, true)
182 }
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206 if r.reflectValueCall != nil {
207 var site ssa.CallInstruction = nil
208 r.addEdge(r.reflectValueCall, site, f, true)
209 }
210 }
211 }
212
213
214 func (r *rta) visitDynCall(site ssa.CallInstruction) {
215 S := site.Common().Signature()
216
217
218 sites, _ := r.dynCallSites.At(S).([]ssa.CallInstruction)
219 r.dynCallSites.Set(S, append(sites, site))
220
221
222
223 funcs, _ := r.addrTakenFuncsBySig.At(S).(map[*ssa.Function]bool)
224 for g := range funcs {
225 r.addEdge(site.Parent(), site, g, true)
226 }
227 }
228
229
230
231
232 func (r *rta) addInvokeEdge(site ssa.CallInstruction, C types.Type) {
233
234 imethod := site.Common().Method
235 cmethod := r.prog.LookupMethod(C, imethod.Pkg(), imethod.Name())
236 r.addEdge(site.Parent(), site, cmethod, true)
237 }
238
239
240 func (r *rta) visitInvoke(site ssa.CallInstruction) {
241 I := site.Common().Value.Type().Underlying().(*types.Interface)
242
243
244 sites, _ := r.invokeSites.At(I).([]ssa.CallInstruction)
245 r.invokeSites.Set(I, append(sites, site))
246
247
248
249 for _, C := range r.implementations(I) {
250 r.addInvokeEdge(site, C)
251 }
252 }
253
254
255
256
257 func (r *rta) visitFunc(f *ssa.Function) {
258 var space [32]*ssa.Value
259
260 for _, b := range f.Blocks {
261 for _, instr := range b.Instrs {
262 rands := instr.Operands(space[:0])
263
264 switch instr := instr.(type) {
265 case ssa.CallInstruction:
266 call := instr.Common()
267 if call.IsInvoke() {
268 r.visitInvoke(instr)
269 } else if g := call.StaticCallee(); g != nil {
270 r.addEdge(f, instr, g, false)
271 } else if _, ok := call.Value.(*ssa.Builtin); !ok {
272 r.visitDynCall(instr)
273 }
274
275
276
277
278 rands = rands[1:]
279
280 case *ssa.MakeInterface:
281
282
283
284
285 r.addRuntimeType(instr.X.Type(), false)
286 }
287
288
289 for _, op := range rands {
290 if g, ok := (*op).(*ssa.Function); ok {
291 r.visitAddrTakenFunc(g)
292 }
293 }
294 }
295 }
296 }
297
298
299
300
301
302
303
304
305
306
307
308
309 func Analyze(roots []*ssa.Function, buildCallGraph bool) *Result {
310 if len(roots) == 0 {
311 return nil
312 }
313
314 r := &rta{
315 result: &Result{Reachable: make(map[*ssa.Function]struct{ AddrTaken bool })},
316 prog: roots[0].Prog,
317 }
318
319 if buildCallGraph {
320
321
322
323 r.result.CallGraph = callgraph.New(roots[0])
324 }
325
326
327
328 if reflectPkg := r.prog.ImportedPackage("reflect"); reflectPkg != nil {
329 reflectValue := reflectPkg.Members["Value"].(*ssa.Type)
330 r.reflectValueCall = r.prog.LookupMethod(reflectValue.Object().Type(), reflectPkg.Pkg, "Call")
331 }
332
333 hasher := typeutil.MakeHasher()
334 r.result.RuntimeTypes.SetHasher(hasher)
335 r.addrTakenFuncsBySig.SetHasher(hasher)
336 r.dynCallSites.SetHasher(hasher)
337 r.invokeSites.SetHasher(hasher)
338 r.concreteTypes.SetHasher(hasher)
339 r.interfaceTypes.SetHasher(hasher)
340
341 for _, root := range roots {
342 r.addReachable(root, false)
343 }
344
345
346
347
348 var shadow []*ssa.Function
349 for len(r.worklist) > 0 {
350 shadow, r.worklist = r.worklist, shadow[:0]
351 for _, f := range shadow {
352 r.visitFunc(f)
353 }
354 }
355 return r.result
356 }
357
358
359 func (r *rta) interfaces(C types.Type) []*types.Interface {
360
361 var cinfo *concreteTypeInfo
362 if v := r.concreteTypes.At(C); v != nil {
363 cinfo = v.(*concreteTypeInfo)
364 } else {
365 mset := r.prog.MethodSets.MethodSet(C)
366 cinfo = &concreteTypeInfo{
367 C: C,
368 mset: mset,
369 fprint: fingerprint(mset),
370 }
371 r.concreteTypes.Set(C, cinfo)
372
373
374
375 r.interfaceTypes.Iterate(func(I types.Type, v interface{}) {
376 iinfo := v.(*interfaceTypeInfo)
377 if I := aliases.Unalias(I).(*types.Interface); implements(cinfo, iinfo) {
378 iinfo.implementations = append(iinfo.implementations, C)
379 cinfo.implements = append(cinfo.implements, I)
380 }
381 })
382 }
383
384 return cinfo.implements
385 }
386
387
388 func (r *rta) implementations(I *types.Interface) []types.Type {
389
390 var iinfo *interfaceTypeInfo
391 if v := r.interfaceTypes.At(I); v != nil {
392 iinfo = v.(*interfaceTypeInfo)
393 } else {
394 mset := r.prog.MethodSets.MethodSet(I)
395 iinfo = &interfaceTypeInfo{
396 I: I,
397 mset: mset,
398 fprint: fingerprint(mset),
399 }
400 r.interfaceTypes.Set(I, iinfo)
401
402
403
404 r.concreteTypes.Iterate(func(C types.Type, v interface{}) {
405 cinfo := v.(*concreteTypeInfo)
406 if implements(cinfo, iinfo) {
407 cinfo.implements = append(cinfo.implements, I)
408 iinfo.implementations = append(iinfo.implementations, C)
409 }
410 })
411 }
412 return iinfo.implementations
413 }
414
415
416
417
418 func (r *rta) addRuntimeType(T types.Type, skip bool) {
419
420 T = aliases.Unalias(T)
421
422 if prev, ok := r.result.RuntimeTypes.At(T).(bool); ok {
423 if skip && !prev {
424 r.result.RuntimeTypes.Set(T, skip)
425 }
426 return
427 }
428 r.result.RuntimeTypes.Set(T, skip)
429
430 mset := r.prog.MethodSets.MethodSet(T)
431
432 if _, ok := T.Underlying().(*types.Interface); !ok {
433
434 for i, n := 0, mset.Len(); i < n; i++ {
435 sel := mset.At(i)
436 m := sel.Obj()
437
438 if m.Exported() {
439
440 r.addReachable(r.prog.MethodValue(sel), true)
441 }
442 }
443
444
445
446 for _, I := range r.interfaces(T) {
447 sites, _ := r.invokeSites.At(I).([]ssa.CallInstruction)
448 for _, site := range sites {
449 r.addInvokeEdge(site, T)
450 }
451 }
452 }
453
454
455
456
457
458 var n *types.Named
459 switch T := aliases.Unalias(T).(type) {
460 case *types.Named:
461 n = T
462 case *types.Pointer:
463 n, _ = aliases.Unalias(T.Elem()).(*types.Named)
464 }
465 if n != nil {
466 owner := n.Obj().Pkg()
467 if owner == nil {
468 return
469 }
470 }
471
472
473 for i := 0; i < mset.Len(); i++ {
474 if mset.At(i).Obj().Exported() {
475 sig := mset.At(i).Type().(*types.Signature)
476 r.addRuntimeType(sig.Params(), true)
477 r.addRuntimeType(sig.Results(), true)
478 }
479 }
480
481 switch t := T.(type) {
482 case *aliases.Alias:
483 panic("unreachable")
484
485 case *types.Basic:
486
487
488 case *types.Interface:
489
490
491 case *types.Pointer:
492 r.addRuntimeType(t.Elem(), false)
493
494 case *types.Slice:
495 r.addRuntimeType(t.Elem(), false)
496
497 case *types.Chan:
498 r.addRuntimeType(t.Elem(), false)
499
500 case *types.Map:
501 r.addRuntimeType(t.Key(), false)
502 r.addRuntimeType(t.Elem(), false)
503
504 case *types.Signature:
505 if t.Recv() != nil {
506 panic(fmt.Sprintf("Signature %s has Recv %s", t, t.Recv()))
507 }
508 r.addRuntimeType(t.Params(), true)
509 r.addRuntimeType(t.Results(), true)
510
511 case *types.Named:
512
513
514 r.addRuntimeType(types.NewPointer(T), false)
515
516
517
518
519
520 r.addRuntimeType(t.Underlying(), true)
521
522 case *types.Array:
523 r.addRuntimeType(t.Elem(), false)
524
525 case *types.Struct:
526 for i, n := 0, t.NumFields(); i < n; i++ {
527 r.addRuntimeType(t.Field(i).Type(), false)
528 }
529
530 case *types.Tuple:
531 for i, n := 0, t.Len(); i < n; i++ {
532 r.addRuntimeType(t.At(i).Type(), false)
533 }
534
535 default:
536 panic(T)
537 }
538 }
539
540
541
542 func fingerprint(mset *types.MethodSet) uint64 {
543 var space [64]byte
544 var mask uint64
545 for i := 0; i < mset.Len(); i++ {
546 method := mset.At(i).Obj()
547 sig := method.Type().(*types.Signature)
548 sum := crc32.ChecksumIEEE(fmt.Appendf(space[:], "%s/%d/%d",
549 method.Id(),
550 sig.Params().Len(),
551 sig.Results().Len()))
552 mask |= 1 << (sum % 64)
553 }
554 return mask
555 }
556
557
558
559 func implements(cinfo *concreteTypeInfo, iinfo *interfaceTypeInfo) (got bool) {
560
561
562
563 return iinfo.fprint & ^cinfo.fprint == 0 && types.Implements(cinfo.C, iinfo.I)
564 }
565
View as plain text