...

Source file src/github.com/google/go-github/v55/github/gen-accessors.go

Documentation: github.com/google/go-github/v55/github

     1  // Copyright 2017 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  // +build ignore
     8  
     9  // gen-accessors generates accessor methods for structs with pointer fields.
    10  //
    11  // It is meant to be used by go-github contributors in conjunction with the
    12  // go generate tool before sending a PR to GitHub.
    13  // Please see the CONTRIBUTING.md file for more information.
    14  package main
    15  
    16  import (
    17  	"bytes"
    18  	"flag"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/format"
    22  	"go/parser"
    23  	"go/token"
    24  	"log"
    25  	"os"
    26  	"sort"
    27  	"strings"
    28  	"text/template"
    29  )
    30  
    31  const (
    32  	fileSuffix = "-accessors.go"
    33  )
    34  
    35  var (
    36  	verbose = flag.Bool("v", false, "Print verbose log messages")
    37  
    38  	sourceTmpl = template.Must(template.New("source").Parse(source))
    39  	testTmpl   = template.Must(template.New("test").Parse(test))
    40  
    41  	// skipStructMethods lists "struct.method" combos to skip.
    42  	skipStructMethods = map[string]bool{
    43  		"RepositoryContent.GetContent":    true,
    44  		"Client.GetBaseURL":               true,
    45  		"Client.GetUploadURL":             true,
    46  		"ErrorResponse.GetResponse":       true,
    47  		"RateLimitError.GetResponse":      true,
    48  		"AbuseRateLimitError.GetResponse": true,
    49  	}
    50  	// skipStructs lists structs to skip.
    51  	skipStructs = map[string]bool{
    52  		"Client": true,
    53  	}
    54  
    55  	// whitelistSliceGetters lists "struct.field" to add getter method
    56  	whitelistSliceGetters = map[string]bool{
    57  		"PushEvent.Commits": true,
    58  	}
    59  )
    60  
    61  func logf(fmt string, args ...interface{}) {
    62  	if *verbose {
    63  		log.Printf(fmt, args...)
    64  	}
    65  }
    66  
    67  func main() {
    68  	flag.Parse()
    69  	fset := token.NewFileSet()
    70  
    71  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    72  	if err != nil {
    73  		log.Fatal(err)
    74  		return
    75  	}
    76  
    77  	for pkgName, pkg := range pkgs {
    78  		t := &templateData{
    79  			filename: pkgName + fileSuffix,
    80  			Year:     2017,
    81  			Package:  pkgName,
    82  			Imports:  map[string]string{},
    83  		}
    84  		for filename, f := range pkg.Files {
    85  			logf("Processing %v...", filename)
    86  			if err := t.processAST(f); err != nil {
    87  				log.Fatal(err)
    88  			}
    89  		}
    90  		if err := t.dump(); err != nil {
    91  			log.Fatal(err)
    92  		}
    93  	}
    94  	logf("Done.")
    95  }
    96  
    97  func (t *templateData) processAST(f *ast.File) error {
    98  	for _, decl := range f.Decls {
    99  		gd, ok := decl.(*ast.GenDecl)
   100  		if !ok {
   101  			continue
   102  		}
   103  		for _, spec := range gd.Specs {
   104  			ts, ok := spec.(*ast.TypeSpec)
   105  			if !ok {
   106  				continue
   107  			}
   108  			// Skip unexported identifiers.
   109  			if !ts.Name.IsExported() {
   110  				logf("Struct %v is unexported; skipping.", ts.Name)
   111  				continue
   112  			}
   113  			// Check if the struct should be skipped.
   114  			if skipStructs[ts.Name.Name] {
   115  				logf("Struct %v is in skip list; skipping.", ts.Name)
   116  				continue
   117  			}
   118  			st, ok := ts.Type.(*ast.StructType)
   119  			if !ok {
   120  				continue
   121  			}
   122  			for _, field := range st.Fields.List {
   123  				if len(field.Names) == 0 {
   124  					continue
   125  				}
   126  
   127  				fieldName := field.Names[0]
   128  				// Skip unexported identifiers.
   129  				if !fieldName.IsExported() {
   130  					logf("Field %v is unexported; skipping.", fieldName)
   131  					continue
   132  				}
   133  				// Check if "struct.method" should be skipped.
   134  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   135  					logf("Method %v is skip list; skipping.", key)
   136  					continue
   137  				}
   138  
   139  				se, ok := field.Type.(*ast.StarExpr)
   140  				if !ok {
   141  					switch x := field.Type.(type) {
   142  					case *ast.MapType:
   143  						t.addMapType(x, ts.Name.String(), fieldName.String(), false)
   144  						continue
   145  					case *ast.ArrayType:
   146  						if key := fmt.Sprintf("%v.%v", ts.Name, fieldName); whitelistSliceGetters[key] {
   147  							logf("Method %v is whitelist; adding getter method.", key)
   148  							t.addArrayType(x, ts.Name.String(), fieldName.String(), false)
   149  							continue
   150  						}
   151  					}
   152  
   153  					logf("Skipping field type %T, fieldName=%v", field.Type, fieldName)
   154  					continue
   155  				}
   156  
   157  				switch x := se.X.(type) {
   158  				case *ast.ArrayType:
   159  					t.addArrayType(x, ts.Name.String(), fieldName.String(), true)
   160  				case *ast.Ident:
   161  					t.addIdent(x, ts.Name.String(), fieldName.String())
   162  				case *ast.MapType:
   163  					t.addMapType(x, ts.Name.String(), fieldName.String(), true)
   164  				case *ast.SelectorExpr:
   165  					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
   166  				default:
   167  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   168  				}
   169  			}
   170  		}
   171  	}
   172  	return nil
   173  }
   174  
   175  func sourceFilter(fi os.FileInfo) bool {
   176  	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
   177  }
   178  
   179  func (t *templateData) dump() error {
   180  	if len(t.Getters) == 0 {
   181  		logf("No getters for %v; skipping.", t.filename)
   182  		return nil
   183  	}
   184  
   185  	// Sort getters by ReceiverType.FieldName.
   186  	sort.Sort(byName(t.Getters))
   187  
   188  	processTemplate := func(tmpl *template.Template, filename string) error {
   189  		var buf bytes.Buffer
   190  		if err := tmpl.Execute(&buf, t); err != nil {
   191  			return err
   192  		}
   193  		clean, err := format.Source(buf.Bytes())
   194  		if err != nil {
   195  			return fmt.Errorf("format.Source:\n%v\n%v", buf.String(), err)
   196  		}
   197  
   198  		logf("Writing %v...", filename)
   199  		if err := os.Chmod(filename, 0644); err != nil {
   200  			return fmt.Errorf("os.Chmod(%q, 0644): %v", filename, err)
   201  		}
   202  
   203  		if err := os.WriteFile(filename, clean, 0444); err != nil {
   204  			return err
   205  		}
   206  
   207  		if err := os.Chmod(filename, 0444); err != nil {
   208  			return fmt.Errorf("os.Chmod(%q, 0444): %v", filename, err)
   209  		}
   210  
   211  		return nil
   212  	}
   213  
   214  	if err := processTemplate(sourceTmpl, t.filename); err != nil {
   215  		return err
   216  	}
   217  	return processTemplate(testTmpl, strings.ReplaceAll(t.filename, ".go", "_test.go"))
   218  }
   219  
   220  func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
   221  	return &getter{
   222  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   223  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   224  		ReceiverType: receiverType,
   225  		FieldName:    fieldName,
   226  		FieldType:    fieldType,
   227  		ZeroValue:    zeroValue,
   228  		NamedStruct:  namedStruct,
   229  	}
   230  }
   231  
   232  func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string, isAPointer bool) {
   233  	var eltType string
   234  	var ng *getter
   235  	switch elt := x.Elt.(type) {
   236  	case *ast.Ident:
   237  		eltType = elt.String()
   238  		ng = newGetter(receiverType, fieldName, "[]"+eltType, "nil", false)
   239  	case *ast.StarExpr:
   240  		ident, ok := elt.X.(*ast.Ident)
   241  		if !ok {
   242  			return
   243  		}
   244  		ng = newGetter(receiverType, fieldName, "[]*"+ident.String(), "nil", false)
   245  	default:
   246  		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
   247  		return
   248  	}
   249  
   250  	ng.ArrayType = !isAPointer
   251  	t.Getters = append(t.Getters, ng)
   252  }
   253  
   254  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   255  	var zeroValue string
   256  	var namedStruct = false
   257  	switch x.String() {
   258  	case "int", "int64":
   259  		zeroValue = "0"
   260  	case "string":
   261  		zeroValue = `""`
   262  	case "bool":
   263  		zeroValue = "false"
   264  	case "Timestamp":
   265  		zeroValue = "Timestamp{}"
   266  	default:
   267  		zeroValue = "nil"
   268  		namedStruct = true
   269  	}
   270  
   271  	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   272  }
   273  
   274  func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) {
   275  	var keyType string
   276  	switch key := x.Key.(type) {
   277  	case *ast.Ident:
   278  		keyType = key.String()
   279  	default:
   280  		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
   281  		return
   282  	}
   283  
   284  	var valueType string
   285  	switch value := x.Value.(type) {
   286  	case *ast.Ident:
   287  		valueType = value.String()
   288  	default:
   289  		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
   290  		return
   291  	}
   292  
   293  	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
   294  	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
   295  	ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false)
   296  	ng.MapType = !isAPointer
   297  	t.Getters = append(t.Getters, ng)
   298  }
   299  
   300  func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
   301  	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
   302  		return
   303  	}
   304  
   305  	var xX string
   306  	if xx, ok := x.X.(*ast.Ident); ok {
   307  		xX = xx.String()
   308  	}
   309  
   310  	switch xX {
   311  	case "time", "json":
   312  		if xX == "json" {
   313  			t.Imports["encoding/json"] = "encoding/json"
   314  		} else {
   315  			t.Imports[xX] = xX
   316  		}
   317  		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
   318  		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
   319  		if xX == "time" && x.Sel.Name == "Duration" {
   320  			zeroValue = "0"
   321  		}
   322  		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
   323  	default:
   324  		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
   325  	}
   326  }
   327  
   328  type templateData struct {
   329  	filename string
   330  	Year     int
   331  	Package  string
   332  	Imports  map[string]string
   333  	Getters  []*getter
   334  }
   335  
   336  type getter struct {
   337  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   338  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   339  	ReceiverType string
   340  	FieldName    string
   341  	FieldType    string
   342  	ZeroValue    string
   343  	NamedStruct  bool // Getter for named struct.
   344  	MapType      bool
   345  	ArrayType    bool
   346  }
   347  
   348  type byName []*getter
   349  
   350  func (b byName) Len() int           { return len(b) }
   351  func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
   352  func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
   353  
   354  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   355  //
   356  // Use of this source code is governed by a BSD-style
   357  // license that can be found in the LICENSE file.
   358  
   359  // Code generated by gen-accessors; DO NOT EDIT.
   360  // Instead, please run "go generate ./..." as described here:
   361  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   362  
   363  package {{.Package}}
   364  {{with .Imports}}
   365  import (
   366    {{- range . -}}
   367    "{{.}}"
   368    {{end -}}
   369  )
   370  {{end}}
   371  {{range .Getters}}
   372  {{if .NamedStruct}}
   373  // Get{{.FieldName}} returns the {{.FieldName}} field.
   374  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
   375    if {{.ReceiverVar}} == nil {
   376      return {{.ZeroValue}}
   377    }
   378    return {{.ReceiverVar}}.{{.FieldName}}
   379  }
   380  {{else if or .MapType .ArrayType }}
   381  // Get{{.FieldName}} returns the {{.FieldName}} {{if .MapType}}map{{else if .ArrayType }}slice{{end}} if it's non-nil, {{if .MapType}}an empty map{{else if .ArrayType }}nil{{end}} otherwise.
   382  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   383    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   384      return {{.ZeroValue}}
   385    }
   386    return {{.ReceiverVar}}.{{.FieldName}}
   387  }
   388  {{else}}
   389  // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
   390  func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
   391    if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
   392      return {{.ZeroValue}}
   393    }
   394    return *{{.ReceiverVar}}.{{.FieldName}}
   395  }
   396  {{end}}
   397  {{end}}
   398  `
   399  
   400  const test = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   401  //
   402  // Use of this source code is governed by a BSD-style
   403  // license that can be found in the LICENSE file.
   404  
   405  // Code generated by gen-accessors; DO NOT EDIT.
   406  // Instead, please run "go generate ./..." as described here:
   407  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   408  
   409  package {{.Package}}
   410  {{with .Imports}}
   411  import (
   412    "testing"
   413    {{range . -}}
   414    "{{.}}"
   415    {{end -}}
   416  )
   417  {{end}}
   418  {{range .Getters}}
   419  {{if .NamedStruct}}
   420  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   421    {{.ReceiverVar}} := &{{.ReceiverType}}{}
   422    {{.ReceiverVar}}.Get{{.FieldName}}()
   423    {{.ReceiverVar}} = nil
   424    {{.ReceiverVar}}.Get{{.FieldName}}()
   425  }
   426  {{else if or .MapType .ArrayType}}
   427  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   428    zeroValue := {{.FieldType}}{}
   429    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
   430    {{.ReceiverVar}}.Get{{.FieldName}}()
   431    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   432    {{.ReceiverVar}}.Get{{.FieldName}}()
   433    {{.ReceiverVar}} = nil
   434    {{.ReceiverVar}}.Get{{.FieldName}}()
   435  }
   436  {{else}}
   437  func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
   438    var zeroValue {{.FieldType}}
   439    {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: &zeroValue }
   440    {{.ReceiverVar}}.Get{{.FieldName}}()
   441    {{.ReceiverVar}} = &{{.ReceiverType}}{}
   442    {{.ReceiverVar}}.Get{{.FieldName}}()
   443    {{.ReceiverVar}} = nil
   444    {{.ReceiverVar}}.Get{{.FieldName}}()
   445  }
   446  {{end}}
   447  {{end}}
   448  `
   449  

View as plain text