...

Source file src/golang.org/x/tools/go/ssa/subst.go

Documentation: golang.org/x/tools/go/ssa

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"go/types"
     9  
    10  	"golang.org/x/tools/internal/aliases"
    11  )
    12  
    13  // Type substituter for a fixed set of replacement types.
    14  //
    15  // A nil *subster is an valid, empty substitution map. It always acts as
    16  // the identity function. This allows for treating parameterized and
    17  // non-parameterized functions identically while compiling to ssa.
    18  //
    19  // Not concurrency-safe.
    20  type subster struct {
    21  	replacements map[*types.TypeParam]types.Type // values should contain no type params
    22  	cache        map[types.Type]types.Type       // cache of subst results
    23  	ctxt         *types.Context                  // cache for instantiation
    24  	scope        *types.Scope                    // *types.Named declared within this scope can be substituted (optional)
    25  	debug        bool                            // perform extra debugging checks
    26  	// TODO(taking): consider adding Pos
    27  	// TODO(zpavlinovic): replacements can contain type params
    28  	// when generating instances inside of a generic function body.
    29  }
    30  
    31  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
    32  // targs should not contain any types in tparams.
    33  // scope is the (optional) lexical block of the generic function for which we are substituting.
    34  func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
    35  	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
    36  
    37  	subst := &subster{
    38  		replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
    39  		cache:        make(map[types.Type]types.Type),
    40  		ctxt:         ctxt,
    41  		scope:        scope,
    42  		debug:        debug,
    43  	}
    44  	for i := 0; i < tparams.Len(); i++ {
    45  		subst.replacements[tparams.At(i)] = targs[i]
    46  	}
    47  	if subst.debug {
    48  		subst.wellFormed()
    49  	}
    50  	return subst
    51  }
    52  
    53  // wellFormed asserts that subst was properly initialized.
    54  func (subst *subster) wellFormed() {
    55  	if subst == nil {
    56  		return
    57  	}
    58  	// Check that all of the type params do not appear in the arguments.
    59  	s := make(map[types.Type]bool, len(subst.replacements))
    60  	for tparam := range subst.replacements {
    61  		s[tparam] = true
    62  	}
    63  	for _, r := range subst.replacements {
    64  		if reaches(r, s) {
    65  			panic(subst)
    66  		}
    67  	}
    68  }
    69  
    70  // typ returns the type of t with the type parameter tparams[i] substituted
    71  // for the type targs[i] where subst was created using tparams and targs.
    72  func (subst *subster) typ(t types.Type) (res types.Type) {
    73  	if subst == nil {
    74  		return t // A nil subst is type preserving.
    75  	}
    76  	if r, ok := subst.cache[t]; ok {
    77  		return r
    78  	}
    79  	defer func() {
    80  		subst.cache[t] = res
    81  	}()
    82  
    83  	switch t := t.(type) {
    84  	case *types.TypeParam:
    85  		r := subst.replacements[t]
    86  		assert(r != nil, "type param without replacement encountered")
    87  		return r
    88  
    89  	case *types.Basic:
    90  		return t
    91  
    92  	case *types.Array:
    93  		if r := subst.typ(t.Elem()); r != t.Elem() {
    94  			return types.NewArray(r, t.Len())
    95  		}
    96  		return t
    97  
    98  	case *types.Slice:
    99  		if r := subst.typ(t.Elem()); r != t.Elem() {
   100  			return types.NewSlice(r)
   101  		}
   102  		return t
   103  
   104  	case *types.Pointer:
   105  		if r := subst.typ(t.Elem()); r != t.Elem() {
   106  			return types.NewPointer(r)
   107  		}
   108  		return t
   109  
   110  	case *types.Tuple:
   111  		return subst.tuple(t)
   112  
   113  	case *types.Struct:
   114  		return subst.struct_(t)
   115  
   116  	case *types.Map:
   117  		key := subst.typ(t.Key())
   118  		elem := subst.typ(t.Elem())
   119  		if key != t.Key() || elem != t.Elem() {
   120  			return types.NewMap(key, elem)
   121  		}
   122  		return t
   123  
   124  	case *types.Chan:
   125  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
   126  			return types.NewChan(t.Dir(), elem)
   127  		}
   128  		return t
   129  
   130  	case *types.Signature:
   131  		return subst.signature(t)
   132  
   133  	case *types.Union:
   134  		return subst.union(t)
   135  
   136  	case *types.Interface:
   137  		return subst.interface_(t)
   138  
   139  	case *aliases.Alias:
   140  		return subst.alias(t)
   141  
   142  	case *types.Named:
   143  		return subst.named(t)
   144  
   145  	default:
   146  		panic("unreachable")
   147  	}
   148  }
   149  
   150  // types returns the result of {subst.typ(ts[i])}.
   151  func (subst *subster) types(ts []types.Type) []types.Type {
   152  	res := make([]types.Type, len(ts))
   153  	for i := range ts {
   154  		res[i] = subst.typ(ts[i])
   155  	}
   156  	return res
   157  }
   158  
   159  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
   160  	if t != nil {
   161  		if vars := subst.varlist(t); vars != nil {
   162  			return types.NewTuple(vars...)
   163  		}
   164  	}
   165  	return t
   166  }
   167  
   168  type varlist interface {
   169  	At(i int) *types.Var
   170  	Len() int
   171  }
   172  
   173  // fieldlist is an adapter for structs for the varlist interface.
   174  type fieldlist struct {
   175  	str *types.Struct
   176  }
   177  
   178  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
   179  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
   180  
   181  func (subst *subster) struct_(t *types.Struct) *types.Struct {
   182  	if t != nil {
   183  		if fields := subst.varlist(fieldlist{t}); fields != nil {
   184  			tags := make([]string, t.NumFields())
   185  			for i, n := 0, t.NumFields(); i < n; i++ {
   186  				tags[i] = t.Tag(i)
   187  			}
   188  			return types.NewStruct(fields, tags)
   189  		}
   190  	}
   191  	return t
   192  }
   193  
   194  // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
   195  func (subst *subster) varlist(in varlist) []*types.Var {
   196  	var out []*types.Var // nil => no updates
   197  	for i, n := 0, in.Len(); i < n; i++ {
   198  		v := in.At(i)
   199  		w := subst.var_(v)
   200  		if v != w && out == nil {
   201  			out = make([]*types.Var, n)
   202  			for j := 0; j < i; j++ {
   203  				out[j] = in.At(j)
   204  			}
   205  		}
   206  		if out != nil {
   207  			out[i] = w
   208  		}
   209  	}
   210  	return out
   211  }
   212  
   213  func (subst *subster) var_(v *types.Var) *types.Var {
   214  	if v != nil {
   215  		if typ := subst.typ(v.Type()); typ != v.Type() {
   216  			if v.IsField() {
   217  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
   218  			}
   219  			return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
   220  		}
   221  	}
   222  	return v
   223  }
   224  
   225  func (subst *subster) union(u *types.Union) *types.Union {
   226  	var out []*types.Term // nil => no updates
   227  
   228  	for i, n := 0, u.Len(); i < n; i++ {
   229  		t := u.Term(i)
   230  		r := subst.typ(t.Type())
   231  		if r != t.Type() && out == nil {
   232  			out = make([]*types.Term, n)
   233  			for j := 0; j < i; j++ {
   234  				out[j] = u.Term(j)
   235  			}
   236  		}
   237  		if out != nil {
   238  			out[i] = types.NewTerm(t.Tilde(), r)
   239  		}
   240  	}
   241  
   242  	if out != nil {
   243  		return types.NewUnion(out)
   244  	}
   245  	return u
   246  }
   247  
   248  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
   249  	if iface == nil {
   250  		return nil
   251  	}
   252  
   253  	// methods for the interface. Initially nil if there is no known change needed.
   254  	// Signatures for the method where recv is nil. NewInterfaceType fills in the receivers.
   255  	var methods []*types.Func
   256  	initMethods := func(n int) { // copy first n explicit methods
   257  		methods = make([]*types.Func, iface.NumExplicitMethods())
   258  		for i := 0; i < n; i++ {
   259  			f := iface.ExplicitMethod(i)
   260  			norecv := changeRecv(f.Type().(*types.Signature), nil)
   261  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
   262  		}
   263  	}
   264  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   265  		f := iface.ExplicitMethod(i)
   266  		// On interfaces, we need to cycle break on anonymous interface types
   267  		// being in a cycle with their signatures being in cycles with their receivers
   268  		// that do not go through a Named.
   269  		norecv := changeRecv(f.Type().(*types.Signature), nil)
   270  		sig := subst.typ(norecv)
   271  		if sig != norecv && methods == nil {
   272  			initMethods(i)
   273  		}
   274  		if methods != nil {
   275  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
   276  		}
   277  	}
   278  
   279  	var embeds []types.Type
   280  	initEmbeds := func(n int) { // copy first n embedded types
   281  		embeds = make([]types.Type, iface.NumEmbeddeds())
   282  		for i := 0; i < n; i++ {
   283  			embeds[i] = iface.EmbeddedType(i)
   284  		}
   285  	}
   286  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   287  		e := iface.EmbeddedType(i)
   288  		r := subst.typ(e)
   289  		if e != r && embeds == nil {
   290  			initEmbeds(i)
   291  		}
   292  		if embeds != nil {
   293  			embeds[i] = r
   294  		}
   295  	}
   296  
   297  	if methods == nil && embeds == nil {
   298  		return iface
   299  	}
   300  	if methods == nil {
   301  		initMethods(iface.NumExplicitMethods())
   302  	}
   303  	if embeds == nil {
   304  		initEmbeds(iface.NumEmbeddeds())
   305  	}
   306  	return types.NewInterfaceType(methods, embeds).Complete()
   307  }
   308  
   309  func (subst *subster) alias(t *aliases.Alias) types.Type {
   310  	// TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
   311  	u := aliases.Unalias(t)
   312  	if s := subst.typ(u); s != u {
   313  		// If there is any change, do not create a new alias.
   314  		return s
   315  	}
   316  	// If there is no change, t did not reach any type parameter.
   317  	// Keep the Alias.
   318  	return t
   319  }
   320  
   321  func (subst *subster) named(t *types.Named) types.Type {
   322  	// A named type may be:
   323  	// (1) ordinary named type (non-local scope, no type parameters, no type arguments),
   324  	// (2) locally scoped type,
   325  	// (3) generic (type parameters but no type arguments), or
   326  	// (4) instantiated (type parameters and type arguments).
   327  	tparams := t.TypeParams()
   328  	if tparams.Len() == 0 {
   329  		if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
   330  			// Outside the current function scope?
   331  			return t // case (1) ordinary
   332  		}
   333  
   334  		// case (2) locally scoped type.
   335  		// Create a new named type to represent this instantiation.
   336  		// We assume that local types of distinct instantiations of a
   337  		// generic function are distinct, even if they don't refer to
   338  		// type parameters, but the spec is unclear; see golang/go#58573.
   339  		//
   340  		// Subtle: We short circuit substitution and use a newly created type in
   341  		// subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive
   342  		// types during traversal. This both breaks infinite cycles and allows for
   343  		// constructing types with the replacement applied in subst.typ(under).
   344  		//
   345  		// Example:
   346  		// func foo[T any]() {
   347  		//   type linkedlist struct {
   348  		//     next *linkedlist
   349  		//     val T
   350  		//   }
   351  		// }
   352  		//
   353  		// When the field `next *linkedlist` is visited during subst.typ(under),
   354  		// we want the substituted type for the field `next` to be `*n`.
   355  		n := types.NewNamed(t.Obj(), nil, nil)
   356  		subst.cache[t] = n
   357  		subst.cache[n] = n
   358  		n.SetUnderlying(subst.typ(t.Underlying()))
   359  		return n
   360  	}
   361  	targs := t.TypeArgs()
   362  
   363  	// insts are arguments to instantiate using.
   364  	insts := make([]types.Type, tparams.Len())
   365  
   366  	// case (3) generic ==> targs.Len() == 0
   367  	// Instantiating a generic with no type arguments should be unreachable.
   368  	// Please report a bug if you encounter this.
   369  	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
   370  
   371  	// case (4) instantiated.
   372  	// Substitute into the type arguments and instantiate the replacements/
   373  	// Example:
   374  	//    type N[A any] func() A
   375  	//    func Foo[T](g N[T]) {}
   376  	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
   377  	//  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
   378  	//  to get {N with TypeArgs == {string} and typeparams == {A} }.
   379  	assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
   380  	for i, n := 0, targs.Len(); i < n; i++ {
   381  		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
   382  		insts[i] = inst
   383  	}
   384  	r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false)
   385  	assert(err == nil, "failed to Instantiate Named type")
   386  	return r
   387  }
   388  
   389  func (subst *subster) signature(t *types.Signature) types.Type {
   390  	tparams := t.TypeParams()
   391  
   392  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
   393  	//
   394  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
   395  	// To support tparams.Len() > 0, we just need to do the following [psuedocode]:
   396  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
   397  
   398  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
   399  
   400  	// Either:
   401  	// (1)non-generic function.
   402  	//    no type params to substitute
   403  	// (2)generic method and recv needs to be substituted.
   404  
   405  	// Receivers can be either:
   406  	// named
   407  	// pointer to named
   408  	// interface
   409  	// nil
   410  	// interface is the problematic case. We need to cycle break there!
   411  	recv := subst.var_(t.Recv())
   412  	params := subst.tuple(t.Params())
   413  	results := subst.tuple(t.Results())
   414  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
   415  		return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
   416  	}
   417  	return t
   418  }
   419  
   420  // reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
   421  // It updates c to cache results.
   422  //
   423  // reaches is currently only part of the wellFormed debug logic, and
   424  // in practice c is initially only type parameters. It is not currently
   425  // relied on in production.
   426  func reaches(t types.Type, c map[types.Type]bool) (res bool) {
   427  	if c, ok := c[t]; ok {
   428  		return c
   429  	}
   430  
   431  	// c is populated with temporary false entries as types are visited.
   432  	// This avoids repeat visits and break cycles.
   433  	c[t] = false
   434  	defer func() {
   435  		c[t] = res
   436  	}()
   437  
   438  	switch t := t.(type) {
   439  	case *types.TypeParam, *types.Basic:
   440  		return false
   441  	case *types.Array:
   442  		return reaches(t.Elem(), c)
   443  	case *types.Slice:
   444  		return reaches(t.Elem(), c)
   445  	case *types.Pointer:
   446  		return reaches(t.Elem(), c)
   447  	case *types.Tuple:
   448  		for i := 0; i < t.Len(); i++ {
   449  			if reaches(t.At(i).Type(), c) {
   450  				return true
   451  			}
   452  		}
   453  	case *types.Struct:
   454  		for i := 0; i < t.NumFields(); i++ {
   455  			if reaches(t.Field(i).Type(), c) {
   456  				return true
   457  			}
   458  		}
   459  	case *types.Map:
   460  		return reaches(t.Key(), c) || reaches(t.Elem(), c)
   461  	case *types.Chan:
   462  		return reaches(t.Elem(), c)
   463  	case *types.Signature:
   464  		if t.Recv() != nil && reaches(t.Recv().Type(), c) {
   465  			return true
   466  		}
   467  		return reaches(t.Params(), c) || reaches(t.Results(), c)
   468  	case *types.Union:
   469  		for i := 0; i < t.Len(); i++ {
   470  			if reaches(t.Term(i).Type(), c) {
   471  				return true
   472  			}
   473  		}
   474  	case *types.Interface:
   475  		for i := 0; i < t.NumEmbeddeds(); i++ {
   476  			if reaches(t.Embedded(i), c) {
   477  				return true
   478  			}
   479  		}
   480  		for i := 0; i < t.NumExplicitMethods(); i++ {
   481  			if reaches(t.ExplicitMethod(i).Type(), c) {
   482  				return true
   483  			}
   484  		}
   485  	case *types.Named, *aliases.Alias:
   486  		return reaches(t.Underlying(), c)
   487  	default:
   488  		panic("unreachable")
   489  	}
   490  	return false
   491  }
   492  

View as plain text