1 package gen
2
3 import (
4 "bytes"
5 "fmt"
6 "hash/fnv"
7 "io"
8 "path"
9 "reflect"
10 "sort"
11 "strconv"
12 "strings"
13 "unicode"
14 )
15
16 const pkgWriter = "github.com/mailru/easyjson/jwriter"
17 const pkgLexer = "github.com/mailru/easyjson/jlexer"
18 const pkgEasyJSON = "github.com/mailru/easyjson"
19
20
21 type FieldNamer interface {
22 GetJSONFieldName(t reflect.Type, f reflect.StructField) string
23 }
24
25
26 type Generator struct {
27 out *bytes.Buffer
28
29 pkgName string
30 pkgPath string
31 buildTags string
32 hashString string
33
34 varCounter int
35
36 noStdMarshalers bool
37 omitEmpty bool
38 disallowUnknownFields bool
39 fieldNamer FieldNamer
40 simpleBytes bool
41 skipMemberNameUnescaping bool
42
43
44 imports map[string]string
45
46
47 marshalers map[reflect.Type]bool
48
49
50 typesSeen map[reflect.Type]bool
51
52
53 typesUnseen []reflect.Type
54
55
56
57 functionNames map[string]reflect.Type
58 }
59
60
61 func NewGenerator(filename string) *Generator {
62 ret := &Generator{
63 imports: map[string]string{
64 pkgWriter: "jwriter",
65 pkgLexer: "jlexer",
66 pkgEasyJSON: "easyjson",
67 "encoding/json": "json",
68 },
69 fieldNamer: DefaultFieldNamer{},
70 marshalers: make(map[reflect.Type]bool),
71 typesSeen: make(map[reflect.Type]bool),
72 functionNames: make(map[string]reflect.Type),
73 }
74
75
76
77 hash := fnv.New32()
78 hash.Write([]byte(filename))
79 ret.hashString = fmt.Sprintf("%x", hash.Sum32())
80
81 return ret
82 }
83
84
85 func (g *Generator) SetPkg(name, path string) {
86 g.pkgName = name
87 g.pkgPath = path
88 }
89
90
91 func (g *Generator) SetBuildTags(tags string) {
92 g.buildTags = tags
93 }
94
95
96 func (g *Generator) SetFieldNamer(n FieldNamer) {
97 g.fieldNamer = n
98 }
99
100
101 func (g *Generator) UseSnakeCase() {
102 g.fieldNamer = SnakeCaseFieldNamer{}
103 }
104
105
106 func (g *Generator) UseLowerCamelCase() {
107 g.fieldNamer = LowerCamelCaseFieldNamer{}
108 }
109
110
111
112 func (g *Generator) NoStdMarshalers() {
113 g.noStdMarshalers = true
114 }
115
116
117 func (g *Generator) DisallowUnknownFields() {
118 g.disallowUnknownFields = true
119 }
120
121
122 func (g *Generator) SkipMemberNameUnescaping() {
123 g.skipMemberNameUnescaping = true
124 }
125
126
127 func (g *Generator) OmitEmpty() {
128 g.omitEmpty = true
129 }
130
131
132 func (g *Generator) SimpleBytes() {
133 g.simpleBytes = true
134 }
135
136
137 func (g *Generator) addType(t reflect.Type) {
138 if g.typesSeen[t] {
139 return
140 }
141 for _, t1 := range g.typesUnseen {
142 if t1 == t {
143 return
144 }
145 }
146 g.typesUnseen = append(g.typesUnseen, t)
147 }
148
149
150
151 func (g *Generator) Add(obj interface{}) {
152 t := reflect.TypeOf(obj)
153 if t.Kind() == reflect.Ptr {
154 t = t.Elem()
155 }
156 g.addType(t)
157 g.marshalers[t] = true
158 }
159
160
161 func (g *Generator) printHeader() {
162 if g.buildTags != "" {
163 fmt.Println("// +build ", g.buildTags)
164 fmt.Println()
165 }
166 fmt.Println("// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT.")
167 fmt.Println()
168 fmt.Println("package ", g.pkgName)
169 fmt.Println()
170
171 byAlias := make(map[string]string, len(g.imports))
172 aliases := make([]string, 0, len(g.imports))
173
174 for path, alias := range g.imports {
175 aliases = append(aliases, alias)
176 byAlias[alias] = path
177 }
178
179 sort.Strings(aliases)
180 fmt.Println("import (")
181 for _, alias := range aliases {
182 fmt.Printf(" %s %q\n", alias, byAlias[alias])
183 }
184
185 fmt.Println(")")
186 fmt.Println("")
187 fmt.Println("// suppress unused package warning")
188 fmt.Println("var (")
189 fmt.Println(" _ *json.RawMessage")
190 fmt.Println(" _ *jlexer.Lexer")
191 fmt.Println(" _ *jwriter.Writer")
192 fmt.Println(" _ easyjson.Marshaler")
193 fmt.Println(")")
194
195 fmt.Println()
196 }
197
198
199 func (g *Generator) Run(out io.Writer) error {
200 g.out = &bytes.Buffer{}
201
202 for len(g.typesUnseen) > 0 {
203 t := g.typesUnseen[len(g.typesUnseen)-1]
204 g.typesUnseen = g.typesUnseen[:len(g.typesUnseen)-1]
205 g.typesSeen[t] = true
206
207 if err := g.genDecoder(t); err != nil {
208 return err
209 }
210 if err := g.genEncoder(t); err != nil {
211 return err
212 }
213
214 if !g.marshalers[t] {
215 continue
216 }
217
218 if err := g.genStructMarshaler(t); err != nil {
219 return err
220 }
221 if err := g.genStructUnmarshaler(t); err != nil {
222 return err
223 }
224 }
225 g.printHeader()
226 _, err := out.Write(g.out.Bytes())
227 return err
228 }
229
230
231 func fixPkgPathVendoring(pkgPath string) string {
232 const vendor = "/vendor/"
233 if i := strings.LastIndex(pkgPath, vendor); i != -1 {
234 return pkgPath[i+len(vendor):]
235 }
236 return pkgPath
237 }
238
239 func fixAliasName(alias string) string {
240 alias = strings.Replace(
241 strings.Replace(alias, ".", "_", -1),
242 "-",
243 "_",
244 -1,
245 )
246
247 if alias[0] == 'v' {
248 alias = "_" + alias
249 }
250 return alias
251 }
252
253
254 func (g *Generator) pkgAlias(pkgPath string) string {
255 pkgPath = fixPkgPathVendoring(pkgPath)
256 if alias := g.imports[pkgPath]; alias != "" {
257 return alias
258 }
259
260 for i := 0; ; i++ {
261 alias := fixAliasName(path.Base(pkgPath))
262 if i > 0 {
263 alias += fmt.Sprint(i)
264 }
265
266 exists := false
267 for _, v := range g.imports {
268 if v == alias {
269 exists = true
270 break
271 }
272 }
273
274 if !exists {
275 g.imports[pkgPath] = alias
276 return alias
277 }
278 }
279 }
280
281
282 func (g *Generator) getType(t reflect.Type) string {
283 if t.Name() == "" {
284 switch t.Kind() {
285 case reflect.Ptr:
286 return "*" + g.getType(t.Elem())
287 case reflect.Slice:
288 return "[]" + g.getType(t.Elem())
289 case reflect.Array:
290 return "[" + strconv.Itoa(t.Len()) + "]" + g.getType(t.Elem())
291 case reflect.Map:
292 return "map[" + g.getType(t.Key()) + "]" + g.getType(t.Elem())
293 }
294 }
295
296 if t.Name() == "" || t.PkgPath() == "" {
297 if t.Kind() == reflect.Struct {
298
299
300
301
302 nf := t.NumField()
303 lines := make([]string, 0, nf)
304 for i := 0; i < nf; i++ {
305 f := t.Field(i)
306 var line string
307 if !f.Anonymous {
308 line = f.Name + " "
309 }
310 line += g.getType(f.Type)
311 t := f.Tag
312 if t != "" {
313 line += " " + escapeTag(t)
314 }
315 lines = append(lines, line)
316 }
317 return strings.Join([]string{"struct { ", strings.Join(lines, "; "), " }"}, "")
318 }
319 return t.String()
320 } else if t.PkgPath() == g.pkgPath {
321 return t.Name()
322 }
323 return g.pkgAlias(t.PkgPath()) + "." + t.Name()
324 }
325
326
327 func escapeTag(tag reflect.StructTag) string {
328 t := string(tag)
329 if strings.ContainsRune(t, '`') {
330
331 return strconv.Quote(t)
332 }
333 return "`" + t + "`"
334 }
335
336
337 func (g *Generator) uniqueVarName() string {
338 g.varCounter++
339 return fmt.Sprint("v", g.varCounter)
340 }
341
342
343
344 func (g *Generator) safeName(t reflect.Type) string {
345 name := t.PkgPath()
346 if t.Name() == "" {
347 name += "anonymous"
348 } else {
349 name += "." + t.Name()
350 }
351
352 parts := []string{}
353 part := []rune{}
354 for _, c := range name {
355 if unicode.IsLetter(c) || unicode.IsDigit(c) {
356 part = append(part, c)
357 } else if len(part) > 0 {
358 parts = append(parts, string(part))
359 part = []rune{}
360 }
361 }
362 return joinFunctionNameParts(false, parts...)
363 }
364
365
366
367
368
369 func (g *Generator) functionName(prefix string, t reflect.Type) string {
370 prefix = joinFunctionNameParts(true, "easyjson", g.hashString, prefix)
371 name := joinFunctionNameParts(true, prefix, g.safeName(t))
372
373
374 if e, ok := g.functionNames[name]; !ok || e == t {
375 g.functionNames[name] = t
376 return name
377 }
378
379
380 for name1, t1 := range g.functionNames {
381 if t1 == t && strings.HasPrefix(name1, prefix) {
382 return name1
383 }
384 }
385
386
387 for i := 1; ; i++ {
388 nm := fmt.Sprint(name, i)
389 if _, ok := g.functionNames[nm]; ok {
390 continue
391 }
392 g.functionNames[nm] = t
393 return nm
394 }
395 }
396
397
398 type DefaultFieldNamer struct{}
399
400 func (DefaultFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
401 jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
402 if jsonName != "" {
403 return jsonName
404 }
405
406 return f.Name
407 }
408
409
410 type LowerCamelCaseFieldNamer struct{}
411
412 func isLower(b byte) bool {
413 return b <= 122 && b >= 97
414 }
415
416 func isUpper(b byte) bool {
417 return b >= 65 && b <= 90
418 }
419
420
421 func lowerFirst(s string) string {
422 if s == "" {
423 return ""
424 }
425
426 str := ""
427 strlen := len(s)
428
429
438
439 foundLower := false
440 for i := range s {
441 ch := s[i]
442 if isUpper(ch) {
443 switch {
444 case i == 0:
445 str += string(ch + 32)
446 case !foundLower:
447 if strlen > (i+1) && isLower(s[i+1]) {
448
449 str += string(ch)
450 } else {
451
452 str += string(ch + 32)
453 }
454 default:
455 str += string(ch)
456 }
457 } else {
458 foundLower = true
459 str += string(ch)
460 }
461 }
462
463 return str
464 }
465
466 func (LowerCamelCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
467 jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
468 if jsonName != "" {
469 return jsonName
470 }
471
472 return lowerFirst(f.Name)
473 }
474
475
476 type SnakeCaseFieldNamer struct{}
477
478 func camelToSnake(name string) string {
479 var ret bytes.Buffer
480
481 multipleUpper := false
482 var lastUpper rune
483 var beforeUpper rune
484
485 for _, c := range name {
486
487 isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c)))
488
489 if lastUpper != 0 {
490
491
492
493
494 firstInRow := !multipleUpper
495 lastInRow := !isUpper
496
497 if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' {
498 ret.WriteByte('_')
499 }
500 ret.WriteRune(unicode.ToLower(lastUpper))
501 }
502
503
504
505 if isUpper {
506 multipleUpper = (lastUpper != 0)
507 lastUpper = c
508 continue
509 }
510
511 ret.WriteRune(c)
512 lastUpper = 0
513 beforeUpper = c
514 multipleUpper = false
515 }
516
517 if lastUpper != 0 {
518 ret.WriteRune(unicode.ToLower(lastUpper))
519 }
520 return string(ret.Bytes())
521 }
522
523 func (SnakeCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
524 jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
525 if jsonName != "" {
526 return jsonName
527 }
528
529 return camelToSnake(f.Name)
530 }
531
532 func joinFunctionNameParts(keepFirst bool, parts ...string) string {
533 buf := bytes.NewBufferString("")
534 for i, part := range parts {
535 if i == 0 && keepFirst {
536 buf.WriteString(part)
537 } else {
538 if len(part) > 0 {
539 buf.WriteString(strings.ToUpper(string(part[0])))
540 }
541 if len(part) > 1 {
542 buf.WriteString(part[1:])
543 }
544 }
545 }
546 return buf.String()
547 }
548
View as plain text