1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package main
18
19 import (
20 "bytes"
21 "flag"
22 "fmt"
23 "go/ast"
24 "go/format"
25 "go/parser"
26 "go/token"
27 "log"
28 "os"
29 "strings"
30 "text/template"
31 )
32
33 const (
34 ignoreFilePrefix1 = "gen-"
35 ignoreFilePrefix2 = "github-"
36 outputFileSuffix = "-stringify_test.go"
37 )
38
39 var (
40 verbose = flag.Bool("v", false, "Print verbose log messages")
41
42
43 skipStructMethods = map[string]bool{}
44
45 skipStructs = map[string]bool{
46 "RateLimits": true,
47 }
48
49 funcMap = template.FuncMap{
50 "isNotLast": func(index int, slice []*structField) string {
51 if index+1 < len(slice) {
52 return ", "
53 }
54 return ""
55 },
56 "processZeroValue": func(v string) string {
57 switch v {
58 case "Bool(false)":
59 return "false"
60 case "Float64(0.0)":
61 return "0"
62 case "0", "Int(0)", "Int64(0)":
63 return "0"
64 case `""`, `String("")`:
65 return `""`
66 case "Timestamp{}", "&Timestamp{}":
67 return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
68 case "nil":
69 return "map[]"
70 case `[]int{0}`:
71 return `[0]`
72 case `[]string{""}`:
73 return `[""]`
74 case "[]Scope{ScopeNone}":
75 return `["(no scope)"]`
76 }
77 log.Fatalf("Unhandled zero value: %q", v)
78 return ""
79 },
80 }
81
82 sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
83 )
84
85 func main() {
86 flag.Parse()
87 fset := token.NewFileSet()
88
89 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
90 if err != nil {
91 log.Fatal(err)
92 return
93 }
94
95 for pkgName, pkg := range pkgs {
96 t := &templateData{
97 filename: pkgName + outputFileSuffix,
98 Year: 2019,
99 Package: pkgName,
100 Imports: map[string]string{"testing": "testing"},
101 StringFuncs: map[string]bool{},
102 StructFields: map[string][]*structField{},
103 }
104 for filename, f := range pkg.Files {
105 logf("Processing %v...", filename)
106 if err := t.processAST(f); err != nil {
107 log.Fatal(err)
108 }
109 }
110 if err := t.dump(); err != nil {
111 log.Fatal(err)
112 }
113 }
114 logf("Done.")
115 }
116
117 func sourceFilter(fi os.FileInfo) bool {
118 return !strings.HasSuffix(fi.Name(), "_test.go") &&
119 !strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
120 !strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
121 }
122
123 type templateData struct {
124 filename string
125 Year int
126 Package string
127 Imports map[string]string
128 StringFuncs map[string]bool
129 StructFields map[string][]*structField
130 }
131
132 type structField struct {
133 sortVal string
134 ReceiverVar string
135 ReceiverType string
136 FieldName string
137 FieldType string
138 ZeroValue string
139 NamedStruct bool
140 }
141
142 func (t *templateData) processAST(f *ast.File) error {
143 for _, decl := range f.Decls {
144 fn, ok := decl.(*ast.FuncDecl)
145 if ok {
146 if fn.Recv != nil && len(fn.Recv.List) > 0 {
147 id, ok := fn.Recv.List[0].Type.(*ast.Ident)
148 if ok && fn.Name.Name == "String" {
149 logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
150 t.StringFuncs[id.Name] = true
151 } else {
152 star, ok := fn.Recv.List[0].Type.(*ast.StarExpr)
153 if ok && fn.Name.Name == "String" {
154 id, ok := star.X.(*ast.Ident)
155 if ok {
156 logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
157 t.StringFuncs[id.Name] = true
158 } else {
159 logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
160 }
161 } else {
162 logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
163 }
164 }
165 } else {
166 logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
167 }
168 continue
169 }
170
171 gd, ok := decl.(*ast.GenDecl)
172 if !ok {
173 logf("Ignoring AST decl type %T", decl)
174 continue
175 }
176
177 for _, spec := range gd.Specs {
178 ts, ok := spec.(*ast.TypeSpec)
179 if !ok {
180 continue
181 }
182
183 if !ts.Name.IsExported() {
184 logf("Struct %v is unexported; skipping.", ts.Name)
185 continue
186 }
187
188 if skipStructs[ts.Name.Name] {
189 logf("Struct %v is in skip list; skipping.", ts.Name)
190 continue
191 }
192 st, ok := ts.Type.(*ast.StructType)
193 if !ok {
194 logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
195 continue
196 }
197 for _, field := range st.Fields.List {
198 if len(field.Names) == 0 {
199 continue
200 }
201
202 fieldName := field.Names[0]
203 if id, ok := field.Type.(*ast.Ident); ok {
204 t.addIdent(id, ts.Name.String(), fieldName.String())
205 continue
206 }
207
208 if at, ok := field.Type.(*ast.ArrayType); ok {
209 if id, ok := at.Elt.(*ast.Ident); ok {
210 t.addIdentSlice(id, ts.Name.String(), fieldName.String())
211 continue
212 }
213 }
214
215 se, ok := field.Type.(*ast.StarExpr)
216 if !ok {
217 logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
218 continue
219 }
220
221
222 if !fieldName.IsExported() {
223 logf("Field %v is unexported; skipping.", fieldName)
224 continue
225 }
226
227 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
228 logf("Method %v is in skip list; skipping.", key)
229 continue
230 }
231
232 switch x := se.X.(type) {
233 case *ast.ArrayType:
234 case *ast.Ident:
235 t.addIdentPtr(x, ts.Name.String(), fieldName.String())
236 case *ast.MapType:
237 case *ast.SelectorExpr:
238 default:
239 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
240 }
241 }
242 }
243 }
244 return nil
245 }
246
247 func (t *templateData) addMapType(receiverType, fieldName string) {
248 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
249 }
250
251 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
252 var zeroValue string
253 var namedStruct = false
254 switch x.String() {
255 case "int":
256 zeroValue = "0"
257 case "int64":
258 zeroValue = "0"
259 case "float64":
260 zeroValue = "0.0"
261 case "string":
262 zeroValue = `""`
263 case "bool":
264 zeroValue = "false"
265 case "Timestamp":
266 zeroValue = "Timestamp{}"
267 default:
268 zeroValue = "nil"
269 namedStruct = true
270 }
271
272 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
273 }
274
275 func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
276 var zeroValue string
277 var namedStruct = false
278 switch x.String() {
279 case "int":
280 zeroValue = "Int(0)"
281 case "int64":
282 zeroValue = "Int64(0)"
283 case "float64":
284 zeroValue = "Float64(0.0)"
285 case "string":
286 zeroValue = `String("")`
287 case "bool":
288 zeroValue = "Bool(false)"
289 case "Timestamp":
290 zeroValue = "&Timestamp{}"
291 default:
292 zeroValue = "nil"
293 namedStruct = true
294 }
295
296 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
297 }
298
299 func (t *templateData) addIdentSlice(x *ast.Ident, receiverType, fieldName string) {
300 var zeroValue string
301 var namedStruct = false
302 switch x.String() {
303 case "int":
304 zeroValue = "[]int{0}"
305 case "int64":
306 zeroValue = "[]int64{0}"
307 case "float64":
308 zeroValue = "[]float64{0}"
309 case "string":
310 zeroValue = `[]string{""}`
311 case "bool":
312 zeroValue = "[]bool{false}"
313 case "Scope":
314 zeroValue = "[]Scope{ScopeNone}"
315
316
317 default:
318 zeroValue = "nil"
319 namedStruct = true
320 }
321
322 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
323 }
324
325 func (t *templateData) dump() error {
326 if len(t.StructFields) == 0 {
327 logf("No StructFields for %v; skipping.", t.filename)
328 return nil
329 }
330
331
332 var toDelete []string
333 for k := range t.StructFields {
334 if !t.StringFuncs[k] {
335 toDelete = append(toDelete, k)
336 continue
337 }
338 }
339 for _, k := range toDelete {
340 delete(t.StructFields, k)
341 }
342
343 var buf bytes.Buffer
344 if err := sourceTmpl.Execute(&buf, t); err != nil {
345 return err
346 }
347 clean, err := format.Source(buf.Bytes())
348 if err != nil {
349 log.Printf("failed-to-format source:\n%v", buf.String())
350 return err
351 }
352
353 logf("Writing %v...", t.filename)
354 if err := os.Chmod(t.filename, 0644); err != nil {
355 return fmt.Errorf("os.Chmod(%q, 0644): %v", t.filename, err)
356 }
357
358 if err := os.WriteFile(t.filename, clean, 0444); err != nil {
359 return err
360 }
361
362 if err := os.Chmod(t.filename, 0444); err != nil {
363 return fmt.Errorf("os.Chmod(%q, 0444): %v", t.filename, err)
364 }
365
366 return nil
367 }
368
369 func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
370 return &structField{
371 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
372 ReceiverVar: strings.ToLower(receiverType[:1]),
373 ReceiverType: receiverType,
374 FieldName: fieldName,
375 FieldType: fieldType,
376 ZeroValue: zeroValue,
377 NamedStruct: namedStruct,
378 }
379 }
380
381 func logf(fmt string, args ...interface{}) {
382 if *verbose {
383 log.Printf(fmt, args...)
384 }
385 }
386
387 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
388 //
389 // Use of this source code is governed by a BSD-style
390 // license that can be found in the LICENSE file.
391
392 // Code generated by gen-stringify-tests; DO NOT EDIT.
393 // Instead, please run "go generate ./..." as described here:
394 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
395
396 package {{ $package := .Package}}{{$package}}
397 {{with .Imports}}
398 import (
399 {{- range . -}}
400 "{{.}}"
401 {{end -}}
402 )
403 {{end}}
404 func Float64(v float64) *float64 { return &v }
405 {{range $key, $value := .StructFields}}
406 func Test{{ $key }}_String(t *testing.T) {
407 v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
408 {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
409 {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
410 }
411 want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
412 {{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
413 if got := v.String(); got != want {
414 t.Errorf("{{ $key }}.String = %v, want %v", got, want)
415 }
416 }
417 {{end}}
418 `
419
View as plain text