...

Source file src/github.com/google/go-github/v47/github/gen-accessors.go

Documentation: github.com/google/go-github/v47/github

     1  // Copyright 2017 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  // +build ignore
     8  
     9  // gen-accessors generates accessor methods for structs with pointer fields.
    10  //
    11  // It is meant to be used by go-github contributors in conjunction with the
    12  // go generate tool before sending a PR to GitHub.
    13  // Please see the CONTRIBUTING.md file for more information.
    14  package main
    15  
    16  import (
    17  	"bytes"
    18  	"flag"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/format"
    22  	"go/parser"
    23  	"go/token"
    24  	"io/ioutil"
    25  	"log"
    26  	"os"
    27  	"sort"
    28  	"strings"
    29  	"text/template"
    30  )
    31  
    32  const (
    33  	fileSuffix = "-accessors.go"
    34  )
    35  
    36  var (
    37  	verbose = flag.Bool("v", false, "Print verbose log messages")
    38  
    39  	sourceTmpl = template.Must(template.New("source").Parse(source))
    40  	testTmpl   = template.Must(template.New("test").Parse(test))
    41  
    42  	// skipStructMethods lists "struct.method" combos to skip.
    43  	skipStructMethods = map[string]bool{
    44  		"RepositoryContent.GetContent":    true,
    45  		"Client.GetBaseURL":               true,
    46  		"Client.GetUploadURL":             true,
    47  		"ErrorResponse.GetResponse":       true,
    48  		"RateLimitError.GetResponse":      true,
    49  		"AbuseRateLimitError.GetResponse": true,
    50  	}
    51  	// skipStructs lists structs to skip.
    52  	skipStructs = map[string]bool{
    53  		"Client": true,
    54  	}
    55  )
    56  
    57  func logf(fmt string, args ...interface{}) {
    58  	if *verbose {
    59  		log.Printf(fmt, args...)
    60  	}
    61  }
    62  
    63  func main() {
    64  	flag.Parse()
    65  	fset := token.NewFileSet()
    66  
    67  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    68  	if err != nil {
    69  		log.Fatal(err)
    70  		return
    71  	}
    72  
    73  	for pkgName, pkg := range pkgs {
    74  		t := &templateData{
    75  			filename: pkgName + fileSuffix,
    76  			Year:     2017,
    77  			Package:  pkgName,
    78  			Imports:  map[string]string{},
    79  		}
    80  		for filename, f := range pkg.Files {
    81  			logf("Processing %v...", filename)
    82  			if err := t.processAST(f); err != nil {
    83  				log.Fatal(err)
    84  			}
    85  		}
    86  		if err := t.dump(); err != nil {
    87  			log.Fatal(err)
    88  		}
    89  	}
    90  	logf("Done.")
    91  }
    92  
    93  func (t *templateData) processAST(f *ast.File) error {
    94  	for _, decl := range f.Decls {
    95  		gd, ok := decl.(*ast.GenDecl)
    96  		if !ok {
    97  			continue
    98  		}
    99  		for _, spec := range gd.Specs {
   100  			ts, ok := spec.(*ast.TypeSpec)
   101  			if !ok {
   102  				continue
   103  			}
   104  			// Skip unexported identifiers.
   105  			if !ts.Name.IsExported() {
   106  				logf("Struct %v is unexported; skipping.", ts.Name)
   107  				continue
   108  			}
   109  			// Check if the struct should be skipped.
   110  			if skipStructs[ts.Name.Name] {
   111  				logf("Struct %v is in skip list; skipping.", ts.Name)
   112  				continue
   113  			}
   114  			st, ok := ts.Type.(*ast.StructType)
   115  			if !ok {
   116  				continue
   117  			}
   118  			for _, field := range st.Fields.List {
   119  				if len(field.Names) == 0 {
   120  					continue
   121  				}
   122  
   123  				fieldName := field.Names[0]
   124  				// Skip unexported identifiers.
   125  				if !fieldName.IsExported() {
   126  					logf("Field %v is unexported; skipping.", fieldName)
   127  					continue
   128  				}
   129  				// Check if "struct.method" should be skipped.
   130  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   131  					logf("Method %v is skip list; skipping.", key)
   132  					continue
   133  				}
   134  
   135  				se, ok := field.Type.(*ast.StarExpr)
   136  				if !ok {
   137  					switch x := field.Type.(type) {
   138  					case *ast.MapType:
   139  						t.addMapType(x, ts.Name.String(), fieldName.String(), false)
   140  						continue
   141  					}
   142  
   143  					logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
   144  					continue
   145  				}
   146  
   147  				switch x := se.X.(type) {
   148  				case *ast.ArrayType:
   149  					t.addArrayType(x, ts.Name.String(), fieldName.String())
   150  				case *ast.Ident:
   151  					t.addIdent(x, ts.Name.String(), fieldName.String())
   152  				case *ast.MapType:
   153  					t.addMapType(x, ts.Name.String(), fieldName.String(), true)
   154  				case *ast.SelectorExpr:
   155  					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
   156  				default:
   157  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   158  				}
   159  			}
   160  		}
   161  	}
   162  	return nil
   163  }
   164  
   165  func sourceFilter(fi os.FileInfo) bool {
   166  	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
   167  }
   168  
   169  func (t *templateData) dump() error {
   170  	if len(t.Getters) == 0 {
   171  		logf("No getters for %v; skipping.", t.filename)
   172  		return nil
   173  	}
   174  
   175  	// Sort getters by ReceiverType.FieldName.
   176  	sort.Sort(byName(t.Getters))
   177  
   178  	processTemplate := func(tmpl *template.Template, filename string) error {
   179  		var buf bytes.Buffer
   180  		if err := tmpl.Execute(&buf, t); err != nil {
   181  			return err
   182  		}
   183  		clean, err := format.Source(buf.Bytes())
   184  		if err != nil {
   185  			return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
   186  		}
   187  
   188  		logf("Writing %v...", filename)
   189  		if err := os.Chmod(filename, 0644); err != nil {
   190  			return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
   191  		}
   192  
   193  		if err := ioutil.WriteFile(filename, clean, 0444); err != nil {
   194  			return err
   195  		}
   196  
   197  		if err := os.Chmod(filename, 0444); err != nil {
   198  			return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
   199  		}
   200  
   201  		return nil
   202  	}
   203  
   204  	if err := processTemplate(sourceTmpl, t.filename); err != nil {
   205  		return err
   206  	}
   207  	return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
   208  }
   209  
   210  func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
   211  	return &getter{
   212  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   213  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   214  		ReceiverType: receiverType,
   215  		FieldName:    fieldName,
   216  		FieldType:    fieldType,
   217  		ZeroValue:    zeroValue,
   218  		NamedStruct:  namedStruct,
   219  	}
   220  }
   221  
   222  func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
   223  	var eltType string
   224  	switch elt := x.Elt.(type) {
   225  	case *ast.Ident:
   226  		eltType = elt.String()
   227  	default:
   228  		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
   229  		return
   230  	}
   231  
   232  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
   233  }
   234  
   235  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   236  	var zeroValue string
   237  	var namedStruct = false
   238  	switch x.String() {
   239  	case "int", "int64":
   240  		zeroValue = "0"
   241  	case "string":
   242  		zeroValue = `""`
   243  	case "bool":
   244  		zeroValue = "false"
   245  	case "Timestamp":
   246  		zeroValue = "Timestamp{}"
   247  	default:
   248  		zeroValue = "nil"
   249  		namedStruct = true
   250  	}
   251  
   252  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   253  }
   254  
   255  func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
   256  	var keyType string
   257  	switch key := x.Key.(type) {
   258  	case *ast.Ident:
   259  		keyType = key.String()
   260  	default:
   261  		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
   262  		return
   263  	}
   264  
   265  	var valueType string
   266  	switch value := x.Value.(type) {
   267  	case *ast.Ident:
   268  		valueType = value.String()
   269  	default:
   270  		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
   271  		return
   272  	}
   273  
   274  	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
   275  	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
   276  	ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
   277  	ng.MapType = !isAPointer
   278  	t.Getters = append(t.Getters, ng)
   279  }
   280  
   281  func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
   282  	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
   283  		return
   284  	}
   285  
   286  	var xX string
   287  	if xx, ok := x.X.(*ast.Ident); ok {
   288  		xX = xx.String()
   289  	}
   290  
   291  	switch xX {
   292  	case "time", "json":
   293  		if xX == "json" {
   294  			t.Imports["encoding/json"] = "encoding/json"
   295  		} else {
   296  			t.Imports[xX] = xX
   297  		}
   298  		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
   299  		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
   300  		if xX == "time" && x.Sel.Name == "Duration" {
   301  			zeroValue = "0"
   302  		}
   303  		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   304  	default:
   305  		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
   306  	}
   307  }
   308  
   309  type templateData struct {
   310  	filename string
   311  	Year     int
   312  	Package  string
   313  	Imports  map[string]string
   314  	Getters  []*getter
   315  }
   316  
   317  type getter struct {
   318  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   319  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   320  	ReceiverType string
   321  	FieldName    string
   322  	FieldType    string
   323  	ZeroValue    string
   324  	NamedStruct  bool // Getter for named struct.
   325  	MapType      bool
   326  }
   327  
   328  type byName []*getter
   329  
   330  func (b byName) Len() int           { return len(b) }
   331  func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
   332  func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   333  
   334  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   335  //
   336  // Use of this source code is governed by a BSD-style
   337  // license that can be found in the LICENSE file.
   338  
   339  // Code generated by gen-accessors; DO NOT EDIT.
   340  // Instead, please run "go generate ./..." as described here:
   341  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   342  
   343  package {{.Package}}
   344  {{with .Imports}}
   345  import (
   346    {{- range . -}}
   347    "{{.}}"
   348    {{end -}}
   349  )
   350  {{end}}
   351  {{range .Getters}}
   352  {{if .NamedStruct}}
   353  // Get{{.FieldName}} returns the {{.FieldName}} field.
   354  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
   355    if {{.ReceiverVar}} == nil {
   356      return {{.ZeroValue}}
   357    }
   358    return {{.ReceiverVar}}.{{.FieldName}}
   359  }
   360  {{else if .MapType}}
   361  // Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise.
   362  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   363    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   364      return {{.ZeroValue}}
   365    }
   366    return {{.ReceiverVar}}.{{.FieldName}}
   367  }
   368  {{else}}
   369  // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
   370  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   371    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   372      return {{.ZeroValue}}
   373    }
   374    return *{{.ReceiverVar}}.{{.FieldName}}
   375  }
   376  {{end}}
   377  {{end}}
   378  `
   379  
   380  const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   381  //
   382  // Use of this source code is governed by a BSD-style
   383  // license that can be found in the LICENSE file.
   384  
   385  // Code generated by gen-accessors; DO NOT EDIT.
   386  // Instead, please run "go generate ./..." as described here:
   387  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   388  
   389  package {{.Package}}
   390  {{with .Imports}}
   391  import (
   392    "testing"
   393    {{range . -}}
   394    "{{.}}"
   395    {{end -}}
   396  )
   397  {{end}}
   398  {{range .Getters}}
   399  {{if .NamedStruct}}
   400  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   401    {{.ReceiverVar}} := &{{.ReceiverType}}{}
   402    {{.ReceiverVar}}.Get{{.FieldName}}()
   403    {{.ReceiverVar}} = nil
   404    {{.ReceiverVar}}.Get{{.FieldName}}()
   405  }
   406  {{else if .MapType}}
   407  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   408    zeroValue := {{.FieldType}}{}
   409    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
   410    {{.ReceiverVar}}.Get{{.FieldName}}()
   411    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   412    {{.ReceiverVar}}.Get{{.FieldName}}()
   413    {{.ReceiverVar}} = nil
   414    {{.ReceiverVar}}.Get{{.FieldName}}()
   415  }
   416  {{else}}
   417  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   418    var zeroValue {{.FieldType}}
   419    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
   420    {{.ReceiverVar}}.Get{{.FieldName}}()
   421    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   422    {{.ReceiverVar}}.Get{{.FieldName}}()
   423    {{.ReceiverVar}} = nil
   424    {{.ReceiverVar}}.Get{{.FieldName}}()
   425  }
   426  {{end}}
   427  {{end}}
   428  `
   429  

View as plain text