...

Source file src/gotest.tools/v3/internal/source/update.go

Documentation: gotest.tools/v3/internal/source

     1  package source
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"runtime"
    14  	"strings"
    15  )
    16  
    17  // IsUpdate is returns true if the -update flag is set. It indicates the user
    18  // running the tests would like to update any golden values.
    19  func IsUpdate() bool {
    20  	if Update {
    21  		return true
    22  	}
    23  	return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
    24  }
    25  
    26  // Update is a shim for testing, and for compatibility with the old -update-golden
    27  // flag.
    28  var Update bool
    29  
    30  func init() {
    31  	if f := flag.Lookup("update"); f != nil {
    32  		getter, ok := f.Value.(flag.Getter)
    33  		msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
    34  		if !ok {
    35  			panic(msg)
    36  		}
    37  		if _, ok := getter.Get().(bool); !ok {
    38  			panic(msg)
    39  		}
    40  		return
    41  	}
    42  	flag.Bool("update", false, "update golden values")
    43  }
    44  
    45  // ErrNotFound indicates that UpdateExpectedValue failed to find the
    46  // variable to update, likely because it is not a package level variable.
    47  var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
    48  
    49  // UpdateExpectedValue looks for a package-level variable with a name that
    50  // starts with expected in the arguments to the caller. If the variable is
    51  // found, the value of the variable will be updated to value of the other
    52  // argument to the caller.
    53  func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
    54  	_, filename, line, ok := runtime.Caller(stackIndex + 1)
    55  	if !ok {
    56  		return errors.New("failed to get call stack")
    57  	}
    58  	debug("call stack position: %s:%d", filename, line)
    59  
    60  	fileset := token.NewFileSet()
    61  	astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
    62  	if err != nil {
    63  		return fmt.Errorf("failed to parse source file %s: %w", filename, err)
    64  	}
    65  
    66  	expr, err := getCallExprArgs(fileset, astFile, line)
    67  	if err != nil {
    68  		return fmt.Errorf("call from %s:%d: %w", filename, line, err)
    69  	}
    70  
    71  	if len(expr) < 3 {
    72  		debug("not enough arguments %d: %v",
    73  			len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
    74  		return ErrNotFound
    75  	}
    76  
    77  	argIndex, ident := getIdentForExpectedValueArg(expr)
    78  	if argIndex < 0 || ident == nil {
    79  		debug("no arguments started with the word 'expected': %v",
    80  			debugFormatNode{Node: &ast.CallExpr{Args: expr}})
    81  		return ErrNotFound
    82  	}
    83  
    84  	value := x
    85  	if argIndex == 1 {
    86  		value = y
    87  	}
    88  
    89  	strValue, ok := value.(string)
    90  	if !ok {
    91  		debug("value must be type string, got %T", value)
    92  		return ErrNotFound
    93  	}
    94  	return UpdateVariable(filename, fileset, astFile, ident, strValue)
    95  }
    96  
    97  // UpdateVariable writes to filename the contents of astFile with the value of
    98  // the variable updated to value.
    99  func UpdateVariable(
   100  	filename string,
   101  	fileset *token.FileSet,
   102  	astFile *ast.File,
   103  	ident *ast.Ident,
   104  	value string,
   105  ) error {
   106  	obj := ident.Obj
   107  	if obj == nil {
   108  		return ErrNotFound
   109  	}
   110  	if obj.Kind != ast.Con && obj.Kind != ast.Var {
   111  		debug("can only update var and const, found %v", obj.Kind)
   112  		return ErrNotFound
   113  	}
   114  
   115  	switch decl := obj.Decl.(type) {
   116  	case *ast.ValueSpec:
   117  		if len(decl.Names) != 1 {
   118  			debug("more than one name in ast.ValueSpec")
   119  			return ErrNotFound
   120  		}
   121  
   122  		decl.Values[0] = &ast.BasicLit{
   123  			Kind:  token.STRING,
   124  			Value: "`" + value + "`",
   125  		}
   126  
   127  	case *ast.AssignStmt:
   128  		if len(decl.Lhs) != 1 {
   129  			debug("more than one name in ast.AssignStmt")
   130  			return ErrNotFound
   131  		}
   132  
   133  		decl.Rhs[0] = &ast.BasicLit{
   134  			Kind:  token.STRING,
   135  			Value: "`" + value + "`",
   136  		}
   137  
   138  	default:
   139  		debug("can only update *ast.ValueSpec, found %T", obj.Decl)
   140  		return ErrNotFound
   141  	}
   142  
   143  	var buf bytes.Buffer
   144  	if err := format.Node(&buf, fileset, astFile); err != nil {
   145  		return fmt.Errorf("failed to format file after update: %w", err)
   146  	}
   147  
   148  	fh, err := os.Create(filename)
   149  	if err != nil {
   150  		return fmt.Errorf("failed to open file %v: %w", filename, err)
   151  	}
   152  	if _, err = fh.Write(buf.Bytes()); err != nil {
   153  		return fmt.Errorf("failed to write file %v: %w", filename, err)
   154  	}
   155  	if err := fh.Sync(); err != nil {
   156  		return fmt.Errorf("failed to sync file %v: %w", filename, err)
   157  	}
   158  	return nil
   159  }
   160  
   161  func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
   162  	for i := 1; i < 3; i++ {
   163  		switch e := expr[i].(type) {
   164  		case *ast.Ident:
   165  			if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
   166  				return i, e
   167  			}
   168  		}
   169  	}
   170  	return -1, nil
   171  }
   172  

View as plain text