1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package main
18
19 import (
20 "bytes"
21 "flag"
22 "fmt"
23 "go/ast"
24 "go/format"
25 "go/parser"
26 "go/token"
27 "io/ioutil"
28 "log"
29 "os"
30 "strings"
31 "text/template"
32 )
33
34 const (
35 ignoreFilePrefix1 = "gen-"
36 ignoreFilePrefix2 = "github-"
37 outputFileSuffix = "-stringify_test.go"
38 )
39
40 var (
41 verbose = flag.Bool("v", false, "Print verbose log messages")
42
43
44 skipStructMethods = map[string]bool{}
45
46 skipStructs = map[string]bool{
47 "RateLimits": true,
48 }
49
50 funcMap = template.FuncMap{
51 "isNotLast": func(index int, slice []*structField) string {
52 if index+1 < len(slice) {
53 return ", "
54 }
55 return ""
56 },
57 "processZeroValue": func(v string) string {
58 switch v {
59 case "Bool(false)":
60 return "false"
61 case "Float64(0.0)":
62 return "0"
63 case "0", "Int(0)", "Int64(0)":
64 return "0"
65 case `""`, `String("")`:
66 return `""`
67 case "Timestamp{}", "&Timestamp{}":
68 return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
69 case "nil":
70 return "map[]"
71 case `[]int{0}`:
72 return `[0]`
73 case `[]string{""}`:
74 return `[""]`
75 case "[]Scope{ScopeNone}":
76 return `["(no scope)"]`
77 }
78 log.Fatalf("Unhandled zero value: %q", v)
79 return ""
80 },
81 }
82
83 sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
84 )
85
86 func main() {
87 flag.Parse()
88 fset := token.NewFileSet()
89
90 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
91 if err != nil {
92 log.Fatal(err)
93 return
94 }
95
96 for pkgName, pkg := range pkgs {
97 t := &templateData{
98 filename: pkgName + outputFileSuffix,
99 Year: 2019,
100 Package: pkgName,
101 Imports: map[string]string{"testing": "testing"},
102 StringFuncs: map[string]bool{},
103 StructFields: map[string][]*structField{},
104 }
105 for filename, f := range pkg.Files {
106 logf("Processing %v...", filename)
107 if err := t.processAST(f); err != nil {
108 log.Fatal(err)
109 }
110 }
111 if err := t.dump(); err != nil {
112 log.Fatal(err)
113 }
114 }
115 logf("Done.")
116 }
117
118 func sourceFilter(fi os.FileInfo) bool {
119 return !strings.HasSuffix(fi.Name(), "_test.go") &&
120 !strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
121 !strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
122 }
123
124 type templateData struct {
125 filename string
126 Year int
127 Package string
128 Imports map[string]string
129 StringFuncs map[string]bool
130 StructFields map[string][]*structField
131 }
132
133 type structField struct {
134 sortVal string
135 ReceiverVar string
136 ReceiverType string
137 FieldName string
138 FieldType string
139 ZeroValue string
140 NamedStruct bool
141 }
142
143 func (t *templateData) processAST(f *ast.File) error {
144 for _, decl := range f.Decls {
145 fn, ok := decl.(*ast.FuncDecl)
146 if ok {
147 if fn.Recv != nil && len(fn.Recv.List) > 0 {
148 id, ok := fn.Recv.List[0].Type.(*ast.Ident)
149 if ok && fn.Name.Name == "String" {
150 logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
151 t.StringFuncs[id.Name] = true
152 } else {
153 star, ok := fn.Recv.List[0].Type.(*ast.StarExpr)
154 if ok && fn.Name.Name == "String" {
155 id, ok := star.X.(*ast.Ident)
156 if ok {
157 logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
158 t.StringFuncs[id.Name] = true
159 } else {
160 logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
161 }
162 } else {
163 logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
164 }
165 }
166 } else {
167 logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
168 }
169 continue
170 }
171
172 gd, ok := decl.(*ast.GenDecl)
173 if !ok {
174 logf("Ignoring AST decl type %T", decl)
175 continue
176 }
177
178 for _, spec := range gd.Specs {
179 ts, ok := spec.(*ast.TypeSpec)
180 if !ok {
181 continue
182 }
183
184 if !ts.Name.IsExported() {
185 logf("Struct %v is unexported; skipping.", ts.Name)
186 continue
187 }
188
189 if skipStructs[ts.Name.Name] {
190 logf("Struct %v is in skip list; skipping.", ts.Name)
191 continue
192 }
193 st, ok := ts.Type.(*ast.StructType)
194 if !ok {
195 logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
196 continue
197 }
198 for _, field := range st.Fields.List {
199 if len(field.Names) == 0 {
200 continue
201 }
202
203 fieldName := field.Names[0]
204 if id, ok := field.Type.(*ast.Ident); ok {
205 t.addIdent(id, ts.Name.String(), fieldName.String())
206 continue
207 }
208
209 if at, ok := field.Type.(*ast.ArrayType); ok {
210 if id, ok := at.Elt.(*ast.Ident); ok {
211 t.addIdentSlice(id, ts.Name.String(), fieldName.String())
212 continue
213 }
214 }
215
216 se, ok := field.Type.(*ast.StarExpr)
217 if !ok {
218 logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
219 continue
220 }
221
222
223 if !fieldName.IsExported() {
224 logf("Field %v is unexported; skipping.", fieldName)
225 continue
226 }
227
228 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
229 logf("Method %v is in skip list; skipping.", key)
230 continue
231 }
232
233 switch x := se.X.(type) {
234 case *ast.ArrayType:
235 case *ast.Ident:
236 t.addIdentPtr(x, ts.Name.String(), fieldName.String())
237 case *ast.MapType:
238 case *ast.SelectorExpr:
239 default:
240 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
241 }
242 }
243 }
244 }
245 return nil
246 }
247
248 func (t *templateData) addMapType(receiverType, fieldName string) {
249 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
250 }
251
252 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
253 var zeroValue string
254 var namedStruct = false
255 switch x.String() {
256 case "int":
257 zeroValue = "0"
258 case "int64":
259 zeroValue = "0"
260 case "float64":
261 zeroValue = "0.0"
262 case "string":
263 zeroValue = `""`
264 case "bool":
265 zeroValue = "false"
266 case "Timestamp":
267 zeroValue = "Timestamp{}"
268 default:
269 zeroValue = "nil"
270 namedStruct = true
271 }
272
273 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
274 }
275
276 func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
277 var zeroValue string
278 var namedStruct = false
279 switch x.String() {
280 case "int":
281 zeroValue = "Int(0)"
282 case "int64":
283 zeroValue = "Int64(0)"
284 case "float64":
285 zeroValue = "Float64(0.0)"
286 case "string":
287 zeroValue = `String("")`
288 case "bool":
289 zeroValue = "Bool(false)"
290 case "Timestamp":
291 zeroValue = "&Timestamp{}"
292 default:
293 zeroValue = "nil"
294 namedStruct = true
295 }
296
297 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
298 }
299
300 func (t *templateData) addIdentSlice(x *ast.Ident, receiverType, fieldName string) {
301 var zeroValue string
302 var namedStruct = false
303 switch x.String() {
304 case "int":
305 zeroValue = "[]int{0}"
306 case "int64":
307 zeroValue = "[]int64{0}"
308 case "float64":
309 zeroValue = "[]float64{0}"
310 case "string":
311 zeroValue = `[]string{""}`
312 case "bool":
313 zeroValue = "[]bool{false}"
314 case "Scope":
315 zeroValue = "[]Scope{ScopeNone}"
316
317
318 default:
319 zeroValue = "nil"
320 namedStruct = true
321 }
322
323 t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
324 }
325
326 func (t *templateData) dump() error {
327 if len(t.StructFields) == 0 {
328 logf("No StructFields for %v; skipping.", t.filename)
329 return nil
330 }
331
332
333 var toDelete []string
334 for k := range t.StructFields {
335 if !t.StringFuncs[k] {
336 toDelete = append(toDelete, k)
337 continue
338 }
339 }
340 for _, k := range toDelete {
341 delete(t.StructFields, k)
342 }
343
344 var buf bytes.Buffer
345 if err := sourceTmpl.Execute(&buf, t); err != nil {
346 return err
347 }
348 clean, err := format.Source(buf.Bytes())
349 if err != nil {
350 log.Printf("failed-to-format source:\n%v", buf.String())
351 return err
352 }
353
354 logf("Writing %v...", t.filename)
355 if err := os.Chmod(t.filename, 0644); err != nil {
356 return fmt.Errorf("os.Chmod(%q, 0644): %v", t.filename, err)
357 }
358
359 if err := ioutil.WriteFile(t.filename, clean, 0444); err != nil {
360 return err
361 }
362
363 if err := os.Chmod(t.filename, 0444); err != nil {
364 return fmt.Errorf("os.Chmod(%q, 0444): %v", t.filename, err)
365 }
366
367 return nil
368 }
369
370 func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
371 return &structField{
372 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
373 ReceiverVar: strings.ToLower(receiverType[:1]),
374 ReceiverType: receiverType,
375 FieldName: fieldName,
376 FieldType: fieldType,
377 ZeroValue: zeroValue,
378 NamedStruct: namedStruct,
379 }
380 }
381
382 func logf(fmt string, args ...interface{}) {
383 if *verbose {
384 log.Printf(fmt, args...)
385 }
386 }
387
388 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
389 //
390 // Use of this source code is governed by a BSD-style
391 // license that can be found in the LICENSE file.
392
393 // Code generated by gen-stringify-tests; DO NOT EDIT.
394 // Instead, please run "go generate ./..." as described here:
395 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
396
397 package {{ $package := .Package}}{{$package}}
398 {{with .Imports}}
399 import (
400 {{- range . -}}
401 "{{.}}"
402 {{end -}}
403 )
404 {{end}}
405 func Float64(v float64) *float64 { return &v }
406 {{range $key, $value := .StructFields}}
407 func Test{{ $key }}_String(t *testing.T) {
408 v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
409 {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
410 {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
411 }
412 want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
413 {{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
414 if got := v.String(); got != want {
415 t.Errorf("{{ $key }}.String = %v, want %v", got, want)
416 }
417 }
418 {{end}}
419 `
420
View as plain text