...
1
2
3
4
5 package cmpopts
6
7 import (
8 "fmt"
9 "reflect"
10 "strings"
11
12 "github.com/google/go-cmp/cmp"
13 )
14
15
16
17
18
19
20
21 func filterField(typ interface{}, name string, opt cmp.Option) cmp.Option {
22
23
24
25
26 sf := newStructFilter(typ, name)
27 return cmp.FilterPath(sf.filter, opt)
28 }
29
30 type structFilter struct {
31 t reflect.Type
32 ft fieldTree
33 }
34
35 func newStructFilter(typ interface{}, names ...string) structFilter {
36
37
38
39
40
41
42
43 t := reflect.TypeOf(typ)
44 if t == nil || t.Kind() != reflect.Struct {
45 panic(fmt.Sprintf("%T must be a non-pointer struct", typ))
46 }
47 var ft fieldTree
48 for _, name := range names {
49 cname, err := canonicalName(t, name)
50 if err != nil {
51 panic(fmt.Sprintf("%s: %v", strings.Join(cname, "."), err))
52 }
53 ft.insert(cname)
54 }
55 return structFilter{t, ft}
56 }
57
58 func (sf structFilter) filter(p cmp.Path) bool {
59 for i, ps := range p {
60 if ps.Type().AssignableTo(sf.t) && sf.ft.matchPrefix(p[i+1:]) {
61 return true
62 }
63 }
64 return false
65 }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91 type fieldTree struct {
92 ok bool
93 sub map[string]fieldTree
94 }
95
96
97 func (ft *fieldTree) insert(cname []string) {
98 if ft.sub == nil {
99 ft.sub = make(map[string]fieldTree)
100 }
101 if len(cname) == 0 {
102 ft.ok = true
103 return
104 }
105 sub := ft.sub[cname[0]]
106 sub.insert(cname[1:])
107 ft.sub[cname[0]] = sub
108 }
109
110
111
112 func (ft fieldTree) matchPrefix(p cmp.Path) bool {
113 for _, ps := range p {
114 switch ps := ps.(type) {
115 case cmp.StructField:
116 ft = ft.sub[ps.Name()]
117 if ft.ok {
118 return true
119 }
120 if len(ft.sub) == 0 {
121 return false
122 }
123 case cmp.Indirect:
124 default:
125 return false
126 }
127 }
128 return false
129 }
130
131
132
133
134
135
136
137
138
139
140
141
142
143 func canonicalName(t reflect.Type, sel string) ([]string, error) {
144 var name string
145 sel = strings.TrimPrefix(sel, ".")
146 if sel == "" {
147 return nil, fmt.Errorf("name must not be empty")
148 }
149 if i := strings.IndexByte(sel, '.'); i < 0 {
150 name, sel = sel, ""
151 } else {
152 name, sel = sel[:i], sel[i:]
153 }
154
155
156 if t.Kind() == reflect.Ptr {
157 t = t.Elem()
158 }
159 if t.Kind() != reflect.Struct {
160 return nil, fmt.Errorf("%v must be a struct", t)
161 }
162
163
164
165 sf, _ := t.FieldByName(name)
166 if !isExported(name) {
167
168
169
170 sf = reflect.StructField{}
171 for i := 0; i < t.NumField() && sf.Name == ""; i++ {
172 if t.Field(i).Name == name {
173 sf = t.Field(i)
174 }
175 }
176 }
177 if sf.Name == "" {
178 return []string{name}, fmt.Errorf("does not exist")
179 }
180 var ss []string
181 for i := range sf.Index {
182 ss = append(ss, t.FieldByIndex(sf.Index[:i+1]).Name)
183 }
184 if sel == "" {
185 return ss, nil
186 }
187 ssPost, err := canonicalName(sf.Type, sel)
188 return append(ss, ssPost...), err
189 }
190
View as plain text