1
2
3
4
5
6
7
8
9
10
11
12
13
14 package main
15
16 import (
17 "bytes"
18 "flag"
19 "fmt"
20 "go/ast"
21 "go/format"
22 "go/parser"
23 "go/token"
24 "io/ioutil"
25 "log"
26 "os"
27 "sort"
28 "strings"
29 "text/template"
30 )
31
32 const (
33 fileSuffix = "-accessors.go"
34 )
35
36 var (
37 verbose = flag.Bool("v", false, "Print verbose log messages")
38
39 sourceTmpl = template.Must(template.New("source").Parse(source))
40 testTmpl = template.Must(template.New("test").Parse(test))
41
42
43 skipStructMethods = map[string]bool{
44 "RepositoryContent.GetContent": true,
45 "Client.GetBaseURL": true,
46 "Client.GetUploadURL": true,
47 "ErrorResponse.GetResponse": true,
48 "RateLimitError.GetResponse": true,
49 "AbuseRateLimitError.GetResponse": true,
50 }
51
52 skipStructs = map[string]bool{
53 "Client": true,
54 }
55 )
56
57 func logf(fmt string, args ...interface{}) {
58 if *verbose {
59 log.Printf(fmt, args...)
60 }
61 }
62
63 func main() {
64 flag.Parse()
65 fset := token.NewFileSet()
66
67 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
68 if err != nil {
69 log.Fatal(err)
70 return
71 }
72
73 for pkgName, pkg := range pkgs {
74 t := &templateData{
75 filename: pkgName + fileSuffix,
76 Year: 2017,
77 Package: pkgName,
78 Imports: map[string]string{},
79 }
80 for filename, f := range pkg.Files {
81 logf("Processing %v...", filename)
82 if err := t.processAST(f); err != nil {
83 log.Fatal(err)
84 }
85 }
86 if err := t.dump(); err != nil {
87 log.Fatal(err)
88 }
89 }
90 logf("Done.")
91 }
92
93 func (t *templateData) processAST(f *ast.File) error {
94 for _, decl := range f.Decls {
95 gd, ok := decl.(*ast.GenDecl)
96 if !ok {
97 continue
98 }
99 for _, spec := range gd.Specs {
100 ts, ok := spec.(*ast.TypeSpec)
101 if !ok {
102 continue
103 }
104
105 if !ts.Name.IsExported() {
106 logf("Struct %v is unexported; skipping.", ts.Name)
107 continue
108 }
109
110 if skipStructs[ts.Name.Name] {
111 logf("Struct %v is in skip list; skipping.", ts.Name)
112 continue
113 }
114 st, ok := ts.Type.(*ast.StructType)
115 if !ok {
116 continue
117 }
118 for _, field := range st.Fields.List {
119 if len(field.Names) == 0 {
120 continue
121 }
122
123 fieldName := field.Names[0]
124
125 if !fieldName.IsExported() {
126 logf("Field %v is unexported; skipping.", fieldName)
127 continue
128 }
129
130 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
131 logf("Method %v is skip list; skipping.", key)
132 continue
133 }
134
135 se, ok := field.Type.(*ast.StarExpr)
136 if !ok {
137 switch x := field.Type.(type) {
138 case *ast.MapType:
139 t.addMapType(x, ts.Name.String(), fieldName.String(), false)
140 continue
141 }
142
143 logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
144 continue
145 }
146
147 switch x := se.X.(type) {
148 case *ast.ArrayType:
149 t.addArrayType(x, ts.Name.String(), fieldName.String())
150 case *ast.Ident:
151 t.addIdent(x, ts.Name.String(), fieldName.String())
152 case *ast.MapType:
153 t.addMapType(x, ts.Name.String(), fieldName.String(), true)
154 case *ast.SelectorExpr:
155 t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
156 default:
157 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
158 }
159 }
160 }
161 }
162 return nil
163 }
164
165 func sourceFilter(fi os.FileInfo) bool {
166 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
167 }
168
169 func (t *templateData) dump() error {
170 if len(t.Getters) == 0 {
171 logf("No getters for %v; skipping.", t.filename)
172 return nil
173 }
174
175
176 sort.Sort(byName(t.Getters))
177
178 processTemplate := func(tmpl *template.Template, filename string) error {
179 var buf bytes.Buffer
180 if err := tmpl.Execute(&buf, t); err != nil {
181 return err
182 }
183 clean, err := format.Source(buf.Bytes())
184 if err != nil {
185 return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
186 }
187
188 logf("Writing %v...", filename)
189 if err := os.Chmod(filename, 0644); err != nil {
190 return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
191 }
192
193 if err := ioutil.WriteFile(filename, clean, 0444); err != nil {
194 return err
195 }
196
197 if err := os.Chmod(filename, 0444); err != nil {
198 return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
199 }
200
201 return nil
202 }
203
204 if err := processTemplate(sourceTmpl, t.filename); err != nil {
205 return err
206 }
207 return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
208 }
209
210 func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
211 return &getter{
212 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
213 ReceiverVar: strings.ToLower(receiverType[:1]),
214 ReceiverType: receiverType,
215 FieldName: fieldName,
216 FieldType: fieldType,
217 ZeroValue: zeroValue,
218 NamedStruct: namedStruct,
219 }
220 }
221
222 func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
223 var eltType string
224 switch elt := x.Elt.(type) {
225 case *ast.Ident:
226 eltType = elt.String()
227 default:
228 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
229 return
230 }
231
232 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
233 }
234
235 func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
236 var zeroValue string
237 var namedStruct = false
238 switch x.String() {
239 case "int", "int64":
240 zeroValue = "0"
241 case "string":
242 zeroValue = `""`
243 case "bool":
244 zeroValue = "false"
245 case "Timestamp":
246 zeroValue = "Timestamp{}"
247 default:
248 zeroValue = "nil"
249 namedStruct = true
250 }
251
252 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
253 }
254
255 func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
256 var keyType string
257 switch key := x.Key.(type) {
258 case *ast.Ident:
259 keyType = key.String()
260 default:
261 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
262 return
263 }
264
265 var valueType string
266 switch value := x.Value.(type) {
267 case *ast.Ident:
268 valueType = value.String()
269 default:
270 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
271 return
272 }
273
274 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
275 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
276 ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
277 ng.MapType = !isAPointer
278 t.Getters = append(t.Getters, ng)
279 }
280
281 func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
282 if strings.ToLower(fieldName[:1]) == fieldName[:1] {
283 return
284 }
285
286 var xX string
287 if xx, ok := x.X.(*ast.Ident); ok {
288 xX = xx.String()
289 }
290
291 switch xX {
292 case "time", "json":
293 if xX == "json" {
294 t.Imports["encoding/json"] = "encoding/json"
295 } else {
296 t.Imports[xX] = xX
297 }
298 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
299 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
300 if xX == "time" && x.Sel.Name == "Duration" {
301 zeroValue = "0"
302 }
303 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
304 default:
305 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
306 }
307 }
308
309 type templateData struct {
310 filename string
311 Year int
312 Package string
313 Imports map[string]string
314 Getters []*getter
315 }
316
317 type getter struct {
318 sortVal string
319 ReceiverVar string
320 ReceiverType string
321 FieldName string
322 FieldType string
323 ZeroValue string
324 NamedStruct bool
325 MapType bool
326 }
327
328 type byName []*getter
329
330 func (b byName) Len() int { return len(b) }
331 func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
332 func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
333
334 const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
335 //
336 // Use of this source code is governed by a BSD-style
337 // license that can be found in the LICENSE file.
338
339 // Code generated by gen-accessors; DO NOT EDIT.
340 // Instead, please run "go generate ./..." as described here:
341 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
342
343 package {{.Package}}
344 {{with .Imports}}
345 import (
346 {{- range . -}}
347 "{{.}}"
348 {{end -}}
349 )
350 {{end}}
351 {{range .Getters}}
352 {{if .NamedStruct}}
353 // Get{{.FieldName}} returns the {{.FieldName}} field.
354 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
355 if {{.ReceiverVar}} == nil {
356 return {{.ZeroValue}}
357 }
358 return {{.ReceiverVar}}.{{.FieldName}}
359 }
360 {{else if .MapType}}
361 // Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise.
362 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
363 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
364 return {{.ZeroValue}}
365 }
366 return {{.ReceiverVar}}.{{.FieldName}}
367 }
368 {{else}}
369 // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
370 func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
371 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
372 return {{.ZeroValue}}
373 }
374 return *{{.ReceiverVar}}.{{.FieldName}}
375 }
376 {{end}}
377 {{end}}
378 `
379
380 const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
381 //
382 // Use of this source code is governed by a BSD-style
383 // license that can be found in the LICENSE file.
384
385 // Code generated by gen-accessors; DO NOT EDIT.
386 // Instead, please run "go generate ./..." as described here:
387 // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
388
389 package {{.Package}}
390 {{with .Imports}}
391 import (
392 "testing"
393 {{range . -}}
394 "{{.}}"
395 {{end -}}
396 )
397 {{end}}
398 {{range .Getters}}
399 {{if .NamedStruct}}
400 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
401 {{.ReceiverVar}} := &{{.ReceiverType}}{}
402 {{.ReceiverVar}}.Get{{.FieldName}}()
403 {{.ReceiverVar}} = nil
404 {{.ReceiverVar}}.Get{{.FieldName}}()
405 }
406 {{else if .MapType}}
407 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
408 zeroValue := {{.FieldType}}{}
409 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
410 {{.ReceiverVar}}.Get{{.FieldName}}()
411 {{.ReceiverVar}} = &{{.ReceiverType}}{}
412 {{.ReceiverVar}}.Get{{.FieldName}}()
413 {{.ReceiverVar}} = nil
414 {{.ReceiverVar}}.Get{{.FieldName}}()
415 }
416 {{else}}
417 func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
418 var zeroValue {{.FieldType}}
419 {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
420 {{.ReceiverVar}}.Get{{.FieldName}}()
421 {{.ReceiverVar}} = &{{.ReceiverType}}{}
422 {{.ReceiverVar}}.Get{{.FieldName}}()
423 {{.ReceiverVar}} = nil
424 {{.ReceiverVar}}.Get{{.FieldName}}()
425 }
426 {{end}}
427 {{end}}
428 `
429
View as plain text