1
2
3
4
5
6
7
8
9
10
11
12
13
14 package main
15
16 import (
17 "bytes"
18 "flag"
19 "fmt"
20 "go/ast"
21 "go/format"
22 "go/parser"
23 "go/token"
24 "log"
25 "os"
26 "sort"
27 "strings"
28 "text/template"
29 )
30
31 const (
32 fileSuffix = "-accessors.go"
33 )
34
35 var (
36 verbose = flag.Bool("v", false, "Print verbose log messages")
37
38 sourceTmpl = template.Must(template.New("source").Parse(source))
39 testTmpl = template.Must(template.New("test").Parse(test))
40
41
42 skipStructMethods = map[string]bool{
43 "RepositoryContent.GetContent": true,
44 "Client.GetBaseURL": true,
45 "Client.GetUploadURL": true,
46 "ErrorResponse.GetResponse": true,
47 "RateLimitError.GetResponse": true,
48 "AbuseRateLimitError.GetResponse": true,
49 }
50
51 skipStructs = map[string]bool{
52 "Client": true,
53 }
54
55
56 whitelistSliceGetters = map[string]bool{
57 "PushEvent.Commits": true,
58 }
59 )
60
61 func logf(fmt string, args ...interface{}) {
62 if *verbose {
63 log.Printf(fmt, args...)
64 }
65 }
66
67 func main() {
68 flag.Parse()
69 fset := token.NewFileSet()
70
71 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
72 if err != nil {
73 log.Fatal(err)
74 return
75 }
76
77 for pkgName, pkg := range pkgs {
78 t := &templateData{
79 filename: pkgName + fileSuffix,
80 Year: 2017,
81 Package: pkgName,
82 Imports: map[string]string{},
83 }
84 for filename, f := range pkg.Files {
85 logf("Processing %v...", filename)
86 if err := t.processAST(f); err != nil {
87 log.Fatal(err)
88 }
89 }
90 if err := t.dump(); err != nil {
91 log.Fatal(err)
92 }
93 }
94 logf("Done.")
95 }
96
97 func (t *templateData) processAST(f *ast.File) error {
98 for _, decl := range f.Decls {
99 gd, ok := decl.(*ast.GenDecl)
100 if !ok {
101 continue
102 }
103 for _, spec := range gd.Specs {
104 ts, ok := spec.(*ast.TypeSpec)
105 if !ok {
106 continue
107 }
108
109 if !ts.Name.IsExported() {
110 logf("Struct %v is unexported; skipping.", ts.Name)
111 continue
112 }
113
114 if skipStructs[ts.Name.Name] {
115 logf("Struct %v is in skip list; skipping.", ts.Name)
116 continue
117 }
118 st, ok := ts.Type.(*ast.StructType)
119 if !ok {
120 continue
121 }
122 for _, field := range st.Fields.List {
123 if len(field.Names) == 0 {
124 continue
125 }
126
127 fieldName := field.Names[0]
128
129 if !fieldName.IsExported() {
130 logf("Field %v is unexported; skipping.", fieldName)
131 continue
132 }
133
134 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
135 logf("Method %v is skip list; skipping.", key)
136 continue
137 }
138
139 se, ok := field.Type.(*ast.StarExpr)
140 if !ok {
141 switch x := field.Type.(type) {
142 case *ast.MapType:
143 t.addMapType(x, ts.Name.String(), fieldName.String(), false)
144 continue
145 case *ast.ArrayType:
146 if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] {
147 logf("Method %v is whitelist; adding getter method.", key)
148 t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
149 continue
150 }
151 }
152
153 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
154 continue
155 }
156
157 switch x := se.X.(type) {
158 case *ast.ArrayType:
159 t.addArrayType(x, ts.Name.String(), fieldName.String(), true)
160 case *ast.Ident:
161 t.addIdent(x, ts.Name.String(), fieldName.String())
162 case *ast.MapType:
163 t.addMapType(x, ts.Name.String(), fieldName.String(), true)
164 case *ast.SelectorExpr:
165 t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
166 default:
167 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
168 }
169 }
170 }
171 }
172 return nil
173 }
174
175 func sourceFilter(fi os.FileInfo) bool {
176 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
177 }
178
179 func (t *templateData) dump() error {
180 if len(t.Getters) == 0 {
181 logf("No getters for %v; skipping.", t.filename)
182 return nil
183 }
184
185
186 sort.Sort(byName(t.Getters))
187
188 processTemplate := func(tmpl *template.Template, filename string) error {
189 var buf bytes.Buffer
190 if err := tmpl.Execute(&buf, t); err != nil {
191 return err
192 }
193 clean, err := format.Source(buf.Bytes())
194 if err != nil {
195 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
196 }
197
198 logf("Writing %v...", filename)
199 if err := os.Chmod(filename, 0644); err != nil {
200 return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
201 }
202
203 if err := os.WriteFile(filename, clean, 0444); err != nil {
204 return err
205 }
206
207 if err := os.Chmod(filename, 0444); err != nil {
208 return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
209 }
210
211 return nil
212 }
213
214 if err := processTemplate(sourceTmpl, t.filename); err != nil {
215 return err
216 }
217 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
218 }
219
220 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
221 return &getter{
222 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
223 ReceiverVar: strings.ToLower(receiverType[:1]),
224 ReceiverType: receiverType,
225 FieldName: fieldName,
226 FieldType: fieldType,
227 ZeroValue: zeroValue,
228 NamedStruct: namedStruct,
229 }
230 }
231
232 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) {
233 var eltType string
234 var ng *getter
235 switch elt := x.Elt.(type) {
236 case *ast.Ident:
237 eltType = elt.String()
238 ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false)
239 case *ast.StarExpr:
240 ident, ok := elt.X.(*ast.Ident)
241 if !ok {
242 return
243 }
244 ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false)
245 default:
246 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
247 return
248 }
249
250 ng.ArrayType = !isAPointer
251 t.Getters = append(t.Getters, ng)
252 }
253
254 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
255 var zeroValue string
256 var namedStruct = false
257 switch x.String() {
258 case "int", "int64":
259 zeroValue = "0"
260 case "string":
261 zeroValue = `""`
262 case "bool":
263 zeroValue = "false"
264 case "Timestamp":
265 zeroValue = "Timestamp{}"
266 default:
267 zeroValue = "nil"
268 namedStruct = true
269 }
270
271 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
272 }
273
274 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
275 var keyType string
276 switch key := x.Key.(type) {
277 case *ast.Ident:
278 keyType = key.String()
279 default:
280 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
281 return
282 }
283
284 var valueType string
285 switch value := x.Value.(type) {
286 case *ast.Ident:
287 valueType = value.String()
288 default:
289 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
290 return
291 }
292
293 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
294 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
295 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
296 ng.MapType = !isAPointer
297 t.Getters = append(t.Getters, ng)
298 }
299
300 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
301 if strings.ToLower(fieldName[:1]) == fieldName[:1] {
302 return
303 }
304
305 var xX string
306 if xx, ok := x.X.(*ast.Ident); ok {
307 xX = xx.String()
308 }
309
310 switch xX {
311 case "time", "json":
312 if xX == "json" {
313 t.Imports["encoding/json"] = "encoding/json"
314 } else {
315 t.Imports[xX] = xX
316 }
317 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
318 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
319 if xX == "time" && x.Sel.Name == "Duration" {
320 zeroValue = "0"
321 }
322 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
323 default:
324 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
325 }
326 }
327
328 type templateData struct {
329 filename string
330 Year int
331 Package string
332 Imports map[string]string
333 Getters []*getter
334 }
335
336 type getter struct {
337 sortVal string
338 ReceiverVar string
339 ReceiverType string
340 FieldName string
341 FieldType string
342 ZeroValue string
343 NamedStruct bool
344 MapType bool
345 ArrayType bool
346 }
347
348 type byName []*getter
349
350 func (b byName) Len() int { return len(b) }
351 func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
352 func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
353
354 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
355 //
356 // Use of this source code is governed by a BSD-style
357 // license that can be found in the LICENSE file.
358
359 // Code generated by gen-accessors; DO NOT EDIT.
360 // Instead, please run "go generate ./..." as described here:
361 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
362
363 package {{.Package}}
364 {{with .Imports}}
365 import (
366 {{- range . -}}
367 "{{.}}"
368 {{end -}}
369 )
370 {{end}}
371 {{range .Getters}}
372 {{if .NamedStruct}}
373 // Get{{.FieldName}} returns the {{.FieldName}} field.
374 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
375 if {{.ReceiverVar}} == nil {
376 return {{.ZeroValue}}
377 }
378 return {{.ReceiverVar}}.{{.FieldName}}
379 }
380 {{else if or .MapType .ArrayType }}
381 // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise.
382 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
383 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
384 return {{.ZeroValue}}
385 }
386 return {{.ReceiverVar}}.{{.FieldName}}
387 }
388 {{else}}
389 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
390 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
391 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
392 return {{.ZeroValue}}
393 }
394 return *{{.ReceiverVar}}.{{.FieldName}}
395 }
396 {{end}}
397 {{end}}
398 `
399
400 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
401 //
402 // Use of this source code is governed by a BSD-style
403 // license that can be found in the LICENSE file.
404
405 // Code generated by gen-accessors; DO NOT EDIT.
406 // Instead, please run "go generate ./..." as described here:
407 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
408
409 package {{.Package}}
410 {{with .Imports}}
411 import (
412 "testing"
413 {{range . -}}
414 "{{.}}"
415 {{end -}}
416 )
417 {{end}}
418 {{range .Getters}}
419 {{if .NamedStruct}}
420 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
421 {{.ReceiverVar}} := &{{.ReceiverType}}{}
422 {{.ReceiverVar}}.Get{{.FieldName}}()
423 {{.ReceiverVar}} = nil
424 {{.ReceiverVar}}.Get{{.FieldName}}()
425 }
426 {{else if or .MapType .ArrayType}}
427 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
428 zeroValue := {{.FieldType}}{}
429 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
430 {{.ReceiverVar}}.Get{{.FieldName}}()
431 {{.ReceiverVar}} = &{{.ReceiverType}}{}
432 {{.ReceiverVar}}.Get{{.FieldName}}()
433 {{.ReceiverVar}} = nil
434 {{.ReceiverVar}}.Get{{.FieldName}}()
435 }
436 {{else}}
437 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
438 var zeroValue {{.FieldType}}
439 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
440 {{.ReceiverVar}}.Get{{.FieldName}}()
441 {{.ReceiverVar}} = &{{.ReceiverType}}{}
442 {{.ReceiverVar}}.Get{{.FieldName}}()
443 {{.ReceiverVar}} = nil
444 {{.ReceiverVar}}.Get{{.FieldName}}()
445 }
446 {{end}}
447 {{end}}
448 `
449
View as plain text