1 package code
2
3 import (
4 "go/ast"
5 "go/importer"
6 "go/parser"
7 "go/token"
8 "go/types"
9 "testing"
10
11 "github.com/stretchr/testify/require"
12 )
13
14 func TestCompatibleTypes(t *testing.T) {
15 valid := []struct {
16 expected string
17 actual string
18 }{
19 {"string", "string"},
20 {"*string", "string"},
21 {"string", "*string"},
22 {"*string", "*string"},
23 {"[]string", "[]string"},
24 {"*[]string", "[]string"},
25 {"*[]string", "[]*string"},
26 {"*[]*[]*[]string", "[][][]string"},
27 {"map[string]interface{}", "map[string]interface{}"},
28 {"map[string]string", "map[string]string"},
29 {"Bar", "Bar"},
30 {"interface{}", "interface{}"},
31 {"interface{Foo() bool}", "interface{Foo() bool}"},
32 {"struct{Foo bool}", "struct{Foo bool}"},
33 }
34
35 for _, tc := range valid {
36 t.Run(tc.expected+"="+tc.actual, func(t *testing.T) {
37 expectedType := parseTypeStr(t, tc.expected)
38 actualType := parseTypeStr(t, tc.actual)
39 require.NoError(t, CompatibleTypes(expectedType, actualType))
40 })
41 }
42
43 invalid := []struct {
44 expected string
45 actual string
46 }{
47 {"string", "int"},
48 {"*string", "[]string"},
49 {"[]string", "[][]string"},
50 {"Bar", "Baz"},
51 {"map[string]interface{}", "map[string]string"},
52 {"map[string]string", "[]string"},
53 {"interface{Foo() bool}", "interface{}"},
54 {"struct{Foo bool}", "struct{Bar bool}"},
55 }
56
57 for _, tc := range invalid {
58 t.Run(tc.expected+"!="+tc.actual, func(t *testing.T) {
59 expectedType := parseTypeStr(t, tc.expected)
60 actualType := parseTypeStr(t, tc.actual)
61 require.Error(t, CompatibleTypes(expectedType, actualType))
62 })
63 }
64 }
65
66 func parseTypeStr(t *testing.T, s string) types.Type {
67 t.Helper()
68
69 fset := token.NewFileSet()
70 f, err := parser.ParseFile(fset, "test.go", `package test
71 type Bar string
72 type Baz string
73
74 type Foo struct {
75 Field `+s+`
76 }
77 `, 0)
78 require.NoError(t, err)
79
80 conf := types.Config{Importer: importer.Default()}
81 pkg, err := conf.Check("test", fset, []*ast.File{f}, nil)
82 require.NoError(t, err)
83
84 return pkg.Scope().Lookup("Foo").Type().(*types.Named).Underlying().(*types.Struct).Field(0).Type()
85 }
86
View as plain text