...

Source file src/golang.org/x/tools/go/packages/packagestest/expect.go

Documentation: golang.org/x/tools/go/packages/packagestest

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package packagestest
     6  
     7  import (
     8  	"fmt"
     9  	"go/token"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"regexp"
    14  	"strings"
    15  
    16  	"golang.org/x/tools/go/expect"
    17  	"golang.org/x/tools/go/packages"
    18  )
    19  
    20  const (
    21  	markMethod    = "mark"
    22  	eofIdentifier = "EOF"
    23  )
    24  
    25  // Expect invokes the supplied methods for all expectation notes found in
    26  // the exported source files.
    27  //
    28  // All exported go source files are parsed to collect the expectation
    29  // notes.
    30  // See the documentation for expect.Parse for how the notes are collected
    31  // and parsed.
    32  //
    33  // The methods are supplied as a map of name to function, and those functions
    34  // will be matched against the expectations by name.
    35  // Notes with no matching function will be skipped, and functions with no
    36  // matching notes will not be invoked.
    37  // If there are no registered markers yet, a special pass will be run first
    38  // which adds any markers declared with @mark(Name, pattern) or @name. These
    39  // call the Mark method to add the marker to the global set.
    40  // You can register the "mark" method to override these in your own call to
    41  // Expect. The bound Mark function is usable directly in your method map, so
    42  //
    43  //	exported.Expect(map[string]interface{}{"mark": exported.Mark})
    44  //
    45  // replicates the built in behavior.
    46  //
    47  // # Method invocation
    48  //
    49  // When invoking a method the expressions in the parameter list need to be
    50  // converted to values to be passed to the method.
    51  // There are a very limited set of types the arguments are allowed to be.
    52  //
    53  //	expect.Note : passed the Note instance being evaluated.
    54  //	string : can be supplied either a string literal or an identifier.
    55  //	int : can only be supplied an integer literal.
    56  //	*regexp.Regexp : can only be supplied a regular expression literal
    57  //	token.Pos : has a file position calculated as described below.
    58  //	token.Position : has a file position calculated as described below.
    59  //	expect.Range: has a start and end position as described below.
    60  //	interface{} : will be passed any value
    61  //
    62  // # Position calculation
    63  //
    64  // There is some extra handling when a parameter is being coerced into a
    65  // token.Pos, token.Position or Range type argument.
    66  //
    67  // If the parameter is an identifier, it will be treated as the name of an
    68  // marker to look up (as if markers were global variables).
    69  //
    70  // If it is a string or regular expression, then it will be passed to
    71  // expect.MatchBefore to look up a match in the line at which it was declared.
    72  //
    73  // It is safe to call this repeatedly with different method sets, but it is
    74  // not safe to call it concurrently.
    75  func (e *Exported) Expect(methods map[string]interface{}) error {
    76  	if err := e.getNotes(); err != nil {
    77  		return err
    78  	}
    79  	if err := e.getMarkers(); err != nil {
    80  		return err
    81  	}
    82  	var err error
    83  	ms := make(map[string]method, len(methods))
    84  	for name, f := range methods {
    85  		mi := method{f: reflect.ValueOf(f)}
    86  		mi.converters = make([]converter, mi.f.Type().NumIn())
    87  		for i := 0; i < len(mi.converters); i++ {
    88  			mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
    89  			if err != nil {
    90  				return fmt.Errorf("invalid method %v: %v", name, err)
    91  			}
    92  		}
    93  		ms[name] = mi
    94  	}
    95  	for _, n := range e.notes {
    96  		if n.Args == nil {
    97  			// simple identifier form, convert to a call to mark
    98  			n = &expect.Note{
    99  				Pos:  n.Pos,
   100  				Name: markMethod,
   101  				Args: []interface{}{n.Name, n.Name},
   102  			}
   103  		}
   104  		mi, ok := ms[n.Name]
   105  		if !ok {
   106  			continue
   107  		}
   108  		params := make([]reflect.Value, len(mi.converters))
   109  		args := n.Args
   110  		for i, convert := range mi.converters {
   111  			params[i], args, err = convert(n, args)
   112  			if err != nil {
   113  				return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err)
   114  			}
   115  		}
   116  		if len(args) > 0 {
   117  			return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args)
   118  		}
   119  		//TODO: catch the error returned from the method
   120  		mi.f.Call(params)
   121  	}
   122  	return nil
   123  }
   124  
   125  // A Range represents an interval within a source file in go/token notation.
   126  type Range struct {
   127  	TokFile    *token.File // non-nil
   128  	Start, End token.Pos   // both valid and within range of TokFile
   129  }
   130  
   131  // Mark adds a new marker to the known set.
   132  func (e *Exported) Mark(name string, r Range) {
   133  	if e.markers == nil {
   134  		e.markers = make(map[string]Range)
   135  	}
   136  	e.markers[name] = r
   137  }
   138  
   139  func (e *Exported) getNotes() error {
   140  	if e.notes != nil {
   141  		return nil
   142  	}
   143  	notes := []*expect.Note{}
   144  	var dirs []string
   145  	for _, module := range e.written {
   146  		for _, filename := range module {
   147  			dirs = append(dirs, filepath.Dir(filename))
   148  		}
   149  	}
   150  	for filename := range e.Config.Overlay {
   151  		dirs = append(dirs, filepath.Dir(filename))
   152  	}
   153  	pkgs, err := packages.Load(e.Config, dirs...)
   154  	if err != nil {
   155  		return fmt.Errorf("unable to load packages for directories %s: %v", dirs, err)
   156  	}
   157  	seen := make(map[token.Position]struct{})
   158  	for _, pkg := range pkgs {
   159  		for _, filename := range pkg.GoFiles {
   160  			content, err := e.FileContents(filename)
   161  			if err != nil {
   162  				return err
   163  			}
   164  			l, err := expect.Parse(e.ExpectFileSet, filename, content)
   165  			if err != nil {
   166  				return fmt.Errorf("failed to extract expectations: %v", err)
   167  			}
   168  			for _, note := range l {
   169  				pos := e.ExpectFileSet.Position(note.Pos)
   170  				if _, ok := seen[pos]; ok {
   171  					continue
   172  				}
   173  				notes = append(notes, note)
   174  				seen[pos] = struct{}{}
   175  			}
   176  		}
   177  	}
   178  	if _, ok := e.written[e.primary]; !ok {
   179  		e.notes = notes
   180  		return nil
   181  	}
   182  	// Check go.mod markers regardless of mode, we need to do this so that our marker count
   183  	// matches the counts in the summary.txt.golden file for the test directory.
   184  	if gomod, found := e.written[e.primary]["go.mod"]; found {
   185  		// If we are in Modules mode, then we need to check the contents of the go.mod.temp.
   186  		if e.Exporter == Modules {
   187  			gomod += ".temp"
   188  		}
   189  		l, err := goModMarkers(e, gomod)
   190  		if err != nil {
   191  			return fmt.Errorf("failed to extract expectations for go.mod: %v", err)
   192  		}
   193  		notes = append(notes, l...)
   194  	}
   195  	e.notes = notes
   196  	return nil
   197  }
   198  
   199  func goModMarkers(e *Exported, gomod string) ([]*expect.Note, error) {
   200  	if _, err := os.Stat(gomod); os.IsNotExist(err) {
   201  		// If there is no go.mod file, we want to be able to continue.
   202  		return nil, nil
   203  	}
   204  	content, err := e.FileContents(gomod)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	if e.Exporter == GOPATH {
   209  		return expect.Parse(e.ExpectFileSet, gomod, content)
   210  	}
   211  	gomod = strings.TrimSuffix(gomod, ".temp")
   212  	// If we are in Modules mode, copy the original contents file back into go.mod
   213  	if err := os.WriteFile(gomod, content, 0644); err != nil {
   214  		return nil, nil
   215  	}
   216  	return expect.Parse(e.ExpectFileSet, gomod, content)
   217  }
   218  
   219  func (e *Exported) getMarkers() error {
   220  	if e.markers != nil {
   221  		return nil
   222  	}
   223  	// set markers early so that we don't call getMarkers again from Expect
   224  	e.markers = make(map[string]Range)
   225  	return e.Expect(map[string]interface{}{
   226  		markMethod: e.Mark,
   227  	})
   228  }
   229  
   230  var (
   231  	noteType       = reflect.TypeOf((*expect.Note)(nil))
   232  	identifierType = reflect.TypeOf(expect.Identifier(""))
   233  	posType        = reflect.TypeOf(token.Pos(0))
   234  	positionType   = reflect.TypeOf(token.Position{})
   235  	rangeType      = reflect.TypeOf(Range{})
   236  	fsetType       = reflect.TypeOf((*token.FileSet)(nil))
   237  	regexType      = reflect.TypeOf((*regexp.Regexp)(nil))
   238  	exportedType   = reflect.TypeOf((*Exported)(nil))
   239  )
   240  
   241  // converter converts from a marker's argument parsed from the comment to
   242  // reflect values passed to the method during Invoke.
   243  // It takes the args remaining, and returns the args it did not consume.
   244  // This allows a converter to consume 0 args for well known types, or multiple
   245  // args for compound types.
   246  type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
   247  
   248  // method is used to track information about Invoke methods that is expensive to
   249  // calculate so that we can work it out once rather than per marker.
   250  type method struct {
   251  	f          reflect.Value // the reflect value of the passed in method
   252  	converters []converter   // the parameter converters for the method
   253  }
   254  
   255  // buildConverter works out what function should be used to go from an ast expressions to a reflect
   256  // value of the type expected by a method.
   257  // It is called when only the target type is know, it returns converters that are flexible across
   258  // all supported expression types for that target type.
   259  func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
   260  	switch {
   261  	case pt == noteType:
   262  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   263  			return reflect.ValueOf(n), args, nil
   264  		}, nil
   265  	case pt == fsetType:
   266  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   267  			return reflect.ValueOf(e.ExpectFileSet), args, nil
   268  		}, nil
   269  	case pt == exportedType:
   270  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   271  			return reflect.ValueOf(e), args, nil
   272  		}, nil
   273  	case pt == posType:
   274  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   275  			r, remains, err := e.rangeConverter(n, args)
   276  			if err != nil {
   277  				return reflect.Value{}, nil, err
   278  			}
   279  			return reflect.ValueOf(r.Start), remains, nil
   280  		}, nil
   281  	case pt == positionType:
   282  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   283  			r, remains, err := e.rangeConverter(n, args)
   284  			if err != nil {
   285  				return reflect.Value{}, nil, err
   286  			}
   287  			return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil
   288  		}, nil
   289  	case pt == rangeType:
   290  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   291  			r, remains, err := e.rangeConverter(n, args)
   292  			if err != nil {
   293  				return reflect.Value{}, nil, err
   294  			}
   295  			return reflect.ValueOf(r), remains, nil
   296  		}, nil
   297  	case pt == identifierType:
   298  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   299  			if len(args) < 1 {
   300  				return reflect.Value{}, nil, fmt.Errorf("missing argument")
   301  			}
   302  			arg := args[0]
   303  			args = args[1:]
   304  			switch arg := arg.(type) {
   305  			case expect.Identifier:
   306  				return reflect.ValueOf(arg), args, nil
   307  			default:
   308  				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
   309  			}
   310  		}, nil
   311  
   312  	case pt == regexType:
   313  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   314  			if len(args) < 1 {
   315  				return reflect.Value{}, nil, fmt.Errorf("missing argument")
   316  			}
   317  			arg := args[0]
   318  			args = args[1:]
   319  			if _, ok := arg.(*regexp.Regexp); !ok {
   320  				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to *regexp.Regexp", arg)
   321  			}
   322  			return reflect.ValueOf(arg), args, nil
   323  		}, nil
   324  
   325  	case pt.Kind() == reflect.String:
   326  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   327  			if len(args) < 1 {
   328  				return reflect.Value{}, nil, fmt.Errorf("missing argument")
   329  			}
   330  			arg := args[0]
   331  			args = args[1:]
   332  			switch arg := arg.(type) {
   333  			case expect.Identifier:
   334  				return reflect.ValueOf(string(arg)), args, nil
   335  			case string:
   336  				return reflect.ValueOf(arg), args, nil
   337  			default:
   338  				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
   339  			}
   340  		}, nil
   341  	case pt.Kind() == reflect.Int64:
   342  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   343  			if len(args) < 1 {
   344  				return reflect.Value{}, nil, fmt.Errorf("missing argument")
   345  			}
   346  			arg := args[0]
   347  			args = args[1:]
   348  			switch arg := arg.(type) {
   349  			case int64:
   350  				return reflect.ValueOf(arg), args, nil
   351  			default:
   352  				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
   353  			}
   354  		}, nil
   355  	case pt.Kind() == reflect.Bool:
   356  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   357  			if len(args) < 1 {
   358  				return reflect.Value{}, nil, fmt.Errorf("missing argument")
   359  			}
   360  			arg := args[0]
   361  			args = args[1:]
   362  			b, ok := arg.(bool)
   363  			if !ok {
   364  				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
   365  			}
   366  			return reflect.ValueOf(b), args, nil
   367  		}, nil
   368  	case pt.Kind() == reflect.Slice:
   369  		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   370  			converter, err := e.buildConverter(pt.Elem())
   371  			if err != nil {
   372  				return reflect.Value{}, nil, err
   373  			}
   374  			result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
   375  			for range args {
   376  				value, remains, err := converter(n, args)
   377  				if err != nil {
   378  					return reflect.Value{}, nil, err
   379  				}
   380  				result = reflect.Append(result, value)
   381  				args = remains
   382  			}
   383  			return result, args, nil
   384  		}, nil
   385  	default:
   386  		if pt.Kind() == reflect.Interface && pt.NumMethod() == 0 {
   387  			return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
   388  				if len(args) < 1 {
   389  					return reflect.Value{}, nil, fmt.Errorf("missing argument")
   390  				}
   391  				return reflect.ValueOf(args[0]), args[1:], nil
   392  			}, nil
   393  		}
   394  		return nil, fmt.Errorf("param has unexpected type %v (kind %v)", pt, pt.Kind())
   395  	}
   396  }
   397  
   398  func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (Range, []interface{}, error) {
   399  	tokFile := e.ExpectFileSet.File(n.Pos)
   400  	if len(args) < 1 {
   401  		return Range{}, nil, fmt.Errorf("missing argument")
   402  	}
   403  	arg := args[0]
   404  	args = args[1:]
   405  	switch arg := arg.(type) {
   406  	case expect.Identifier:
   407  		// handle the special identifiers
   408  		switch arg {
   409  		case eofIdentifier:
   410  			// end of file identifier
   411  			eof := tokFile.Pos(tokFile.Size())
   412  			return newRange(tokFile, eof, eof), args, nil
   413  		default:
   414  			// look up an marker by name
   415  			mark, ok := e.markers[string(arg)]
   416  			if !ok {
   417  				return Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
   418  			}
   419  			return mark, args, nil
   420  		}
   421  	case string:
   422  		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
   423  		if err != nil {
   424  			return Range{}, nil, err
   425  		}
   426  		if !start.IsValid() {
   427  			return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
   428  		}
   429  		return newRange(tokFile, start, end), args, nil
   430  	case *regexp.Regexp:
   431  		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
   432  		if err != nil {
   433  			return Range{}, nil, err
   434  		}
   435  		if !start.IsValid() {
   436  			return Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
   437  		}
   438  		return newRange(tokFile, start, end), args, nil
   439  	default:
   440  		return Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
   441  	}
   442  }
   443  
   444  // newRange creates a new Range from a token.File and two valid positions within it.
   445  func newRange(file *token.File, start, end token.Pos) Range {
   446  	fileBase := file.Base()
   447  	fileEnd := fileBase + file.Size()
   448  	if !start.IsValid() {
   449  		panic("invalid start token.Pos")
   450  	}
   451  	if !end.IsValid() {
   452  		panic("invalid end token.Pos")
   453  	}
   454  	if int(start) < fileBase || int(start) > fileEnd {
   455  		panic(fmt.Sprintf("invalid start: %d not in [%d, %d]", start, fileBase, fileEnd))
   456  	}
   457  	if int(end) < fileBase || int(end) > fileEnd {
   458  		panic(fmt.Sprintf("invalid end: %d not in [%d, %d]", end, fileBase, fileEnd))
   459  	}
   460  	if start > end {
   461  		panic("invalid start: greater than end")
   462  	}
   463  	return Range{
   464  		TokFile: file,
   465  		Start:   start,
   466  		End:     end,
   467  	}
   468  }
   469  

View as plain text