...

Source file src/github.com/google/go-github/v55/github/gen-stringify-test.go

Documentation: github.com/google/go-github/v55/github

     1  // Copyright 2019 The go-github AUTHORS. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build ignore
     7  // +build ignore
     8  
     9  // gen-stringify-test generates test methods to test the String methods.
    10  //
    11  // These tests eliminate most of the code coverage problems so that real
    12  // code coverage issues can be more readily identified.
    13  //
    14  // It is meant to be used by go-github contributors in conjunction with the
    15  // go generate tool before sending a PR to GitHub.
    16  // Please see the CONTRIBUTING.md file for more information.
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"flag"
    22  	"fmt"
    23  	"go/ast"
    24  	"go/format"
    25  	"go/parser"
    26  	"go/token"
    27  	"log"
    28  	"os"
    29  	"strings"
    30  	"text/template"
    31  )
    32  
    33  const (
    34  	ignoreFilePrefix1 = "gen-"
    35  	ignoreFilePrefix2 = "github-"
    36  	outputFileSuffix  = "-stringify_test.go"
    37  )
    38  
    39  var (
    40  	verbose = flag.Bool("v", false, "Print verbose log messages")
    41  
    42  	// skipStructMethods lists "struct.method" combos to skip.
    43  	skipStructMethods = map[string]bool{}
    44  	// skipStructs lists structs to skip.
    45  	skipStructs = map[string]bool{
    46  		"RateLimits": true,
    47  	}
    48  
    49  	funcMap = template.FuncMap{
    50  		"isNotLast": func(index int, slice []*structField) string {
    51  			if index+1 < len(slice) {
    52  				return ", "
    53  			}
    54  			return ""
    55  		},
    56  		"processZeroValue": func(v string) string {
    57  			switch v {
    58  			case "Bool(false)":
    59  				return "false"
    60  			case "Float64(0.0)":
    61  				return "0"
    62  			case "0", "Int(0)", "Int64(0)":
    63  				return "0"
    64  			case `""`, `String("")`:
    65  				return `""`
    66  			case "Timestamp{}", "&Timestamp{}":
    67  				return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
    68  			case "nil":
    69  				return "map[]"
    70  			case `[]int{0}`:
    71  				return `[0]`
    72  			case `[]string{""}`:
    73  				return `[""]`
    74  			case "[]Scope{ScopeNone}":
    75  				return `["(no scope)"]`
    76  			}
    77  			log.Fatalf("Unhandled zero value: %q", v)
    78  			return ""
    79  		},
    80  	}
    81  
    82  	sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
    83  )
    84  
    85  func main() {
    86  	flag.Parse()
    87  	fset := token.NewFileSet()
    88  
    89  	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
    90  	if err != nil {
    91  		log.Fatal(err)
    92  		return
    93  	}
    94  
    95  	for pkgName, pkg := range pkgs {
    96  		t := &templateData{
    97  			filename:     pkgName + outputFileSuffix,
    98  			Year:         2019, // No need to change this once set (even in following years).
    99  			Package:      pkgName,
   100  			Imports:      map[string]string{"testing": "testing"},
   101  			StringFuncs:  map[string]bool{},
   102  			StructFields: map[string][]*structField{},
   103  		}
   104  		for filename, f := range pkg.Files {
   105  			logf("Processing %v...", filename)
   106  			if err := t.processAST(f); err != nil {
   107  				log.Fatal(err)
   108  			}
   109  		}
   110  		if err := t.dump(); err != nil {
   111  			log.Fatal(err)
   112  		}
   113  	}
   114  	logf("Done.")
   115  }
   116  
   117  func sourceFilter(fi os.FileInfo) bool {
   118  	return !strings.HasSuffix(fi.Name(), "_test.go") &&
   119  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
   120  		!strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
   121  }
   122  
   123  type templateData struct {
   124  	filename     string
   125  	Year         int
   126  	Package      string
   127  	Imports      map[string]string
   128  	StringFuncs  map[string]bool
   129  	StructFields map[string][]*structField
   130  }
   131  
   132  type structField struct {
   133  	sortVal      string // Lower-case version of "ReceiverType.FieldName".
   134  	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
   135  	ReceiverType string
   136  	FieldName    string
   137  	FieldType    string
   138  	ZeroValue    string
   139  	NamedStruct  bool // Getter for named struct.
   140  }
   141  
   142  func (t *templateData) processAST(f *ast.File) error {
   143  	for _, decl := range f.Decls {
   144  		fn, ok := decl.(*ast.FuncDecl)
   145  		if ok {
   146  			if fn.Recv != nil && len(fn.Recv.List) > 0 {
   147  				id, ok := fn.Recv.List[0].Type.(*ast.Ident)
   148  				if ok && fn.Name.Name == "String" {
   149  					logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
   150  					t.StringFuncs[id.Name] = true
   151  				} else {
   152  					star, ok := fn.Recv.List[0].Type.(*ast.StarExpr)
   153  					if ok && fn.Name.Name == "String" {
   154  						id, ok := star.X.(*ast.Ident)
   155  						if ok {
   156  							logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
   157  							t.StringFuncs[id.Name] = true
   158  						} else {
   159  							logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
   160  						}
   161  					} else {
   162  						logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
   163  					}
   164  				}
   165  			} else {
   166  				logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
   167  			}
   168  			continue
   169  		}
   170  
   171  		gd, ok := decl.(*ast.GenDecl)
   172  		if !ok {
   173  			logf("Ignoring AST decl type %T", decl)
   174  			continue
   175  		}
   176  
   177  		for _, spec := range gd.Specs {
   178  			ts, ok := spec.(*ast.TypeSpec)
   179  			if !ok {
   180  				continue
   181  			}
   182  			// Skip unexported identifiers.
   183  			if !ts.Name.IsExported() {
   184  				logf("Struct %v is unexported; skipping.", ts.Name)
   185  				continue
   186  			}
   187  			// Check if the struct should be skipped.
   188  			if skipStructs[ts.Name.Name] {
   189  				logf("Struct %v is in skip list; skipping.", ts.Name)
   190  				continue
   191  			}
   192  			st, ok := ts.Type.(*ast.StructType)
   193  			if !ok {
   194  				logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
   195  				continue
   196  			}
   197  			for _, field := range st.Fields.List {
   198  				if len(field.Names) == 0 {
   199  					continue
   200  				}
   201  
   202  				fieldName := field.Names[0]
   203  				if id, ok := field.Type.(*ast.Ident); ok {
   204  					t.addIdent(id, ts.Name.String(), fieldName.String())
   205  					continue
   206  				}
   207  
   208  				if at, ok := field.Type.(*ast.ArrayType); ok {
   209  					if id, ok := at.Elt.(*ast.Ident); ok {
   210  						t.addIdentSlice(id, ts.Name.String(), fieldName.String())
   211  						continue
   212  					}
   213  				}
   214  
   215  				se, ok := field.Type.(*ast.StarExpr)
   216  				if !ok {
   217  					logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
   218  					continue
   219  				}
   220  
   221  				// Skip unexported identifiers.
   222  				if !fieldName.IsExported() {
   223  					logf("Field %v is unexported; skipping.", fieldName)
   224  					continue
   225  				}
   226  				// Check if "struct.method" should be skipped.
   227  				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); skipStructMethods[key] {
   228  					logf("Method %v is in skip list; skipping.", key)
   229  					continue
   230  				}
   231  
   232  				switch x := se.X.(type) {
   233  				case *ast.ArrayType:
   234  				case *ast.Ident:
   235  					t.addIdentPtr(x, ts.Name.String(), fieldName.String())
   236  				case *ast.MapType:
   237  				case *ast.SelectorExpr:
   238  				default:
   239  					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
   240  				}
   241  			}
   242  		}
   243  	}
   244  	return nil
   245  }
   246  
   247  func (t *templateData) addMapType(receiverType, fieldName string) {
   248  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
   249  }
   250  
   251  func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
   252  	var zeroValue string
   253  	var namedStruct = false
   254  	switch x.String() {
   255  	case "int":
   256  		zeroValue = "0"
   257  	case "int64":
   258  		zeroValue = "0"
   259  	case "float64":
   260  		zeroValue = "0.0"
   261  	case "string":
   262  		zeroValue = `""`
   263  	case "bool":
   264  		zeroValue = "false"
   265  	case "Timestamp":
   266  		zeroValue = "Timestamp{}"
   267  	default:
   268  		zeroValue = "nil"
   269  		namedStruct = true
   270  	}
   271  
   272  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   273  }
   274  
   275  func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
   276  	var zeroValue string
   277  	var namedStruct = false
   278  	switch x.String() {
   279  	case "int":
   280  		zeroValue = "Int(0)"
   281  	case "int64":
   282  		zeroValue = "Int64(0)"
   283  	case "float64":
   284  		zeroValue = "Float64(0.0)"
   285  	case "string":
   286  		zeroValue = `String("")`
   287  	case "bool":
   288  		zeroValue = "Bool(false)"
   289  	case "Timestamp":
   290  		zeroValue = "&Timestamp{}"
   291  	default:
   292  		zeroValue = "nil"
   293  		namedStruct = true
   294  	}
   295  
   296  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   297  }
   298  
   299  func (t *templateData) addIdentSlice(x *ast.Ident, receiverType, fieldName string) {
   300  	var zeroValue string
   301  	var namedStruct = false
   302  	switch x.String() {
   303  	case "int":
   304  		zeroValue = "[]int{0}"
   305  	case "int64":
   306  		zeroValue = "[]int64{0}"
   307  	case "float64":
   308  		zeroValue = "[]float64{0}"
   309  	case "string":
   310  		zeroValue = `[]string{""}`
   311  	case "bool":
   312  		zeroValue = "[]bool{false}"
   313  	case "Scope":
   314  		zeroValue = "[]Scope{ScopeNone}"
   315  	// case "Timestamp":
   316  	// 	zeroValue = "&Timestamp{}"
   317  	default:
   318  		zeroValue = "nil"
   319  		namedStruct = true
   320  	}
   321  
   322  	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
   323  }
   324  
   325  func (t *templateData) dump() error {
   326  	if len(t.StructFields) == 0 {
   327  		logf("No StructFields for %v; skipping.", t.filename)
   328  		return nil
   329  	}
   330  
   331  	// Remove unused structs.
   332  	var toDelete []string
   333  	for k := range t.StructFields {
   334  		if !t.StringFuncs[k] {
   335  			toDelete = append(toDelete, k)
   336  			continue
   337  		}
   338  	}
   339  	for _, k := range toDelete {
   340  		delete(t.StructFields, k)
   341  	}
   342  
   343  	var buf bytes.Buffer
   344  	if err := sourceTmpl.Execute(&buf, t); err != nil {
   345  		return err
   346  	}
   347  	clean, err := format.Source(buf.Bytes())
   348  	if err != nil {
   349  		log.Printf("failed-to-format source:\n%v", buf.String())
   350  		return err
   351  	}
   352  
   353  	logf("Writing %v...", t.filename)
   354  	if err := os.Chmod(t.filename, 0644); err != nil {
   355  		return fmt.Errorf("os.Chmod(%q, 0644): %v", t.filename, err)
   356  	}
   357  
   358  	if err := os.WriteFile(t.filename, clean, 0444); err != nil {
   359  		return err
   360  	}
   361  
   362  	if err := os.Chmod(t.filename, 0444); err != nil {
   363  		return fmt.Errorf("os.Chmod(%q, 0444): %v", t.filename, err)
   364  	}
   365  
   366  	return nil
   367  }
   368  
   369  func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
   370  	return &structField{
   371  		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
   372  		ReceiverVar:  strings.ToLower(receiverType[:1]),
   373  		ReceiverType: receiverType,
   374  		FieldName:    fieldName,
   375  		FieldType:    fieldType,
   376  		ZeroValue:    zeroValue,
   377  		NamedStruct:  namedStruct,
   378  	}
   379  }
   380  
   381  func logf(fmt string, args ...interface{}) {
   382  	if *verbose {
   383  		log.Printf(fmt, args...)
   384  	}
   385  }
   386  
   387  const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
   388  //
   389  // Use of this source code is governed by a BSD-style
   390  // license that can be found in the LICENSE file.
   391  
   392  // Code generated by gen-stringify-tests; DO NOT EDIT.
   393  // Instead, please run "go generate ./..." as described here:
   394  // https://github.com/google/go-github/blob/master/CONTRIBUTING.md#submitting-a-patch
   395  
   396  package {{ $package := .Package}}{{$package}}
   397  {{with .Imports}}
   398  import (
   399    {{- range . -}}
   400    "{{.}}"
   401    {{end -}}
   402  )
   403  {{end}}
   404  func Float64(v float64) *float64 { return &v }
   405  {{range $key, $value := .StructFields}}
   406  func Test{{ $key }}_String(t *testing.T) {
   407    v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
   408      {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
   409      {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
   410    }
   411   	want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
   412  {{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
   413  	if got := v.String(); got != want {
   414  		t.Errorf("{{ $key }}.String = %v, want %v", got, want)
   415  	}
   416  }
   417  {{end}}
   418  `
   419  

View as plain text