     1  package gen
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"hash/fnv"
     7  	"io"
     8  	"path"
     9  	"reflect"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"unicode"
    14  )
    16  const pkgWriter = "github.com/mailru/easyjson/jwriter"
    17  const pkgLexer = "github.com/mailru/easyjson/jlexer"
    18  const pkgEasyJSON = "github.com/mailru/easyjson"
    20  // FieldNamer defines a policy for generating names for struct fields.
    21  type FieldNamer interface {
    22  	GetJSONFieldName(t reflect.Type, f reflect.StructField) string
    23  }
    25  // Generator generates the requested marshaler/unmarshalers.
    26  type Generator struct {
    27  	out *bytes.Buffer
    29  	pkgName    string
    30  	pkgPath    string
    31  	buildTags  string
    32  	hashString string
    34  	varCounter int
    36  	noStdMarshalers          bool
    37  	omitEmpty                bool
    38  	disallowUnknownFields    bool
    39  	fieldNamer               FieldNamer
    40  	simpleBytes              bool
    41  	skipMemberNameUnescaping bool
    43  	// package path to local alias map for tracking imports
    44  	imports map[string]string
    46  	// types that marshalers were requested for by user
    47  	marshalers map[reflect.Type]bool
    49  	// types that encoders were already generated for
    50  	typesSeen map[reflect.Type]bool
    52  	// types that encoders were requested for (e.g. by encoders of other types)
    53  	typesUnseen []reflect.Type
    55  	// function name to relevant type maps to track names of de-/encoders in
    56  	// case of a name clash or unnamed structs
    57  	functionNames map[string]reflect.Type
    58  }
    60  // NewGenerator initializes and returns a Generator.
    61  func NewGenerator(filename string) *Generator {
    62  	ret := &Generator{
    63  		imports: map[string]string{
    64  			pkgWriter:       "jwriter",
    65  			pkgLexer:        "jlexer",
    66  			pkgEasyJSON:     "easyjson",
    67  			"encoding/json": "json",
    68  		},
    69  		fieldNamer:    DefaultFieldNamer{},
    70  		marshalers:    make(map[reflect.Type]bool),
    71  		typesSeen:     make(map[reflect.Type]bool),
    72  		functionNames: make(map[string]reflect.Type),
    73  	}
    75  	// Use a file-unique prefix on all auxiliary funcs to avoid
    76  	// name clashes.
    77  	hash := fnv.New32()
    78  	hash.Write([]byte(filename))
    79  	ret.hashString = fmt.Sprintf("%x", hash.Sum32())
    81  	return ret
    82  }
    84  // SetPkg sets the name and path of output package.
    85  func (g *Generator) SetPkg(name, path string) {
    86  	g.pkgName = name
    87  	g.pkgPath = path
    88  }
    90  // SetBuildTags sets build tags for the output file.
    91  func (g *Generator) SetBuildTags(tags string) {
    92  	g.buildTags = tags
    93  }
    95  // SetFieldNamer sets field naming strategy.
    96  func (g *Generator) SetFieldNamer(n FieldNamer) {
    97  	g.fieldNamer = n
    98  }
   100  // UseSnakeCase sets snake_case field naming strategy.
   101  func (g *Generator) UseSnakeCase() {
   102  	g.fieldNamer = SnakeCaseFieldNamer{}
   103  }
   105  // UseLowerCamelCase sets lowerCamelCase field naming strategy.
   106  func (g *Generator) UseLowerCamelCase() {
   107  	g.fieldNamer = LowerCamelCaseFieldNamer{}
   108  }
   110  // NoStdMarshalers instructs not to generate standard MarshalJSON/UnmarshalJSON
   111  // methods (only the custom interface).
   112  func (g *Generator) NoStdMarshalers() {
   113  	g.noStdMarshalers = true
   114  }
   116  // DisallowUnknownFields instructs not to skip unknown fields in json and return error.
   117  func (g *Generator) DisallowUnknownFields() {
   118  	g.disallowUnknownFields = true
   119  }
   121  // SkipMemberNameUnescaping instructs to skip member names unescaping to improve performance
   122  func (g *Generator) SkipMemberNameUnescaping() {
   123  	g.skipMemberNameUnescaping = true
   124  }
   126  // OmitEmpty triggers `json=",omitempty"` behaviour by default.
   127  func (g *Generator) OmitEmpty() {
   128  	g.omitEmpty = true
   129  }
   131  // SimpleBytes triggers generate output bytes as slice byte
   132  func (g *Generator) SimpleBytes() {
   133  	g.simpleBytes = true
   134  }
   136  // addTypes requests to generate encoding/decoding funcs for the given type.
   137  func (g *Generator) addType(t reflect.Type) {
   138  	if g.typesSeen[t] {
   139  		return
   140  	}
   141  	for _, t1 := range g.typesUnseen {
   142  		if t1 == t {
   143  			return
   144  		}
   145  	}
   146  	g.typesUnseen = append(g.typesUnseen, t)
   147  }
   149  // Add requests to generate marshaler/unmarshalers and encoding/decoding
   150  // funcs for the type of given object.
   151  func (g *Generator) Add(obj interface{}) {
   152  	t := reflect.TypeOf(obj)
   153  	if t.Kind() == reflect.Ptr {
   154  		t = t.Elem()
   155  	}
   156  	g.addType(t)
   157  	g.marshalers[t] = true
   158  }
   160  // printHeader prints package declaration and imports.
   161  func (g *Generator) printHeader() {
   162  	if g.buildTags != "" {
   163  		fmt.Println("// +build ", g.buildTags)
   164  		fmt.Println()
   165  	}
   166  	fmt.Println("// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT.")
   167  	fmt.Println()
   168  	fmt.Println("package ", g.pkgName)
   169  	fmt.Println()
   171  	byAlias := make(map[string]string, len(g.imports))
   172  	aliases := make([]string, 0, len(g.imports))
   174  	for path, alias := range g.imports {
   175  		aliases = append(aliases, alias)
   176  		byAlias[alias] = path
   177  	}
   179  	sort.Strings(aliases)
   180  	fmt.Println("import (")
   181  	for _, alias := range aliases {
   182  		fmt.Printf("  %s %q\n", alias, byAlias[alias])
   183  	}
   185  	fmt.Println(")")
   186  	fmt.Println("")
   187  	fmt.Println("// suppress unused package warning")
   188  	fmt.Println("var (")
   189  	fmt.Println("   _ *json.RawMessage")
   190  	fmt.Println("   _ *jlexer.Lexer")
   191  	fmt.Println("   _ *jwriter.Writer")
   192  	fmt.Println("   _ easyjson.Marshaler")
   193  	fmt.Println(")")
   195  	fmt.Println()
   196  }
   198  // Run runs the generator and outputs generated code to out.
   199  func (g *Generator) Run(out io.Writer) error {
   200  	g.out = &bytes.Buffer{}
   202  	for len(g.typesUnseen) > 0 {
   203  		t := g.typesUnseen[len(g.typesUnseen)-1]
   204  		g.typesUnseen = g.typesUnseen[:len(g.typesUnseen)-1]
   205  		g.typesSeen[t] = true
   207  		if err := g.genDecoder(t); err != nil {
   208  			return err
   209  		}
   210  		if err := g.genEncoder(t); err != nil {
   211  			return err
   212  		}
   214  		if !g.marshalers[t] {
   215  			continue
   216  		}
   218  		if err := g.genStructMarshaler(t); err != nil {
   219  			return err
   220  		}
   221  		if err := g.genStructUnmarshaler(t); err != nil {
   222  			return err
   223  		}
   224  	}
   225  	g.printHeader()
   226  	_, err := out.Write(g.out.Bytes())
   227  	return err
   228  }
   230  // fixes vendored paths
   231  func fixPkgPathVendoring(pkgPath string) string {
   232  	const vendor = "/vendor/"
   233  	if i := strings.LastIndex(pkgPath, vendor); i != -1 {
   234  		return pkgPath[i+len(vendor):]
   235  	}
   236  	return pkgPath
   237  }
   239  func fixAliasName(alias string) string {
   240  	alias = strings.Replace(
   241  		strings.Replace(alias, ".", "_", -1),
   242  		"-",
   243  		"_",
   244  		-1,
   245  	)
   247  	if alias[0] == 'v' { // to void conflicting with var names, say v1
   248  		alias = "_" + alias
   249  	}
   250  	return alias
   251  }
   253  // pkgAlias creates and returns and import alias for a given package.
   254  func (g *Generator) pkgAlias(pkgPath string) string {
   255  	pkgPath = fixPkgPathVendoring(pkgPath)
   256  	if alias := g.imports[pkgPath]; alias != "" {
   257  		return alias
   258  	}
   260  	for i := 0; ; i++ {
   261  		alias := fixAliasName(path.Base(pkgPath))
   262  		if i > 0 {
   263  			alias += fmt.Sprint(i)
   264  		}
   266  		exists := false
   267  		for _, v := range g.imports {
   268  			if v == alias {
   269  				exists = true
   270  				break
   271  			}
   272  		}
   274  		if !exists {
   275  			g.imports[pkgPath] = alias
   276  			return alias
   277  		}
   278  	}
   279  }
   281  // getType return the textual type name of given type that can be used in generated code.
   282  func (g *Generator) getType(t reflect.Type) string {
   283  	if t.Name() == "" {
   284  		switch t.Kind() {
   285  		case reflect.Ptr:
   286  			return "*" + g.getType(t.Elem())
   287  		case reflect.Slice:
   288  			return "[]" + g.getType(t.Elem())
   289  		case reflect.Array:
   290  			return "[" + strconv.Itoa(t.Len()) + "]" + g.getType(t.Elem())
   291  		case reflect.Map:
   292  			return "map[" + g.getType(t.Key()) + "]" + g.getType(t.Elem())
   293  		}
   294  	}
   296  	if t.Name() == "" || t.PkgPath() == "" {
   297  		if t.Kind() == reflect.Struct {
   298  			// the fields of an anonymous struct can have named types,
   299  			// and t.String() will not be sufficient because it does not
   300  			// remove the package name when it matches g.pkgPath.
   301  			// so we convert by hand
   302  			nf := t.NumField()
   303  			lines := make([]string, 0, nf)
   304  			for i := 0; i < nf; i++ {
   305  				f := t.Field(i)
   306  				var line string
   307  				if !f.Anonymous {
   308  					line = f.Name + " "
   309  				} // else the field is anonymous (an embedded type)
   310  				line += g.getType(f.Type)
   311  				t := f.Tag
   312  				if t != "" {
   313  					line += " " + escapeTag(t)
   314  				}
   315  				lines = append(lines, line)
   316  			}
   317  			return strings.Join([]string{"struct { ", strings.Join(lines, "; "), " }"}, "")
   318  		}
   319  		return t.String()
   320  	} else if t.PkgPath() == g.pkgPath {
   321  		return t.Name()
   322  	}
   323  	return g.pkgAlias(t.PkgPath()) + "." + t.Name()
   324  }
   326  // escape a struct field tag string back to source code
   327  func escapeTag(tag reflect.StructTag) string {
   328  	t := string(tag)
   329  	if strings.ContainsRune(t, '`') {
   330  		// there are ` in the string; we can't use ` to enclose the string
   331  		return strconv.Quote(t)
   332  	}
   333  	return "`" + t + "`"
   334  }
   336  // uniqueVarName returns a file-unique name that can be used for generated variables.
   337  func (g *Generator) uniqueVarName() string {
   338  	g.varCounter++
   339  	return fmt.Sprint("v", g.varCounter)
   340  }
   342  // safeName escapes unsafe characters in pkg/type name and returns a string that can be used
   343  // in encoder/decoder names for the type.
   344  func (g *Generator) safeName(t reflect.Type) string {
   345  	name := t.PkgPath()
   346  	if t.Name() == "" {
   347  		name += "anonymous"
   348  	} else {
   349  		name += "." + t.Name()
   350  	}
   352  	parts := []string{}
   353  	part := []rune{}
   354  	for _, c := range name {
   355  		if unicode.IsLetter(c) || unicode.IsDigit(c) {
   356  			part = append(part, c)
   357  		} else if len(part) > 0 {
   358  			parts = append(parts, string(part))
   359  			part = []rune{}
   360  		}
   361  	}
   362  	return joinFunctionNameParts(false, parts...)
   363  }
   365  // functionName returns a function name for a given type with a given prefix. If a function
   366  // with this prefix already exists for a type, it is returned.
   367  //
   368  // Method is used to track encoder/decoder names for the type.
   369  func (g *Generator) functionName(prefix string, t reflect.Type) string {
   370  	prefix = joinFunctionNameParts(true, "easyjson", g.hashString, prefix)
   371  	name := joinFunctionNameParts(true, prefix, g.safeName(t))
   373  	// Most of the names will be unique, try a shortcut first.
   374  	if e, ok := g.functionNames[name]; !ok || e == t {
   375  		g.functionNames[name] = t
   376  		return name
   377  	}
   379  	// Search if the function already exists.
   380  	for name1, t1 := range g.functionNames {
   381  		if t1 == t && strings.HasPrefix(name1, prefix) {
   382  			return name1
   383  		}
   384  	}
   386  	// Create a new name in the case of a clash.
   387  	for i := 1; ; i++ {
   388  		nm := fmt.Sprint(name, i)
   389  		if _, ok := g.functionNames[nm]; ok {
   390  			continue
   391  		}
   392  		g.functionNames[nm] = t
   393  		return nm
   394  	}
   395  }
   397  // DefaultFieldsNamer implements trivial naming policy equivalent to encoding/json.
   398  type DefaultFieldNamer struct{}
   400  func (DefaultFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   401  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   402  	if jsonName != "" {
   403  		return jsonName
   404  	}
   406  	return f.Name
   407  }
   409  // LowerCamelCaseFieldNamer
   410  type LowerCamelCaseFieldNamer struct{}
   412  func isLower(b byte) bool {
   413  	return b <= 122 && b >= 97
   414  }
   416  func isUpper(b byte) bool {
   417  	return b >= 65 && b <= 90
   418  }
   420  // convert HTTPRestClient to httpRestClient
   421  func lowerFirst(s string) string {
   422  	if s == "" {
   423  		return ""
   424  	}
   426  	str := ""
   427  	strlen := len(s)
   429  	/**
   430  	  Loop each char
   431  	  If is uppercase:
   432  	    If is first char, LOWER it
   433  	    If the following char is lower, LEAVE it
   434  	    If the following char is upper OR numeric, LOWER it
   435  	    If is the end of string, LEAVE it
   436  	  Else lowercase
   437  	*/
   439  	foundLower := false
   440  	for i := range s {
   441  		ch := s[i]
   442  		if isUpper(ch) {
   443  			switch {
   444  			case i == 0:
   445  				str += string(ch + 32)
   446  			case !foundLower: // Currently just a stream of capitals, eg JSONRESTS[erver]
   447  				if strlen > (i+1) && isLower(s[i+1]) {
   448  					// Next char is lower, keep this a capital
   449  					str += string(ch)
   450  				} else {
   451  					// Either at end of string or next char is capital
   452  					str += string(ch + 32)
   453  				}
   454  			default:
   455  				str += string(ch)
   456  			}
   457  		} else {
   458  			foundLower = true
   459  			str += string(ch)
   460  		}
   461  	}
   463  	return str
   464  }
   466  func (LowerCamelCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   467  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   468  	if jsonName != "" {
   469  		return jsonName
   470  	}
   472  	return lowerFirst(f.Name)
   473  }
   475  // SnakeCaseFieldNamer implements CamelCase to snake_case conversion for fields names.
   476  type SnakeCaseFieldNamer struct{}
   478  func camelToSnake(name string) string {
   479  	var ret bytes.Buffer
   481  	multipleUpper := false
   482  	var lastUpper rune
   483  	var beforeUpper rune
   485  	for _, c := range name {
   486  		// Non-lowercase character after uppercase is considered to be uppercase too.
   487  		isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c)))
   489  		if lastUpper != 0 {
   490  			// Output a delimiter if last character was either the first uppercase character
   491  			// in a row, or the last one in a row (e.g. 'S' in "HTTPServer").
   492  			// Do not output a delimiter at the beginning of the name.
   494  			firstInRow := !multipleUpper
   495  			lastInRow := !isUpper
   497  			if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' {
   498  				ret.WriteByte('_')
   499  			}
   500  			ret.WriteRune(unicode.ToLower(lastUpper))
   501  		}
   503  		// Buffer uppercase char, do not output it yet as a delimiter may be required if the
   504  		// next character is lowercase.
   505  		if isUpper {
   506  			multipleUpper = (lastUpper != 0)
   507  			lastUpper = c
   508  			continue
   509  		}
   511  		ret.WriteRune(c)
   512  		lastUpper = 0
   513  		beforeUpper = c
   514  		multipleUpper = false
   515  	}
   517  	if lastUpper != 0 {
   518  		ret.WriteRune(unicode.ToLower(lastUpper))
   519  	}
   520  	return string(ret.Bytes())
   521  }
   523  func (SnakeCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
   524  	jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
   525  	if jsonName != "" {
   526  		return jsonName
   527  	}
   529  	return camelToSnake(f.Name)
   530  }
   532  func joinFunctionNameParts(keepFirst bool, parts ...string) string {
   533  	buf := bytes.NewBufferString("")
   534  	for i, part := range parts {
   535  		if i == 0 && keepFirst {
   536  			buf.WriteString(part)
   537  		} else {
   538  			if len(part) > 0 {
   539  				buf.WriteString(strings.ToUpper(string(part[0])))
   540  			}
   541  			if len(part) > 1 {
   542  				buf.WriteString(part[1:])
   543  			}
   544  		}
   545  	}
   546  	return buf.String()
   547  }

