1 package templates
2
3 import (
4 "bytes"
5 "fmt"
6 "go/types"
7 "io/fs"
8 "os"
9 "path/filepath"
10 "reflect"
11 "regexp"
12 "runtime"
13 "sort"
14 "strconv"
15 "strings"
16 "sync"
17 "text/template"
18 "unicode"
19
20 "github.com/99designs/gqlgen/internal/code"
21 "github.com/99designs/gqlgen/internal/imports"
22 )
23
24
25
26 var CurrentImports *Imports
27
28
29 type Options struct {
30
31
32
33
34 PackageName string
35
36
37
38
39 Template string
40
41
42
43 TemplateFS fs.FS
44
45
46
47 Filename string
48 RegionTags bool
49 GeneratedHeader bool
50
51 PackageDoc string
52
53 FileNotice string
54
55 Data interface{}
56 Funcs template.FuncMap
57
58
59 Packages *code.Packages
60 }
61
62 var (
63 modelNamesMu sync.Mutex
64 modelNames = make(map[string]string, 0)
65 goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]")
66 )
67
68
69
70
71
72 func Render(cfg Options) error {
73 if CurrentImports != nil {
74 panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
75 }
76 CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
77
78 funcs := Funcs()
79 for n, f := range cfg.Funcs {
80 funcs[n] = f
81 }
82
83 t := template.New("").Funcs(funcs)
84 t, err := parseTemplates(cfg, t)
85 if err != nil {
86 return err
87 }
88
89 roots := make([]string, 0, len(t.Templates()))
90 for _, template := range t.Templates() {
91
92 if strings.HasSuffix(template.Name(), "_.gotpl") ||
93
94 !strings.HasSuffix(template.Name(), ".gotpl") {
95 continue
96 }
97
98 roots = append(roots, template.Name())
99 }
100
101
102 sort.Slice(roots, func(i, j int) bool {
103
104 if strings.HasSuffix(roots[i], "!.gotpl") {
105 return true
106 }
107 if strings.HasSuffix(roots[j], "!.gotpl") {
108 return false
109 }
110 return roots[i] < roots[j]
111 })
112
113 var buf bytes.Buffer
114 for _, root := range roots {
115 if cfg.RegionTags {
116 buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
117 }
118 err := t.Lookup(root).Execute(&buf, cfg.Data)
119 if err != nil {
120 return fmt.Errorf("%s: %w", root, err)
121 }
122 if cfg.RegionTags {
123 buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
124 }
125 }
126
127 var result bytes.Buffer
128 if cfg.GeneratedHeader {
129 result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
130 }
131 if cfg.PackageDoc != "" {
132 result.WriteString(cfg.PackageDoc + "\n")
133 }
134 result.WriteString("package ")
135 result.WriteString(cfg.PackageName)
136 result.WriteString("\n\n")
137 if cfg.FileNotice != "" {
138 result.WriteString(cfg.FileNotice)
139 result.WriteString("\n\n")
140 }
141 result.WriteString("import (\n")
142 result.WriteString(CurrentImports.String())
143 result.WriteString(")\n")
144 _, err = buf.WriteTo(&result)
145 if err != nil {
146 return err
147 }
148 CurrentImports = nil
149
150 err = write(cfg.Filename, result.Bytes(), cfg.Packages)
151 if err != nil {
152 return err
153 }
154
155 cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
156 return nil
157 }
158
159 func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
160 if cfg.Template != "" {
161 var err error
162 t, err = t.New("template.gotpl").Parse(cfg.Template)
163 if err != nil {
164 return nil, fmt.Errorf("error with provided template: %w", err)
165 }
166 return t, nil
167 }
168
169 var fileSystem fs.FS
170 if cfg.TemplateFS != nil {
171 fileSystem = cfg.TemplateFS
172 } else {
173
174 _, callerFile, _, _ := runtime.Caller(2)
175 rootDir := filepath.Dir(callerFile)
176 fileSystem = os.DirFS(rootDir)
177 }
178
179 t, err := t.ParseFS(fileSystem, "*.gotpl")
180 if err != nil {
181 return nil, fmt.Errorf("locating templates: %w", err)
182 }
183
184 return t, nil
185 }
186
187 func center(width int, pad string, s string) string {
188 if len(s)+2 > width {
189 return s
190 }
191 lpad := (width - len(s)) / 2
192 rpad := width - (lpad + len(s))
193 return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
194 }
195
196 func Funcs() template.FuncMap {
197 return template.FuncMap{
198 "ucFirst": UcFirst,
199 "lcFirst": LcFirst,
200 "quote": strconv.Quote,
201 "rawQuote": rawQuote,
202 "dump": Dump,
203 "ref": ref,
204 "ts": TypeIdentifier,
205 "call": Call,
206 "prefixLines": prefixLines,
207 "notNil": notNil,
208 "reserveImport": CurrentImports.Reserve,
209 "lookupImport": CurrentImports.Lookup,
210 "go": ToGo,
211 "goPrivate": ToGoPrivate,
212 "goModelName": ToGoModelName,
213 "goPrivateModelName": ToGoPrivateModelName,
214 "add": func(a, b int) int {
215 return a + b
216 },
217 "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
218 return render(resolveName(filename, 0), tpldata)
219 },
220 }
221 }
222
223 func UcFirst(s string) string {
224 if s == "" {
225 return ""
226 }
227 r := []rune(s)
228 r[0] = unicode.ToUpper(r[0])
229 return string(r)
230 }
231
232 func LcFirst(s string) string {
233 if s == "" {
234 return ""
235 }
236
237 r := []rune(s)
238 r[0] = unicode.ToLower(r[0])
239 return string(r)
240 }
241
242 func isDelimiter(c rune) bool {
243 return c == '-' || c == '_' || unicode.IsSpace(c)
244 }
245
246 func ref(p types.Type) string {
247 return CurrentImports.LookupType(p)
248 }
249
250 func Call(p *types.Func) string {
251 pkg := CurrentImports.Lookup(p.Pkg().Path())
252
253 if pkg != "" {
254 pkg += "."
255 }
256
257 if p.Type() != nil {
258
259 ref(p.Type().(*types.Signature).Results().At(0).Type())
260 }
261
262 return pkg + p.Name()
263 }
264
265 func resetModelNames() {
266 modelNamesMu.Lock()
267 defer modelNamesMu.Unlock()
268 modelNames = make(map[string]string, 0)
269 }
270
271 func buildGoModelNameKey(parts []string) string {
272 const sep = ":"
273 return strings.Join(parts, sep)
274 }
275
276 func goModelName(primaryToGoFunc func(string) string, parts []string) string {
277 modelNamesMu.Lock()
278 defer modelNamesMu.Unlock()
279
280 var (
281 goNameKey string
282 partLen int
283
284 nameExists = func(n string) bool {
285 for _, v := range modelNames {
286 if n == v {
287 return true
288 }
289 }
290 return false
291 }
292
293 applyToGoFunc = func(parts []string) string {
294 var out string
295 switch len(parts) {
296 case 0:
297 return ""
298 case 1:
299 return primaryToGoFunc(parts[0])
300 default:
301 out = primaryToGoFunc(parts[0])
302 }
303 for _, p := range parts[1:] {
304 out = fmt.Sprintf("%s%s", out, ToGo(p))
305 }
306 return out
307 }
308
309 applyValidGoName = func(parts []string) string {
310 var out string
311 for _, p := range parts {
312 out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p))
313 }
314 return out
315 }
316 )
317
318
319 goNameKey = buildGoModelNameKey(parts)
320
321
322 if goName, ok := modelNames[goNameKey]; ok {
323 return goName
324 }
325
326
327 if goName := applyToGoFunc(parts); !nameExists(goName) {
328 modelNames[goNameKey] = goName
329 return goName
330 }
331
332
333 partLen = len(parts)
334
335
336 if partLen == 1 {
337 base := applyToGoFunc(parts)
338 for i := 0; ; i++ {
339 tmp := fmt.Sprintf("%s%d", base, i)
340 if !nameExists(tmp) {
341 modelNames[goNameKey] = tmp
342 return tmp
343 }
344 }
345 }
346
347
348 for i := partLen - 1; i >= 1; i-- {
349 tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:]))
350 if !nameExists(tmp) {
351 modelNames[goNameKey] = tmp
352 return tmp
353 }
354 }
355
356
357 base := applyToGoFunc(parts)
358 for i := 0; ; i++ {
359 tmp := fmt.Sprintf("%s%d", base, i)
360 if !nameExists(tmp) {
361 modelNames[goNameKey] = tmp
362 return tmp
363 }
364 }
365 }
366
367 func ToGoModelName(parts ...string) string {
368 return goModelName(ToGo, parts)
369 }
370
371 func ToGoPrivateModelName(parts ...string) string {
372 return goModelName(ToGoPrivate, parts)
373 }
374
375 func replaceInvalidCharacters(in string) string {
376 return goNameRe.ReplaceAllLiteralString(in, "_")
377 }
378
379 func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) {
380 return func(info *wordInfo) {
381 word := info.Word
382
383 switch {
384 case private && info.WordOffset == 0:
385 if strings.ToUpper(word) == word || strings.ToLower(word) == word {
386
387 word = strings.ToLower(info.Word)
388 } else {
389
390 word = LcFirst(info.Word)
391 }
392
393 case info.MatchCommonInitial:
394 word = strings.ToUpper(word)
395
396 case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word):
397
398
399 word = UcFirst(strings.ToLower(word))
400 }
401
402 *nameRunes = append(*nameRunes, []rune(word)...)
403 }
404 }
405
406 func ToGo(name string) string {
407 if name == "_" {
408 return "_"
409 }
410 runes := make([]rune, 0, len(name))
411
412 wordWalker(name, wordWalkerFunc(false, &runes))
413
414 return string(runes)
415 }
416
417 func ToGoPrivate(name string) string {
418 if name == "_" {
419 return "_"
420 }
421 runes := make([]rune, 0, len(name))
422
423 wordWalker(name, wordWalkerFunc(true, &runes))
424
425 return sanitizeKeywords(string(runes))
426 }
427
428 type wordInfo struct {
429 WordOffset int
430 Word string
431 MatchCommonInitial bool
432 HasCommonInitial bool
433 }
434
435
436
437 func wordWalker(str string, f func(*wordInfo)) {
438 runes := []rune(strings.TrimFunc(str, isDelimiter))
439 w, i, wo := 0, 0, 0
440 hasCommonInitial := false
441 for i+1 <= len(runes) {
442 eow := false
443 switch {
444 case i+1 == len(runes):
445 eow = true
446 case isDelimiter(runes[i+1]):
447
448 eow = true
449 n := 1
450 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
451 n++
452 }
453
454
455 if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
456 n--
457 }
458
459 copy(runes[i+1:], runes[i+n+1:])
460 runes = runes[:len(runes)-n]
461 case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
462
463 eow = true
464 }
465 i++
466
467 initialisms := GetInitialisms()
468
469 word := string(runes[w:i])
470 if !eow && initialisms[word] && !unicode.IsLower(runes[i]) {
471
472
473
474 } else if !eow {
475 if initialisms[word] {
476 hasCommonInitial = true
477 }
478 continue
479 }
480
481 matchCommonInitial := false
482 upperWord := strings.ToUpper(word)
483 if initialisms[upperWord] {
484
485
486
487
488
489
490
491
492
493
494
495 switch upperWord {
496 case "ID", "IP":
497 if word == str[:2] && !eow && len(str) > 3 && unicode.IsUpper(runes[3]) {
498 continue
499 }
500 }
501 hasCommonInitial = true
502 matchCommonInitial = true
503 }
504
505 f(&wordInfo{
506 WordOffset: wo,
507 Word: word,
508 MatchCommonInitial: matchCommonInitial,
509 HasCommonInitial: hasCommonInitial,
510 })
511 hasCommonInitial = false
512 w = i
513 wo++
514 }
515 }
516
517 var keywords = []string{
518 "break",
519 "default",
520 "func",
521 "interface",
522 "select",
523 "case",
524 "defer",
525 "go",
526 "map",
527 "struct",
528 "chan",
529 "else",
530 "goto",
531 "package",
532 "switch",
533 "const",
534 "fallthrough",
535 "if",
536 "range",
537 "type",
538 "continue",
539 "for",
540 "import",
541 "return",
542 "var",
543 "_",
544 }
545
546
547 func sanitizeKeywords(name string) string {
548 for _, k := range keywords {
549 if name == k {
550 return name + "Arg"
551 }
552 }
553 return name
554 }
555
556 func rawQuote(s string) string {
557 return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
558 }
559
560 func notNil(field string, data interface{}) bool {
561 v := reflect.ValueOf(data)
562
563 if v.Kind() == reflect.Ptr {
564 v = v.Elem()
565 }
566 if v.Kind() != reflect.Struct {
567 return false
568 }
569 val := v.FieldByName(field)
570
571 return val.IsValid() && !val.IsNil()
572 }
573
574 func Dump(val interface{}) string {
575 switch val := val.(type) {
576 case int:
577 return strconv.Itoa(val)
578 case int64:
579 return fmt.Sprintf("%d", val)
580 case float64:
581 return fmt.Sprintf("%f", val)
582 case string:
583 return strconv.Quote(val)
584 case bool:
585 return strconv.FormatBool(val)
586 case nil:
587 return "nil"
588 case []interface{}:
589 var parts []string
590 for _, part := range val {
591 parts = append(parts, Dump(part))
592 }
593 return "[]interface{}{" + strings.Join(parts, ",") + "}"
594 case map[string]interface{}:
595 buf := bytes.Buffer{}
596 buf.WriteString("map[string]interface{}{")
597 var keys []string
598 for key := range val {
599 keys = append(keys, key)
600 }
601 sort.Strings(keys)
602
603 for _, key := range keys {
604 data := val[key]
605
606 buf.WriteString(strconv.Quote(key))
607 buf.WriteString(":")
608 buf.WriteString(Dump(data))
609 buf.WriteString(",")
610 }
611 buf.WriteString("}")
612 return buf.String()
613 default:
614 panic(fmt.Errorf("unsupported type %T", val))
615 }
616 }
617
618 func prefixLines(prefix, s string) string {
619 return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
620 }
621
622 func resolveName(name string, skip int) string {
623 if name[0] == '.' {
624
625 _, callerFile, _, _ := runtime.Caller(skip + 1)
626 return filepath.Join(filepath.Dir(callerFile), name[1:])
627 }
628
629
630 _, callerFile, _, _ := runtime.Caller(0)
631 return filepath.Join(filepath.Dir(callerFile), name)
632 }
633
634 func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
635 t := template.New("").Funcs(Funcs())
636
637 b, err := os.ReadFile(filename)
638 if err != nil {
639 return nil, err
640 }
641
642 t, err = t.New(filepath.Base(filename)).Parse(string(b))
643 if err != nil {
644 panic(err)
645 }
646
647 buf := &bytes.Buffer{}
648 return buf, t.Execute(buf, tpldata)
649 }
650
651 func write(filename string, b []byte, packages *code.Packages) error {
652 err := os.MkdirAll(filepath.Dir(filename), 0o755)
653 if err != nil {
654 return fmt.Errorf("failed to create directory: %w", err)
655 }
656
657 formatted, err := imports.Prune(filename, b, packages)
658 if err != nil {
659 fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
660 formatted = b
661 }
662
663 err = os.WriteFile(filename, formatted, 0o644)
664 if err != nil {
665 return fmt.Errorf("failed to write %s: %w", filename, err)
666 }
667
668 return nil
669 }
670
671 var pkgReplacer = strings.NewReplacer(
672 "/", "ᚋ",
673 ".", "ᚗ",
674 "-", "ᚑ",
675 "~", "א",
676 )
677
678 func TypeIdentifier(t types.Type) string {
679 res := ""
680 for {
681 switch it := t.(type) {
682 case *types.Pointer:
683 t.Underlying()
684 res += "ᚖ"
685 t = it.Elem()
686 case *types.Slice:
687 res += "ᚕ"
688 t = it.Elem()
689 case *types.Named:
690 res += pkgReplacer.Replace(it.Obj().Pkg().Path())
691 res += "ᚐ"
692 res += it.Obj().Name()
693 return res
694 case *types.Basic:
695 res += it.Name()
696 return res
697 case *types.Map:
698 res += "map"
699 return res
700 case *types.Interface:
701 res += "interface"
702 return res
703 default:
704 panic(fmt.Errorf("unexpected type %T", it))
705 }
706 }
707 }
708
709
710
711
712 var CommonInitialisms = map[string]bool{
713 "ACL": true,
714 "API": true,
715 "ASCII": true,
716 "CPU": true,
717 "CSS": true,
718 "CSV": true,
719 "DNS": true,
720 "EOF": true,
721 "GUID": true,
722 "HTML": true,
723 "HTTP": true,
724 "HTTPS": true,
725 "ICMP": true,
726 "ID": true,
727 "IP": true,
728 "JSON": true,
729 "KVK": true,
730 "LHS": true,
731 "PDF": true,
732 "PGP": true,
733 "QPS": true,
734 "QR": true,
735 "RAM": true,
736 "RHS": true,
737 "RPC": true,
738 "SLA": true,
739 "SMTP": true,
740 "SQL": true,
741 "SSH": true,
742 "SVG": true,
743 "TCP": true,
744 "TLS": true,
745 "TTL": true,
746 "UDP": true,
747 "UI": true,
748 "UID": true,
749 "URI": true,
750 "URL": true,
751 "UTF8": true,
752 "UUID": true,
753 "VM": true,
754 "XML": true,
755 "XMPP": true,
756 "XSRF": true,
757 "XSS": true,
758 }
759
760
761 var GetInitialisms = func() map[string]bool {
762 return CommonInitialisms
763 }
764
View as plain text