1
2
3
4
5 package packagestest
6
7 import (
8 "fmt"
9 "go/token"
10 "os"
11 "path/filepath"
12 "reflect"
13 "regexp"
14 "strings"
15
16 "golang.org/x/tools/go/expect"
17 "golang.org/x/tools/go/packages"
18 )
19
20 const (
21 markMethod = "mark"
22 eofIdentifier = "EOF"
23 )
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 func (e *Exported) Expect(methods map[string]interface{}) error {
76 if err := e.getNotes(); err != nil {
77 return err
78 }
79 if err := e.getMarkers(); err != nil {
80 return err
81 }
82 var err error
83 ms := make(map[string]method, len(methods))
84 for name, f := range methods {
85 mi := method{f: reflect.ValueOf(f)}
86 mi.converters = make([]converter, mi.f.Type().NumIn())
87 for i := 0; i < len(mi.converters); i++ {
88 mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
89 if err != nil {
90 return fmt.Errorf("invalid method %v: %v", name, err)
91 }
92 }
93 ms[name] = mi
94 }
95 for _, n := range e.notes {
96 if n.Args == nil {
97
98 n = &expect.Note{
99 Pos: n.Pos,
100 Name: markMethod,
101 Args: []interface{}{n.Name, n.Name},
102 }
103 }
104 mi, ok := ms[n.Name]
105 if !ok {
106 continue
107 }
108 params := make([]reflect.Value, len(mi.converters))
109 args := n.Args
110 for i, convert := range mi.converters {
111 params[i], args, err = convert(n, args)
112 if err != nil {
113 return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err)
114 }
115 }
116 if len(args) > 0 {
117 return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args)
118 }
119
120 mi.f.Call(params)
121 }
122 return nil
123 }
124
125
126 type Range struct {
127 TokFile *token.File
128 Start, End token.Pos
129 }
130
131
132 func (e *Exported) Mark(name string, r Range) {
133 if e.markers == nil {
134 e.markers = make(map[string]Range)
135 }
136 e.markers[name] = r
137 }
138
139 func (e *Exported) getNotes() error {
140 if e.notes != nil {
141 return nil
142 }
143 notes := []*expect.Note{}
144 var dirs []string
145 for _, module := range e.written {
146 for _, filename := range module {
147 dirs = append(dirs, filepath.Dir(filename))
148 }
149 }
150 for filename := range e.Config.Overlay {
151 dirs = append(dirs, filepath.Dir(filename))
152 }
153 pkgs, err := packages.Load(e.Config, dirs...)
154 if err != nil {
155 return fmt.Errorf("unable to load packages for directories %s: %v", dirs, err)
156 }
157 seen := make(map[token.Position]struct{})
158 for _, pkg := range pkgs {
159 for _, filename := range pkg.GoFiles {
160 content, err := e.FileContents(filename)
161 if err != nil {
162 return err
163 }
164 l, err := expect.Parse(e.ExpectFileSet, filename, content)
165 if err != nil {
166 return fmt.Errorf("failed to extract expectations: %v", err)
167 }
168 for _, note := range l {
169 pos := e.ExpectFileSet.Position(note.Pos)
170 if _, ok := seen[pos]; ok {
171 continue
172 }
173 notes = append(notes, note)
174 seen[pos] = struct{}{}
175 }
176 }
177 }
178 if _, ok := e.written[e.primary]; !ok {
179 e.notes = notes
180 return nil
181 }
182
183
184 if gomod, found := e.written[e.primary]["go.mod"]; found {
185
186 if e.Exporter == Modules {
187 gomod += ".temp"
188 }
189 l, err := goModMarkers(e, gomod)
190 if err != nil {
191 return fmt.Errorf("failed to extract expectations for go.mod: %v", err)
192 }
193 notes = append(notes, l...)
194 }
195 e.notes = notes
196 return nil
197 }
198
199 func goModMarkers(e *Exported, gomod string) ([]*expect.Note, error) {
200 if _, err := os.Stat(gomod); os.IsNotExist(err) {
201
202 return nil, nil
203 }
204 content, err := e.FileContents(gomod)
205 if err != nil {
206 return nil, err
207 }
208 if e.Exporter == GOPATH {
209 return expect.Parse(e.ExpectFileSet, gomod, content)
210 }
211 gomod = strings.TrimSuffix(gomod, ".temp")
212
213 if err := os.WriteFile(gomod, content, 0644); err != nil {
214 return nil, nil
215 }
216 return expect.Parse(e.ExpectFileSet, gomod, content)
217 }
218
219 func (e *Exported) getMarkers() error {
220 if e.markers != nil {
221 return nil
222 }
223
224 e.markers = make(map[string]Range)
225 return e.Expect(map[string]interface{}{
226 markMethod: e.Mark,
227 })
228 }
229
230 var (
231 noteType = reflect.TypeOf((*expect.Note)(nil))
232 identifierType = reflect.TypeOf(expect.Identifier(""))
233 posType = reflect.TypeOf(token.Pos(0))
234 positionType = reflect.TypeOf(token.Position{})
235 rangeType = reflect.TypeOf(Range{})
236 fsetType = reflect.TypeOf((*token.FileSet)(nil))
237 regexType = reflect.TypeOf((*regexp.Regexp)(nil))
238 exportedType = reflect.TypeOf((*Exported)(nil))
239 )
240
241
242
243
244
245
246 type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
247
248
249
250 type method struct {
251 f reflect.Value
252 converters []converter
253 }
254
255
256
257
258
259 func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
260 switch {
261 case pt == noteType:
262 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
263 return reflect.ValueOf(n), args, nil
264 }, nil
265 case pt == fsetType:
266 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
267 return reflect.ValueOf(e.ExpectFileSet), args, nil
268 }, nil
269 case pt == exportedType:
270 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
271 return reflect.ValueOf(e), args, nil
272 }, nil
273 case pt == posType:
274 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
275 r, remains, err := e.rangeConverter(n, args)
276 if err != nil {
277 return reflect.Value{}, nil, err
278 }
279 return reflect.ValueOf(r.Start), remains, nil
280 }, nil
281 case pt == positionType:
282 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
283 r, remains, err := e.rangeConverter(n, args)
284 if err != nil {
285 return reflect.Value{}, nil, err
286 }
287 return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil
288 }, nil
289 case pt == rangeType:
290 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
291 r, remains, err := e.rangeConverter(n, args)
292 if err != nil {
293 return reflect.Value{}, nil, err
294 }
295 return reflect.ValueOf(r), remains, nil
296 }, nil
297 case pt == identifierType:
298 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
299 if len(args) < 1 {
300 return reflect.Value{}, nil, fmt.Errorf("missing argument")
301 }
302 arg := args[0]
303 args = args[1:]
304 switch arg := arg.(type) {
305 case expect.Identifier:
306 return reflect.ValueOf(arg), args, nil
307 default:
308 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
309 }
310 }, nil
311
312 case pt == regexType:
313 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
314 if len(args) < 1 {
315 return reflect.Value{}, nil, fmt.Errorf("missing argument")
316 }
317 arg := args[0]
318 args = args[1:]
319 if _, ok := arg.(*regexp.Regexp); !ok {
320 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to *regexp.Regexp", arg)
321 }
322 return reflect.ValueOf(arg), args, nil
323 }, nil
324
325 case pt.Kind() == reflect.String:
326 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
327 if len(args) < 1 {
328 return reflect.Value{}, nil, fmt.Errorf("missing argument")
329 }
330 arg := args[0]
331 args = args[1:]
332 switch arg := arg.(type) {
333 case expect.Identifier:
334 return reflect.ValueOf(string(arg)), args, nil
335 case string:
336 return reflect.ValueOf(arg), args, nil
337 default:
338 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
339 }
340 }, nil
341 case pt.Kind() == reflect.Int64:
342 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
343 if len(args) < 1 {
344 return reflect.Value{}, nil, fmt.Errorf("missing argument")
345 }
346 arg := args[0]
347 args = args[1:]
348 switch arg := arg.(type) {
349 case int64:
350 return reflect.ValueOf(arg), args, nil
351 default:
352 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
353 }
354 }, nil
355 case pt.Kind() == reflect.Bool:
356 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
357 if len(args) < 1 {
358 return reflect.Value{}, nil, fmt.Errorf("missing argument")
359 }
360 arg := args[0]
361 args = args[1:]
362 b, ok := arg.(bool)
363 if !ok {
364 return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
365 }
366 return reflect.ValueOf(b), args, nil
367 }, nil
368 case pt.Kind() == reflect.Slice:
369 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
370 converter, err := e.buildConverter(pt.Elem())
371 if err != nil {
372 return reflect.Value{}, nil, err
373 }
374 result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
375 for range args {
376 value, remains, err := converter(n, args)
377 if err != nil {
378 return reflect.Value{}, nil, err
379 }
380 result = reflect.Append(result, value)
381 args = remains
382 }
383 return result, args, nil
384 }, nil
385 default:
386 if pt.Kind() == reflect.Interface && pt.NumMethod() == 0 {
387 return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
388 if len(args) < 1 {
389 return reflect.Value{}, nil, fmt.Errorf("missing argument")
390 }
391 return reflect.ValueOf(args[0]), args[1:], nil
392 }, nil
393 }
394 return nil, fmt.Errorf("param has unexpected type %v (kind %v)", pt, pt.Kind())
395 }
396 }
397
398 func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (Range, []interface{}, error) {
399 tokFile := e.ExpectFileSet.File(n.Pos)
400 if len(args) < 1 {
401 return Range{}, nil, fmt.Errorf("missing argument")
402 }
403 arg := args[0]
404 args = args[1:]
405 switch arg := arg.(type) {
406 case expect.Identifier:
407
408 switch arg {
409 case eofIdentifier:
410
411 eof := tokFile.Pos(tokFile.Size())
412 return newRange(tokFile, eof, eof), args, nil
413 default:
414
415 mark, ok := e.markers[string(arg)]
416 if !ok {
417 return Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
418 }
419 return mark, args, nil
420 }
421 case string:
422 start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
423 if err != nil {
424 return Range{}, nil, err
425 }
426 if !start.IsValid() {
427 return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
428 }
429 return newRange(tokFile, start, end), args, nil
430 case *regexp.Regexp:
431 start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
432 if err != nil {
433 return Range{}, nil, err
434 }
435 if !start.IsValid() {
436 return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
437 }
438 return newRange(tokFile, start, end), args, nil
439 default:
440 return Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
441 }
442 }
443
444
445 func newRange(file *token.File, start, end token.Pos) Range {
446 fileBase := file.Base()
447 fileEnd := fileBase + file.Size()
448 if !start.IsValid() {
449 panic("invalid start token.Pos")
450 }
451 if !end.IsValid() {
452 panic("invalid end token.Pos")
453 }
454 if int(start) < fileBase || int(start) > fileEnd {
455 panic(fmt.Sprintf("invalid start: %d not in [%d, %d]", start, fileBase, fileEnd))
456 }
457 if int(end) < fileBase || int(end) > fileEnd {
458 panic(fmt.Sprintf("invalid end: %d not in [%d, %d]", end, fileBase, fileEnd))
459 }
460 if start > end {
461 panic("invalid start: greater than end")
462 }
463 return Range{
464 TokFile: file,
465 Start: start,
466 End: end,
467 }
468 }
469
View as plain text