1
2
3
4
5
6
7
8
9
10
11
12
13
14 package satisfy
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 import (
41 "fmt"
42 "go/ast"
43 "go/token"
44 "go/types"
45
46 "golang.org/x/tools/go/ast/astutil"
47 "golang.org/x/tools/go/types/typeutil"
48 "golang.org/x/tools/internal/typeparams"
49 )
50
51
52
53
54
55
56
57 type Constraint struct {
58 LHS, RHS types.Type
59 }
60
61
62
63
64
65
66
67
68
69 type Finder struct {
70 Result map[Constraint]bool
71 msetcache typeutil.MethodSetCache
72
73
74 info *types.Info
75 sig *types.Signature
76 }
77
78
79
80
81
82
83
84
85
86
87 func (f *Finder) Find(info *types.Info, files []*ast.File) {
88 if f.Result == nil {
89 f.Result = make(map[Constraint]bool)
90 }
91
92 f.info = info
93 for _, file := range files {
94 for _, d := range file.Decls {
95 switch d := d.(type) {
96 case *ast.GenDecl:
97 if d.Tok == token.VAR {
98 for _, spec := range d.Specs {
99 f.valueSpec(spec.(*ast.ValueSpec))
100 }
101 }
102
103 case *ast.FuncDecl:
104 if d.Body != nil {
105 f.sig = f.info.Defs[d.Name].Type().(*types.Signature)
106 f.stmt(d.Body)
107 f.sig = nil
108 }
109 }
110 }
111 }
112 f.info = nil
113 }
114
115 var (
116 tInvalid = types.Typ[types.Invalid]
117 tUntypedBool = types.Typ[types.UntypedBool]
118 tUntypedNil = types.Typ[types.UntypedNil]
119 )
120
121
122 func (f *Finder) exprN(e ast.Expr) types.Type {
123 typ := f.info.Types[e].Type.(*types.Tuple)
124 switch e := e.(type) {
125 case *ast.ParenExpr:
126 return f.exprN(e.X)
127
128 case *ast.CallExpr:
129
130 sig := coreType(f.expr(e.Fun)).(*types.Signature)
131 f.call(sig, e.Args)
132
133 case *ast.IndexExpr:
134
135 x := f.expr(e.X)
136 f.assign(f.expr(e.Index), coreType(x).(*types.Map).Key())
137
138 case *ast.TypeAssertExpr:
139
140 f.typeAssert(f.expr(e.X), typ.At(0).Type())
141
142 case *ast.UnaryExpr:
143
144 f.expr(e.X)
145
146 default:
147 panic(e)
148 }
149 return typ
150 }
151
152 func (f *Finder) call(sig *types.Signature, args []ast.Expr) {
153 if len(args) == 0 {
154 return
155 }
156
157
158 if _, ok := args[len(args)-1].(*ast.Ellipsis); ok {
159 for i, arg := range args {
160
161 f.assign(sig.Params().At(i).Type(), f.expr(arg))
162 }
163 return
164 }
165
166 var argtypes []types.Type
167
168
169 if tuple, ok := f.info.Types[args[0]].Type.(*types.Tuple); ok {
170
171 f.expr(args[0])
172
173 for i := 0; i < tuple.Len(); i++ {
174 argtypes = append(argtypes, tuple.At(i).Type())
175 }
176 } else {
177 for _, arg := range args {
178 argtypes = append(argtypes, f.expr(arg))
179 }
180 }
181
182
183 if !sig.Variadic() {
184 for i, argtype := range argtypes {
185 f.assign(sig.Params().At(i).Type(), argtype)
186 }
187 } else {
188
189 nnormals := sig.Params().Len() - 1
190 for i, argtype := range argtypes[:nnormals] {
191 f.assign(sig.Params().At(i).Type(), argtype)
192 }
193
194 tElem := sig.Params().At(nnormals).Type().(*types.Slice).Elem()
195 for i := nnormals; i < len(argtypes); i++ {
196 f.assign(tElem, argtypes[i])
197 }
198 }
199 }
200
201
202 func (f *Finder) builtin(obj *types.Builtin, sig *types.Signature, args []ast.Expr) {
203 switch obj.Name() {
204 case "make", "new":
205
206 for _, arg := range args[1:] {
207 f.expr(arg)
208 }
209
210 case "append":
211 s := f.expr(args[0])
212 if _, ok := args[len(args)-1].(*ast.Ellipsis); ok && len(args) == 2 {
213
214 f.expr(args[1])
215 } else {
216
217 tElem := coreType(s).(*types.Slice).Elem()
218 for _, arg := range args[1:] {
219 f.assign(tElem, f.expr(arg))
220 }
221 }
222
223 case "delete":
224 m := f.expr(args[0])
225 k := f.expr(args[1])
226 f.assign(coreType(m).(*types.Map).Key(), k)
227
228 default:
229
230 f.call(sig, args)
231 }
232 }
233
234 func (f *Finder) extract(tuple types.Type, i int) types.Type {
235 if tuple, ok := tuple.(*types.Tuple); ok && i < tuple.Len() {
236 return tuple.At(i).Type()
237 }
238 return tInvalid
239 }
240
241 func (f *Finder) valueSpec(spec *ast.ValueSpec) {
242 var T types.Type
243 if spec.Type != nil {
244 T = f.info.Types[spec.Type].Type
245 }
246 switch len(spec.Values) {
247 case len(spec.Names):
248 for _, value := range spec.Values {
249 v := f.expr(value)
250 if T != nil {
251 f.assign(T, v)
252 }
253 }
254
255 case 1:
256 tuple := f.exprN(spec.Values[0])
257 for i := range spec.Names {
258 if T != nil {
259 f.assign(T, f.extract(tuple, i))
260 }
261 }
262 }
263 }
264
265
266
267
268
269
270
271
272
273 func (f *Finder) assign(lhs, rhs types.Type) {
274 if types.Identical(lhs, rhs) {
275 return
276 }
277 if !isInterface(lhs) {
278 return
279 }
280
281 if f.msetcache.MethodSet(lhs).Len() == 0 {
282 return
283 }
284 if f.msetcache.MethodSet(rhs).Len() == 0 {
285 return
286 }
287
288 f.Result[Constraint{lhs, rhs}] = true
289 }
290
291
292
293 func (f *Finder) typeAssert(I, T types.Type) {
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308 if types.AssignableTo(T, I) {
309 f.assign(I, T)
310 }
311 }
312
313
314 func (f *Finder) compare(x, y types.Type) {
315 if types.AssignableTo(x, y) {
316 f.assign(y, x)
317 } else if types.AssignableTo(y, x) {
318 f.assign(x, y)
319 }
320 }
321
322
323
324 func (f *Finder) expr(e ast.Expr) types.Type {
325 tv := f.info.Types[e]
326 if tv.Value != nil {
327 return tv.Type
328 }
329
330
331
332 switch e := e.(type) {
333 case *ast.BadExpr, *ast.BasicLit:
334
335
336 case *ast.Ident:
337
338 if obj, ok := f.info.Uses[e]; ok {
339 return obj.Type()
340 }
341 if e.Name == "_" {
342 return tInvalid
343 }
344 panic("undefined ident: " + e.Name)
345
346 case *ast.Ellipsis:
347 if e.Elt != nil {
348 f.expr(e.Elt)
349 }
350
351 case *ast.FuncLit:
352 saved := f.sig
353 f.sig = tv.Type.(*types.Signature)
354 f.stmt(e.Body)
355 f.sig = saved
356
357 case *ast.CompositeLit:
358 switch T := coreType(typeparams.Deref(tv.Type)).(type) {
359 case *types.Struct:
360 for i, elem := range e.Elts {
361 if kv, ok := elem.(*ast.KeyValueExpr); ok {
362 f.assign(f.info.Uses[kv.Key.(*ast.Ident)].Type(), f.expr(kv.Value))
363 } else {
364 f.assign(T.Field(i).Type(), f.expr(elem))
365 }
366 }
367
368 case *types.Map:
369 for _, elem := range e.Elts {
370 elem := elem.(*ast.KeyValueExpr)
371 f.assign(T.Key(), f.expr(elem.Key))
372 f.assign(T.Elem(), f.expr(elem.Value))
373 }
374
375 case *types.Array, *types.Slice:
376 tElem := T.(interface {
377 Elem() types.Type
378 }).Elem()
379 for _, elem := range e.Elts {
380 if kv, ok := elem.(*ast.KeyValueExpr); ok {
381
382 f.assign(tElem, f.expr(kv.Value))
383 } else {
384 f.assign(tElem, f.expr(elem))
385 }
386 }
387
388 default:
389 panic(fmt.Sprintf("unexpected composite literal type %T: %v", tv.Type, tv.Type.String()))
390 }
391
392 case *ast.ParenExpr:
393 f.expr(e.X)
394
395 case *ast.SelectorExpr:
396 if _, ok := f.info.Selections[e]; ok {
397 f.expr(e.X)
398 } else {
399 return f.info.Uses[e.Sel].Type()
400 }
401
402 case *ast.IndexExpr:
403 if instance(f.info, e.X) {
404
405 } else {
406
407 x := f.expr(e.X)
408 i := f.expr(e.Index)
409 if ux, ok := coreType(x).(*types.Map); ok {
410 f.assign(ux.Key(), i)
411 }
412 }
413
414 case *ast.IndexListExpr:
415
416
417 case *ast.SliceExpr:
418 f.expr(e.X)
419 if e.Low != nil {
420 f.expr(e.Low)
421 }
422 if e.High != nil {
423 f.expr(e.High)
424 }
425 if e.Max != nil {
426 f.expr(e.Max)
427 }
428
429 case *ast.TypeAssertExpr:
430 x := f.expr(e.X)
431 f.typeAssert(x, f.info.Types[e.Type].Type)
432
433 case *ast.CallExpr:
434 if tvFun := f.info.Types[e.Fun]; tvFun.IsType() {
435
436 arg0 := f.expr(e.Args[0])
437 f.assign(tvFun.Type, arg0)
438 } else {
439
440
441
442
443
444 if s, ok := unparen(e.Fun).(*ast.SelectorExpr); ok {
445 if obj, ok := f.info.Uses[s.Sel].(*types.Builtin); ok && obj.Pkg().Path() == "unsafe" {
446 sig := f.info.Types[e.Fun].Type.(*types.Signature)
447 f.call(sig, e.Args)
448 return tv.Type
449 }
450 }
451
452
453 if id, ok := unparen(e.Fun).(*ast.Ident); ok {
454 if obj, ok := f.info.Uses[id].(*types.Builtin); ok {
455 sig := f.info.Types[id].Type.(*types.Signature)
456 f.builtin(obj, sig, e.Args)
457 return tv.Type
458 }
459 }
460
461
462 f.call(coreType(f.expr(e.Fun)).(*types.Signature), e.Args)
463 }
464
465 case *ast.StarExpr:
466 f.expr(e.X)
467
468 case *ast.UnaryExpr:
469 f.expr(e.X)
470
471 case *ast.BinaryExpr:
472 x := f.expr(e.X)
473 y := f.expr(e.Y)
474 if e.Op == token.EQL || e.Op == token.NEQ {
475 f.compare(x, y)
476 }
477
478 case *ast.KeyValueExpr:
479 f.expr(e.Key)
480 f.expr(e.Value)
481
482 case *ast.ArrayType,
483 *ast.StructType,
484 *ast.FuncType,
485 *ast.InterfaceType,
486 *ast.MapType,
487 *ast.ChanType:
488 panic(e)
489 }
490
491 if tv.Type == nil {
492 panic(fmt.Sprintf("no type for %T", e))
493 }
494
495 return tv.Type
496 }
497
498 func (f *Finder) stmt(s ast.Stmt) {
499 switch s := s.(type) {
500 case *ast.BadStmt,
501 *ast.EmptyStmt,
502 *ast.BranchStmt:
503
504
505 case *ast.DeclStmt:
506 d := s.Decl.(*ast.GenDecl)
507 if d.Tok == token.VAR {
508 for _, spec := range d.Specs {
509 f.valueSpec(spec.(*ast.ValueSpec))
510 }
511 }
512
513 case *ast.LabeledStmt:
514 f.stmt(s.Stmt)
515
516 case *ast.ExprStmt:
517 f.expr(s.X)
518
519 case *ast.SendStmt:
520 ch := f.expr(s.Chan)
521 val := f.expr(s.Value)
522 f.assign(coreType(ch).(*types.Chan).Elem(), val)
523
524 case *ast.IncDecStmt:
525 f.expr(s.X)
526
527 case *ast.AssignStmt:
528 switch s.Tok {
529 case token.ASSIGN, token.DEFINE:
530
531 var rhsTuple types.Type
532 if len(s.Lhs) != len(s.Rhs) {
533 rhsTuple = f.exprN(s.Rhs[0])
534 }
535 for i := range s.Lhs {
536 var lhs, rhs types.Type
537 if rhsTuple == nil {
538 rhs = f.expr(s.Rhs[i])
539 } else {
540 rhs = f.extract(rhsTuple, i)
541 }
542
543 if id, ok := s.Lhs[i].(*ast.Ident); ok {
544 if id.Name != "_" {
545 if obj, ok := f.info.Defs[id]; ok {
546 lhs = obj.Type()
547 }
548 }
549 }
550 if lhs == nil {
551 lhs = f.expr(s.Lhs[i])
552 }
553 f.assign(lhs, rhs)
554 }
555
556 default:
557
558 f.expr(s.Lhs[0])
559 f.expr(s.Rhs[0])
560 }
561
562 case *ast.GoStmt:
563 f.expr(s.Call)
564
565 case *ast.DeferStmt:
566 f.expr(s.Call)
567
568 case *ast.ReturnStmt:
569 formals := f.sig.Results()
570 switch len(s.Results) {
571 case formals.Len():
572 for i, result := range s.Results {
573 f.assign(formals.At(i).Type(), f.expr(result))
574 }
575
576 case 1:
577 tuple := f.exprN(s.Results[0])
578 for i := 0; i < formals.Len(); i++ {
579 f.assign(formals.At(i).Type(), f.extract(tuple, i))
580 }
581 }
582
583 case *ast.SelectStmt:
584 f.stmt(s.Body)
585
586 case *ast.BlockStmt:
587 for _, s := range s.List {
588 f.stmt(s)
589 }
590
591 case *ast.IfStmt:
592 if s.Init != nil {
593 f.stmt(s.Init)
594 }
595 f.expr(s.Cond)
596 f.stmt(s.Body)
597 if s.Else != nil {
598 f.stmt(s.Else)
599 }
600
601 case *ast.SwitchStmt:
602 if s.Init != nil {
603 f.stmt(s.Init)
604 }
605 var tag types.Type = tUntypedBool
606 if s.Tag != nil {
607 tag = f.expr(s.Tag)
608 }
609 for _, cc := range s.Body.List {
610 cc := cc.(*ast.CaseClause)
611 for _, cond := range cc.List {
612 f.compare(tag, f.info.Types[cond].Type)
613 }
614 for _, s := range cc.Body {
615 f.stmt(s)
616 }
617 }
618
619 case *ast.TypeSwitchStmt:
620 if s.Init != nil {
621 f.stmt(s.Init)
622 }
623 var I types.Type
624 switch ass := s.Assign.(type) {
625 case *ast.ExprStmt:
626 I = f.expr(unparen(ass.X).(*ast.TypeAssertExpr).X)
627 case *ast.AssignStmt:
628 I = f.expr(unparen(ass.Rhs[0]).(*ast.TypeAssertExpr).X)
629 }
630 for _, cc := range s.Body.List {
631 cc := cc.(*ast.CaseClause)
632 for _, cond := range cc.List {
633 tCase := f.info.Types[cond].Type
634 if tCase != tUntypedNil {
635 f.typeAssert(I, tCase)
636 }
637 }
638 for _, s := range cc.Body {
639 f.stmt(s)
640 }
641 }
642
643 case *ast.CommClause:
644 if s.Comm != nil {
645 f.stmt(s.Comm)
646 }
647 for _, s := range s.Body {
648 f.stmt(s)
649 }
650
651 case *ast.ForStmt:
652 if s.Init != nil {
653 f.stmt(s.Init)
654 }
655 if s.Cond != nil {
656 f.expr(s.Cond)
657 }
658 if s.Post != nil {
659 f.stmt(s.Post)
660 }
661 f.stmt(s.Body)
662
663 case *ast.RangeStmt:
664 x := f.expr(s.X)
665
666 if s.Tok == token.ASSIGN {
667 if s.Key != nil {
668 k := f.expr(s.Key)
669 var xelem types.Type
670
671
672 switch ux := coreType(x).(type) {
673 case *types.Chan:
674 xelem = ux.Elem()
675 case *types.Map:
676 xelem = ux.Key()
677 }
678 if xelem != nil {
679 f.assign(k, xelem)
680 }
681 }
682 if s.Value != nil {
683 val := f.expr(s.Value)
684 var xelem types.Type
685
686
687 switch ux := coreType(x).(type) {
688 case *types.Array:
689 xelem = ux.Elem()
690 case *types.Map:
691 xelem = ux.Elem()
692 case *types.Pointer:
693 xelem = coreType(typeparams.Deref(ux)).(*types.Array).Elem()
694 case *types.Slice:
695 xelem = ux.Elem()
696 }
697 if xelem != nil {
698 f.assign(val, xelem)
699 }
700 }
701 }
702 f.stmt(s.Body)
703
704 default:
705 panic(s)
706 }
707 }
708
709
710
711 func unparen(e ast.Expr) ast.Expr { return astutil.Unparen(e) }
712
713 func isInterface(T types.Type) bool { return types.IsInterface(T) }
714
715 func coreType(T types.Type) types.Type { return typeparams.CoreType(T) }
716
717 func instance(info *types.Info, expr ast.Expr) bool {
718 var id *ast.Ident
719 switch x := expr.(type) {
720 case *ast.Ident:
721 id = x
722 case *ast.SelectorExpr:
723 id = x.Sel
724 default:
725 return false
726 }
727 _, ok := info.Instances[id]
728 return ok
729 }
730
View as plain text