...

Source file src/github.com/mailru/easyjson/gen/generator.go

Documentation: github.com/mailru/easyjson/gen

     1  package gen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"hash/fnv"
     7  	"io"
     8  	"path"
     9  	"reflect"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"unicode"
    14  )
    15  
    16  const pkgWriter = "github.com/mailru/easyjson/jwriter"
    17  const pkgLexer = "github.com/mailru/easyjson/jlexer"
    18  const pkgEasyJSON = "github.com/mailru/easyjson"
    19  
    20  // FieldNamer defines a policy for generating names for struct fields.
    21  type FieldNamer interface {
    22  	GetJSONFieldName(t reflect.Type, f reflect.StructField) string
    23  }
    24  
    25  // Generator generates the requested marshaler/unmarshalers.
    26  type Generator struct {
    27  	out *bytes.Buffer
    28  
    29  	pkgName    string
    30  	pkgPath    string
    31  	buildTags  string
    32  	hashString string
    33  
    34  	varCounter int
    35  
    36  	noStdMarshalers          bool
    37  	omitEmpty                bool
    38  	disallowUnknownFields    bool
    39  	fieldNamer               FieldNamer
    40  	simpleBytes              bool
    41  	skipMemberNameUnescaping bool
    42  
    43  	// package path to local alias map for tracking imports
    44  	imports map[string]string
    45  
    46  	// types that marshalers were requested for by user
    47  	marshalers map[reflect.Type]bool
    48  
    49  	// types that encoders were already generated for
    50  	typesSeen map[reflect.Type]bool
    51  
    52  	// types that encoders were requested for (e.g. by encoders of other types)
    53  	typesUnseen []reflect.Type
    54  
    55  	// function name to relevant type maps to track names of de-/encoders in
    56  	// case of a name clash or unnamed structs
    57  	functionNames map[string]reflect.Type
    58  }
    59  
    60  // NewGenerator initializes and returns a Generator.
    61  func NewGenerator(filename string) *Generator {
    62  	ret := &Generator{
    63  		imports: map[string]string{
    64  			pkgWriter:       "jwriter",
    65  			pkgLexer:        "jlexer",
    66  			pkgEasyJSON:     "easyjson",
    67  			"encoding/json": "json",
    68  		},
    69  		fieldNamer:    DefaultFieldNamer{},
    70  		marshalers:    make(map[reflect.Type]bool),
    71  		typesSeen:     make(map[reflect.Type]bool),
    72  		functionNames: make(map[string]reflect.Type),
    73  	}
    74  
    75  	// Use a file-unique prefix on all auxiliary funcs to avoid
    76  	// name clashes.
    77  	hash := fnv.New32()
    78  	hash.Write([]byte(filename))
    79  	ret.hashString = fmt.Sprintf("%x", hash.Sum32())
    80  
    81  	return ret
    82  }
    83  
    84  // SetPkg sets the name and path of output package.
    85  func (g *Generator) SetPkg(name, path string) {
    86  	g.pkgName = name
    87  	g.pkgPath = path
    88  }
    89  
    90  // SetBuildTags sets build tags for the output file.
    91  func (g *Generator) SetBuildTags(tags string) {
    92  	g.buildTags = tags
    93  }
    94  
    95  // SetFieldNamer sets field naming strategy.
    96  func (g *Generator) SetFieldNamer(n FieldNamer) {
    97  	g.fieldNamer = n
    98  }
    99  
   100  // UseSnakeCase sets snake_case field naming strategy.
   101  func (g *Generator) UseSnakeCase() {
   102  	g.fieldNamer = SnakeCaseFieldNamer{}
   103  }
   104  
   105  // UseLowerCamelCase sets lowerCamelCase field naming strategy.
   106  func (g *Generator) UseLowerCamelCase() {
   107  	g.fieldNamer = LowerCamelCaseFieldNamer{}
   108  }
   109  
   110  // NoStdMarshalers instructs not to generate standard MarshalJSON/UnmarshalJSON
   111  // methods (only the custom interface).
   112  func (g *Generator) NoStdMarshalers() {
   113  	g.noStdMarshalers = true
   114  }
   115  
   116  // DisallowUnknownFields instructs not to skip unknown fields in json and return error.
   117  func (g *Generator) DisallowUnknownFields() {
   118  	g.disallowUnknownFields = true
   119  }
   120  
   121  // SkipMemberNameUnescaping instructs to skip member names unescaping to improve performance
   122  func (g *Generator) SkipMemberNameUnescaping() {
   123  	g.skipMemberNameUnescaping = true
   124  }
   125  
   126  // OmitEmpty triggers `json=",omitempty"` behaviour by default.
   127  func (g *Generator) OmitEmpty() {
   128  	g.omitEmpty = true
   129  }
   130  
   131  // SimpleBytes triggers generate output bytes as slice byte
   132  func (g *Generator) SimpleBytes() {
   133  	g.simpleBytes = true
   134  }
   135  
   136  // addTypes requests to generate encoding/decoding funcs for the given type.
   137  func (g *Generator) addType(t reflect.Type) {
   138  	if g.typesSeen[t] {
   139  		return
   140  	}
   141  	for _, t1 := range g.typesUnseen {
   142  		if t1 == t {
   143  			return
   144  		}
   145  	}
   146  	g.typesUnseen = append(g.typesUnseen, t)
   147  }
   148  
   149  // Add requests to generate marshaler/unmarshalers and encoding/decoding
   150  // funcs for the type of given object.
   151  func (g *Generator) Add(obj interface{}) {
   152  	t := reflect.TypeOf(obj)
   153  	if t.Kind() == reflect.Ptr {
   154  		t = t.Elem()
   155  	}
   156  	g.addType(t)
   157  	g.marshalers[t] = true
   158  }
   159  
   160  // printHeader prints package declaration and imports.
   161  func (g *Generator) printHeader() {
   162  	if g.buildTags != "" {
   163  		fmt.Println("// +build ", g.buildTags)
   164  		fmt.Println()
   165  	}
   166  	fmt.Println("// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT.")
   167  	fmt.Println()
   168  	fmt.Println("package ", g.pkgName)
   169  	fmt.Println()
   170  
   171  	byAlias := make(map[string]string, len(g.imports))
   172  	aliases := make([]string, 0, len(g.imports))
   173  
   174  	for path, alias := range g.imports {
   175  		aliases = append(aliases, alias)
   176  		byAlias[alias] = path
   177  	}
   178  
   179  	sort.Strings(aliases)
   180  	fmt.Println("import (")
   181  	for _, alias := range aliases {
   182  		fmt.Printf("  %s %q\n", alias, byAlias[alias])
   183  	}
   184  
   185  	fmt.Println(")")
   186  	fmt.Println("")
   187  	fmt.Println("// suppress unused package warning")
   188  	fmt.Println("var (")
   189  	fmt.Println("   _ *json.RawMessage")
   190  	fmt.Println("   _ *jlexer.Lexer")
   191  	fmt.Println("   _ *jwriter.Writer")
   192  	fmt.Println("   _ easyjson.Marshaler")
   193  	fmt.Println(")")
   194  
   195  	fmt.Println()
   196  }
   197  
   198  // Run runs the generator and outputs generated code to out.
   199  func (g *Generator) Run(out io.Writer) error {
   200  	g.out = &bytes.Buffer{}
   201  
   202  	for len(g.typesUnseen) > 0 {
   203  		t := g.typesUnseen[len(g.typesUnseen)-1]
   204  		g.typesUnseen = g.typesUnseen[:len(g.typesUnseen)-1]
   205  		g.typesSeen[t] = true
   206  
   207  		if err := g.genDecoder(t); err != nil {
   208  			return err
   209  		}
   210  		if err := g.genEncoder(t); err != nil {
   211  			return err
   212  		}
   213  
   214  		if !g.marshalers[t] {
   215  			continue
   216  		}
   217  
   218  		if err := g.genStructMarshaler(t); err != nil {
   219  			return err
   220  		}
   221  		if err := g.genStructUnmarshaler(t); err != nil {
   222  			return err
   223  		}
   224  	}
   225  	g.printHeader()
   226  	_, err := out.Write(g.out.Bytes())
   227  	return err
   228  }
   229  
   230  // fixes vendored paths
   231  func fixPkgPathVendoring(pkgPath string) string {
   232  	const vendor = "/vendor/"
   233  	if i := strings.LastIndex(pkgPath, vendor); i != -1 {
   234  		return pkgPath[i+len(vendor):]
   235  	}
   236  	return pkgPath
   237  }
   238  
   239  func fixAliasName(alias string) string {
   240  	alias = strings.Replace(
   241  		strings.Replace(alias, ".", "_", -1),
   242  		"-",
   243  		"_",
   244  		-1,
   245  	)
   246  
   247  	if alias[0] == 'v' { // to void conflicting with var names, say v1
   248  		alias = "_" + alias
   249  	}
   250  	return alias
   251  }
   252  
   253  // pkgAlias creates and returns and import alias for a given package.
   254  func (g *Generator) pkgAlias(pkgPath string) string {
   255  	pkgPath = fixPkgPathVendoring(pkgPath)
   256  	if alias := g.imports[pkgPath]; alias != "" {
   257  		return alias
   258  	}
   259  
   260  	for i := 0; ; i++ {
   261  		alias := fixAliasName(path.Base(pkgPath))
   262  		if i > 0 {
   263  			alias += fmt.Sprint(i)
   264  		}
   265  
   266  		exists := false
   267  		for _, v := range g.imports {
   268  			if v == alias {
   269  				exists = true
   270  				break
   271  			}
   272  		}
   273  
   274  		if !exists {
   275  			g.imports[pkgPath] = alias
   276  			return alias
   277  		}
   278  	}
   279  }
   280  
   281  // getType return the textual type name of given type that can be used in generated code.
   282  func (g *Generator) getType(t reflect.Type) string {
   283  	if t.Name() == "" {
   284  		switch t.Kind() {
   285  		case reflect.Ptr:
   286  			return "*" + g.getType(t.Elem())
   287  		case reflect.Slice:
   288  			return "[]" + g.getType(t.Elem())
   289  		case reflect.Array:
   290  			return "[" + strconv.Itoa(t.Len()) + "]" + g.getType(t.Elem())
   291  		case reflect.Map:
   292  			return "map[" + g.getType(t.Key()) + "]" + g.getType(t.Elem())
   293  		}
   294  	}
   295  
   296  	if t.Name() == "" || t.PkgPath() == "" {
   297  		if t.Kind() == reflect.Struct {
   298  			// the fields of an anonymous struct can have named types,
   299  			// and t.String() will not be sufficient because it does not
   300  			// remove the package name when it matches g.pkgPath.
   301  			// so we convert by hand
   302  			nf := t.NumField()
   303  			lines := make([]string, 0, nf)
   304  			for i := 0; i < nf; i++ {
   305  				f := t.Field(i)
   306  				var line string
   307  				if !f.Anonymous {
   308  					line = f.Name + " "
   309  				} // else the field is anonymous (an embedded type)
   310  				line += g.getType(f.Type)
   311  				t := f.Tag
   312  				if t != "" {
   313  					line += " " + escapeTag(t)
   314  				}
   315  				lines = append(lines, line)
   316  			}
   317  			return strings.Join([]string{"struct { ", strings.Join(lines, "; "), " }"}, "")
   318  		}
   319  		return t.String()
   320  	} else if t.PkgPath() == g.pkgPath {
   321  		return t.Name()
   322  	}
   323  	return g.pkgAlias(t.PkgPath()) + "." + t.Name()
   324  }
   325  
   326  // escape a struct field tag string back to source code
   327  func escapeTag(tag reflect.StructTag) string {
   328  	t := string(tag)
   329  	if strings.ContainsRune(t, '`') {
   330  		// there are ` in the string; we can't use ` to enclose the string
   331  		return strconv.Quote(t)
   332  	}
   333  	return "`" + t + "`"
   334  }
   335  
   336  // uniqueVarName returns a file-unique name that can be used for generated variables.
   337  func (g *Generator) uniqueVarName() string {
   338  	g.varCounter++
   339  	return fmt.Sprint("v", g.varCounter)
   340  }
   341  
   342  // safeName escapes unsafe characters in pkg/type name and returns a string that can be used
   343  // in encoder/decoder names for the type.
   344  func (g *Generator) safeName(t reflect.Type) string {
   345  	name := t.PkgPath()
   346  	if t.Name() == "" {
   347  		name += "anonymous"
   348  	} else {
   349  		name += "." + t.Name()
   350  	}
   351  
   352  	parts := []string{}
   353  	part := []rune{}
   354  	for _, c := range name {
   355  		if unicode.IsLetter(c) || unicode.IsDigit(c) {
   356  			part = append(part, c)
   357  		} else if len(part) > 0 {
   358  			parts = append(parts, string(part))
   359  			part = []rune{}
   360  		}
   361  	}
   362  	return joinFunctionNameParts(false, parts...)
   363  }
   364  
   365  // functionName returns a function name for a given type with a given prefix. If a function
   366  // with this prefix already exists for a type, it is returned.
   367  //
   368  // Method is used to track encoder/decoder names for the type.
   369  func (g *Generator) functionName(prefix string, t reflect.Type) string {
   370  	prefix = joinFunctionNameParts(true, "easyjson", g.hashString, prefix)
   371  	name := joinFunctionNameParts(true, prefix, g.safeName(t))
   372  
   373  	// Most of the names will be unique, try a shortcut first.
   374  	if e, ok := g.functionNames[name]; !ok || e == t {
   375  		g.functionNames[name] = t
   376  		return name
   377  	}
   378  
   379  	// Search if the function already exists.
   380  	for name1, t1 := range g.functionNames {
   381  		if t1 == t && strings.HasPrefix(name1, prefix) {
   382  			return name1
   383  		}
   384  	}
   385  
   386  	// Create a new name in the case of a clash.
   387  	for i := 1; ; i++ {
   388  		nm := fmt.Sprint(name, i)
   389  		if _, ok := g.functionNames[nm]; ok {
   390  			continue
   391  		}
   392  		g.functionNames[nm] = t
   393  		return nm
   394  	}
   395  }
   396  
   397  // DefaultFieldsNamer implements trivial naming policy equivalent to encoding/json.
   398  type DefaultFieldNamer struct{}
   399  
   400  func (DefaultFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   401  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   402  	if jsonName != "" {
   403  		return jsonName
   404  	}
   405  
   406  	return f.Name
   407  }
   408  
   409  // LowerCamelCaseFieldNamer
   410  type LowerCamelCaseFieldNamer struct{}
   411  
   412  func isLower(b byte) bool {
   413  	return b <= 122 && b >= 97
   414  }
   415  
   416  func isUpper(b byte) bool {
   417  	return b >= 65 && b <= 90
   418  }
   419  
   420  // convert HTTPRestClient to httpRestClient
   421  func lowerFirst(s string) string {
   422  	if s == "" {
   423  		return ""
   424  	}
   425  
   426  	str := ""
   427  	strlen := len(s)
   428  
   429  	/**
   430  	  Loop each char
   431  	  If is uppercase:
   432  	    If is first char, LOWER it
   433  	    If the following char is lower, LEAVE it
   434  	    If the following char is upper OR numeric, LOWER it
   435  	    If is the end of string, LEAVE it
   436  	  Else lowercase
   437  	*/
   438  
   439  	foundLower := false
   440  	for i := range s {
   441  		ch := s[i]
   442  		if isUpper(ch) {
   443  			switch {
   444  			case i == 0:
   445  				str += string(ch + 32)
   446  			case !foundLower: // Currently just a stream of capitals, eg JSONRESTS[erver]
   447  				if strlen > (i+1) && isLower(s[i+1]) {
   448  					// Next char is lower, keep this a capital
   449  					str += string(ch)
   450  				} else {
   451  					// Either at end of string or next char is capital
   452  					str += string(ch + 32)
   453  				}
   454  			default:
   455  				str += string(ch)
   456  			}
   457  		} else {
   458  			foundLower = true
   459  			str += string(ch)
   460  		}
   461  	}
   462  
   463  	return str
   464  }
   465  
   466  func (LowerCamelCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   467  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   468  	if jsonName != "" {
   469  		return jsonName
   470  	}
   471  
   472  	return lowerFirst(f.Name)
   473  }
   474  
   475  // SnakeCaseFieldNamer implements CamelCase to snake_case conversion for fields names.
   476  type SnakeCaseFieldNamer struct{}
   477  
   478  func camelToSnake(name string) string {
   479  	var ret bytes.Buffer
   480  
   481  	multipleUpper := false
   482  	var lastUpper rune
   483  	var beforeUpper rune
   484  
   485  	for _, c := range name {
   486  		// Non-lowercase character after uppercase is considered to be uppercase too.
   487  		isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c)))
   488  
   489  		if lastUpper != 0 {
   490  			// Output a delimiter if last character was either the first uppercase character
   491  			// in a row, or the last one in a row (e.g. 'S' in "HTTPServer").
   492  			// Do not output a delimiter at the beginning of the name.
   493  
   494  			firstInRow := !multipleUpper
   495  			lastInRow := !isUpper
   496  
   497  			if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' {
   498  				ret.WriteByte('_')
   499  			}
   500  			ret.WriteRune(unicode.ToLower(lastUpper))
   501  		}
   502  
   503  		// Buffer uppercase char, do not output it yet as a delimiter may be required if the
   504  		// next character is lowercase.
   505  		if isUpper {
   506  			multipleUpper = (lastUpper != 0)
   507  			lastUpper = c
   508  			continue
   509  		}
   510  
   511  		ret.WriteRune(c)
   512  		lastUpper = 0
   513  		beforeUpper = c
   514  		multipleUpper = false
   515  	}
   516  
   517  	if lastUpper != 0 {
   518  		ret.WriteRune(unicode.ToLower(lastUpper))
   519  	}
   520  	return string(ret.Bytes())
   521  }
   522  
   523  func (SnakeCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   524  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   525  	if jsonName != "" {
   526  		return jsonName
   527  	}
   528  
   529  	return camelToSnake(f.Name)
   530  }
   531  
   532  func joinFunctionNameParts(keepFirst bool, parts ...string) string {
   533  	buf := bytes.NewBufferString("")
   534  	for i, part := range parts {
   535  		if i == 0 && keepFirst {
   536  			buf.WriteString(part)
   537  		} else {
   538  			if len(part) > 0 {
   539  				buf.WriteString(strings.ToUpper(string(part[0])))
   540  			}
   541  			if len(part) > 1 {
   542  				buf.WriteString(part[1:])
   543  			}
   544  		}
   545  	}
   546  	return buf.String()
   547  }
   548  

View as plain text