1 package source
2
3 import (
4 "bytes"
5 "errors"
6 "flag"
7 "fmt"
8 "go/ast"
9 "go/format"
10 "go/parser"
11 "go/token"
12 "os"
13 "runtime"
14 "strings"
15 )
16
17
18
19 func IsUpdate() bool {
20 if Update {
21 return true
22 }
23 return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
24 }
25
26
27
28 var Update bool
29
30 func init() {
31 if f := flag.Lookup("update"); f != nil {
32 getter, ok := f.Value.(flag.Getter)
33 msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
34 if !ok {
35 panic(msg)
36 }
37 if _, ok := getter.Get().(bool); !ok {
38 panic(msg)
39 }
40 return
41 }
42 flag.Bool("update", false, "update golden values")
43 }
44
45
46
47 var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
48
49
50
51
52
53 func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
54 _, filename, line, ok := runtime.Caller(stackIndex + 1)
55 if !ok {
56 return errors.New("failed to get call stack")
57 }
58 debug("call stack position: %s:%d", filename, line)
59
60 fileset := token.NewFileSet()
61 astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
62 if err != nil {
63 return fmt.Errorf("failed to parse source file %s: %w", filename, err)
64 }
65
66 expr, err := getCallExprArgs(fileset, astFile, line)
67 if err != nil {
68 return fmt.Errorf("call from %s:%d: %w", filename, line, err)
69 }
70
71 if len(expr) < 3 {
72 debug("not enough arguments %d: %v",
73 len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
74 return ErrNotFound
75 }
76
77 argIndex, ident := getIdentForExpectedValueArg(expr)
78 if argIndex < 0 || ident == nil {
79 debug("no arguments started with the word 'expected': %v",
80 debugFormatNode{Node: &ast.CallExpr{Args: expr}})
81 return ErrNotFound
82 }
83
84 value := x
85 if argIndex == 1 {
86 value = y
87 }
88
89 strValue, ok := value.(string)
90 if !ok {
91 debug("value must be type string, got %T", value)
92 return ErrNotFound
93 }
94 return UpdateVariable(filename, fileset, astFile, ident, strValue)
95 }
96
97
98
99 func UpdateVariable(
100 filename string,
101 fileset *token.FileSet,
102 astFile *ast.File,
103 ident *ast.Ident,
104 value string,
105 ) error {
106 obj := ident.Obj
107 if obj == nil {
108 return ErrNotFound
109 }
110 if obj.Kind != ast.Con && obj.Kind != ast.Var {
111 debug("can only update var and const, found %v", obj.Kind)
112 return ErrNotFound
113 }
114
115 switch decl := obj.Decl.(type) {
116 case *ast.ValueSpec:
117 if len(decl.Names) != 1 {
118 debug("more than one name in ast.ValueSpec")
119 return ErrNotFound
120 }
121
122 decl.Values[0] = &ast.BasicLit{
123 Kind: token.STRING,
124 Value: "`" + value + "`",
125 }
126
127 case *ast.AssignStmt:
128 if len(decl.Lhs) != 1 {
129 debug("more than one name in ast.AssignStmt")
130 return ErrNotFound
131 }
132
133 decl.Rhs[0] = &ast.BasicLit{
134 Kind: token.STRING,
135 Value: "`" + value + "`",
136 }
137
138 default:
139 debug("can only update *ast.ValueSpec, found %T", obj.Decl)
140 return ErrNotFound
141 }
142
143 var buf bytes.Buffer
144 if err := format.Node(&buf, fileset, astFile); err != nil {
145 return fmt.Errorf("failed to format file after update: %w", err)
146 }
147
148 fh, err := os.Create(filename)
149 if err != nil {
150 return fmt.Errorf("failed to open file %v: %w", filename, err)
151 }
152 if _, err = fh.Write(buf.Bytes()); err != nil {
153 return fmt.Errorf("failed to write file %v: %w", filename, err)
154 }
155 if err := fh.Sync(); err != nil {
156 return fmt.Errorf("failed to sync file %v: %w", filename, err)
157 }
158 return nil
159 }
160
161 func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
162 for i := 1; i < 3; i++ {
163 switch e := expr[i].(type) {
164 case *ast.Ident:
165 if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
166 return i, e
167 }
168 }
169 }
170 return -1, nil
171 }
172
View as plain text