1
2
3
4
5 package ssa
6
7 import (
8 "go/types"
9
10 "golang.org/x/tools/internal/aliases"
11 )
12
13
14
15
16
17
18
19
20 type subster struct {
21 replacements map[*types.TypeParam]types.Type
22 cache map[types.Type]types.Type
23 ctxt *types.Context
24 scope *types.Scope
25 debug bool
26
27
28
29 }
30
31
32
33
34 func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
35 assert(tparams.Len() == len(targs), "makeSubster argument count must match")
36
37 subst := &subster{
38 replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
39 cache: make(map[types.Type]types.Type),
40 ctxt: ctxt,
41 scope: scope,
42 debug: debug,
43 }
44 for i := 0; i < tparams.Len(); i++ {
45 subst.replacements[tparams.At(i)] = targs[i]
46 }
47 if subst.debug {
48 subst.wellFormed()
49 }
50 return subst
51 }
52
53
54 func (subst *subster) wellFormed() {
55 if subst == nil {
56 return
57 }
58
59 s := make(map[types.Type]bool, len(subst.replacements))
60 for tparam := range subst.replacements {
61 s[tparam] = true
62 }
63 for _, r := range subst.replacements {
64 if reaches(r, s) {
65 panic(subst)
66 }
67 }
68 }
69
70
71
72 func (subst *subster) typ(t types.Type) (res types.Type) {
73 if subst == nil {
74 return t
75 }
76 if r, ok := subst.cache[t]; ok {
77 return r
78 }
79 defer func() {
80 subst.cache[t] = res
81 }()
82
83 switch t := t.(type) {
84 case *types.TypeParam:
85 r := subst.replacements[t]
86 assert(r != nil, "type param without replacement encountered")
87 return r
88
89 case *types.Basic:
90 return t
91
92 case *types.Array:
93 if r := subst.typ(t.Elem()); r != t.Elem() {
94 return types.NewArray(r, t.Len())
95 }
96 return t
97
98 case *types.Slice:
99 if r := subst.typ(t.Elem()); r != t.Elem() {
100 return types.NewSlice(r)
101 }
102 return t
103
104 case *types.Pointer:
105 if r := subst.typ(t.Elem()); r != t.Elem() {
106 return types.NewPointer(r)
107 }
108 return t
109
110 case *types.Tuple:
111 return subst.tuple(t)
112
113 case *types.Struct:
114 return subst.struct_(t)
115
116 case *types.Map:
117 key := subst.typ(t.Key())
118 elem := subst.typ(t.Elem())
119 if key != t.Key() || elem != t.Elem() {
120 return types.NewMap(key, elem)
121 }
122 return t
123
124 case *types.Chan:
125 if elem := subst.typ(t.Elem()); elem != t.Elem() {
126 return types.NewChan(t.Dir(), elem)
127 }
128 return t
129
130 case *types.Signature:
131 return subst.signature(t)
132
133 case *types.Union:
134 return subst.union(t)
135
136 case *types.Interface:
137 return subst.interface_(t)
138
139 case *aliases.Alias:
140 return subst.alias(t)
141
142 case *types.Named:
143 return subst.named(t)
144
145 default:
146 panic("unreachable")
147 }
148 }
149
150
151 func (subst *subster) types(ts []types.Type) []types.Type {
152 res := make([]types.Type, len(ts))
153 for i := range ts {
154 res[i] = subst.typ(ts[i])
155 }
156 return res
157 }
158
159 func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
160 if t != nil {
161 if vars := subst.varlist(t); vars != nil {
162 return types.NewTuple(vars...)
163 }
164 }
165 return t
166 }
167
168 type varlist interface {
169 At(i int) *types.Var
170 Len() int
171 }
172
173
174 type fieldlist struct {
175 str *types.Struct
176 }
177
178 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
179 func (fl fieldlist) Len() int { return fl.str.NumFields() }
180
181 func (subst *subster) struct_(t *types.Struct) *types.Struct {
182 if t != nil {
183 if fields := subst.varlist(fieldlist{t}); fields != nil {
184 tags := make([]string, t.NumFields())
185 for i, n := 0, t.NumFields(); i < n; i++ {
186 tags[i] = t.Tag(i)
187 }
188 return types.NewStruct(fields, tags)
189 }
190 }
191 return t
192 }
193
194
195 func (subst *subster) varlist(in varlist) []*types.Var {
196 var out []*types.Var
197 for i, n := 0, in.Len(); i < n; i++ {
198 v := in.At(i)
199 w := subst.var_(v)
200 if v != w && out == nil {
201 out = make([]*types.Var, n)
202 for j := 0; j < i; j++ {
203 out[j] = in.At(j)
204 }
205 }
206 if out != nil {
207 out[i] = w
208 }
209 }
210 return out
211 }
212
213 func (subst *subster) var_(v *types.Var) *types.Var {
214 if v != nil {
215 if typ := subst.typ(v.Type()); typ != v.Type() {
216 if v.IsField() {
217 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
218 }
219 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
220 }
221 }
222 return v
223 }
224
225 func (subst *subster) union(u *types.Union) *types.Union {
226 var out []*types.Term
227
228 for i, n := 0, u.Len(); i < n; i++ {
229 t := u.Term(i)
230 r := subst.typ(t.Type())
231 if r != t.Type() && out == nil {
232 out = make([]*types.Term, n)
233 for j := 0; j < i; j++ {
234 out[j] = u.Term(j)
235 }
236 }
237 if out != nil {
238 out[i] = types.NewTerm(t.Tilde(), r)
239 }
240 }
241
242 if out != nil {
243 return types.NewUnion(out)
244 }
245 return u
246 }
247
248 func (subst *subster) interface_(iface *types.Interface) *types.Interface {
249 if iface == nil {
250 return nil
251 }
252
253
254
255 var methods []*types.Func
256 initMethods := func(n int) {
257 methods = make([]*types.Func, iface.NumExplicitMethods())
258 for i := 0; i < n; i++ {
259 f := iface.ExplicitMethod(i)
260 norecv := changeRecv(f.Type().(*types.Signature), nil)
261 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
262 }
263 }
264 for i := 0; i < iface.NumExplicitMethods(); i++ {
265 f := iface.ExplicitMethod(i)
266
267
268
269 norecv := changeRecv(f.Type().(*types.Signature), nil)
270 sig := subst.typ(norecv)
271 if sig != norecv && methods == nil {
272 initMethods(i)
273 }
274 if methods != nil {
275 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
276 }
277 }
278
279 var embeds []types.Type
280 initEmbeds := func(n int) {
281 embeds = make([]types.Type, iface.NumEmbeddeds())
282 for i := 0; i < n; i++ {
283 embeds[i] = iface.EmbeddedType(i)
284 }
285 }
286 for i := 0; i < iface.NumEmbeddeds(); i++ {
287 e := iface.EmbeddedType(i)
288 r := subst.typ(e)
289 if e != r && embeds == nil {
290 initEmbeds(i)
291 }
292 if embeds != nil {
293 embeds[i] = r
294 }
295 }
296
297 if methods == nil && embeds == nil {
298 return iface
299 }
300 if methods == nil {
301 initMethods(iface.NumExplicitMethods())
302 }
303 if embeds == nil {
304 initEmbeds(iface.NumEmbeddeds())
305 }
306 return types.NewInterfaceType(methods, embeds).Complete()
307 }
308
309 func (subst *subster) alias(t *aliases.Alias) types.Type {
310
311 u := aliases.Unalias(t)
312 if s := subst.typ(u); s != u {
313
314 return s
315 }
316
317
318 return t
319 }
320
321 func (subst *subster) named(t *types.Named) types.Type {
322
323
324
325
326
327 tparams := t.TypeParams()
328 if tparams.Len() == 0 {
329 if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
330
331 return t
332 }
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355 n := types.NewNamed(t.Obj(), nil, nil)
356 subst.cache[t] = n
357 subst.cache[n] = n
358 n.SetUnderlying(subst.typ(t.Underlying()))
359 return n
360 }
361 targs := t.TypeArgs()
362
363
364 insts := make([]types.Type, tparams.Len())
365
366
367
368
369 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
370
371
372
373
374
375
376
377
378
379 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
380 for i, n := 0, targs.Len(); i < n; i++ {
381 inst := subst.typ(targs.At(i))
382 insts[i] = inst
383 }
384 r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false)
385 assert(err == nil, "failed to Instantiate Named type")
386 return r
387 }
388
389 func (subst *subster) signature(t *types.Signature) types.Type {
390 tparams := t.TypeParams()
391
392
393
394
395
396
397
398 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
399
400
401
402
403
404
405
406
407
408
409
410
411 recv := subst.var_(t.Recv())
412 params := subst.tuple(t.Params())
413 results := subst.tuple(t.Results())
414 if recv != t.Recv() || params != t.Params() || results != t.Results() {
415 return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
416 }
417 return t
418 }
419
420
421
422
423
424
425
426 func reaches(t types.Type, c map[types.Type]bool) (res bool) {
427 if c, ok := c[t]; ok {
428 return c
429 }
430
431
432
433 c[t] = false
434 defer func() {
435 c[t] = res
436 }()
437
438 switch t := t.(type) {
439 case *types.TypeParam, *types.Basic:
440 return false
441 case *types.Array:
442 return reaches(t.Elem(), c)
443 case *types.Slice:
444 return reaches(t.Elem(), c)
445 case *types.Pointer:
446 return reaches(t.Elem(), c)
447 case *types.Tuple:
448 for i := 0; i < t.Len(); i++ {
449 if reaches(t.At(i).Type(), c) {
450 return true
451 }
452 }
453 case *types.Struct:
454 for i := 0; i < t.NumFields(); i++ {
455 if reaches(t.Field(i).Type(), c) {
456 return true
457 }
458 }
459 case *types.Map:
460 return reaches(t.Key(), c) || reaches(t.Elem(), c)
461 case *types.Chan:
462 return reaches(t.Elem(), c)
463 case *types.Signature:
464 if t.Recv() != nil && reaches(t.Recv().Type(), c) {
465 return true
466 }
467 return reaches(t.Params(), c) || reaches(t.Results(), c)
468 case *types.Union:
469 for i := 0; i < t.Len(); i++ {
470 if reaches(t.Term(i).Type(), c) {
471 return true
472 }
473 }
474 case *types.Interface:
475 for i := 0; i < t.NumEmbeddeds(); i++ {
476 if reaches(t.Embedded(i), c) {
477 return true
478 }
479 }
480 for i := 0; i < t.NumExplicitMethods(); i++ {
481 if reaches(t.ExplicitMethod(i).Type(), c) {
482 return true
483 }
484 }
485 case *types.Named, *aliases.Alias:
486 return reaches(t.Underlying(), c)
487 default:
488 panic("unreachable")
489 }
490 return false
491 }
492
View as plain text