1 package config
2
3 import (
4 "bytes"
5 "fmt"
6 "go/types"
7 "io"
8 "os"
9 "path/filepath"
10 "regexp"
11 "sort"
12 "strings"
13
14 "github.com/vektah/gqlparser/v2"
15 "github.com/vektah/gqlparser/v2/ast"
16 "golang.org/x/tools/go/packages"
17 "gopkg.in/yaml.v3"
18
19 "github.com/99designs/gqlgen/codegen/templates"
20 "github.com/99designs/gqlgen/internal/code"
21 )
22
23 type Config struct {
24 SchemaFilename StringList `yaml:"schema,omitempty"`
25 Exec ExecConfig `yaml:"exec"`
26 Model PackageConfig `yaml:"model,omitempty"`
27 Federation PackageConfig `yaml:"federation,omitempty"`
28 Resolver ResolverConfig `yaml:"resolver,omitempty"`
29 AutoBind []string `yaml:"autobind"`
30 Models TypeMap `yaml:"models,omitempty"`
31 StructTag string `yaml:"struct_tag,omitempty"`
32 Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
33 GoBuildTags StringList `yaml:"go_build_tags,omitempty"`
34 GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"`
35 OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
36 OmitGetters bool `yaml:"omit_getters,omitempty"`
37 OmitInterfaceChecks bool `yaml:"omit_interface_checks,omitempty"`
38 OmitComplexity bool `yaml:"omit_complexity,omitempty"`
39 OmitGQLGenFileNotice bool `yaml:"omit_gqlgen_file_notice,omitempty"`
40 OmitGQLGenVersionInFileNotice bool `yaml:"omit_gqlgen_version_in_file_notice,omitempty"`
41 OmitRootModels bool `yaml:"omit_root_models,omitempty"`
42 OmitResolverFields bool `yaml:"omit_resolver_fields,omitempty"`
43 StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
44 ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"`
45 ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
46 NullableInputOmittable bool `yaml:"nullable_input_omittable,omitempty"`
47 EnableModelJsonOmitemptyTag *bool `yaml:"enable_model_json_omitempty_tag,omitempty"`
48 SkipValidation bool `yaml:"skip_validation,omitempty"`
49 SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
50 Sources []*ast.Source `yaml:"-"`
51 Packages *code.Packages `yaml:"-"`
52 Schema *ast.Schema `yaml:"-"`
53
54
55 Federated bool `yaml:"federated,omitempty"`
56 }
57
58 var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
59
60
61 func DefaultConfig() *Config {
62 return &Config{
63 SchemaFilename: StringList{"schema.graphql"},
64 Model: PackageConfig{Filename: "models_gen.go"},
65 Exec: ExecConfig{Filename: "generated.go"},
66 Directives: map[string]DirectiveConfig{},
67 Models: TypeMap{},
68 StructFieldsAlwaysPointers: true,
69 ReturnPointersInUmarshalInput: false,
70 ResolversAlwaysReturnPointers: true,
71 NullableInputOmittable: false,
72 }
73 }
74
75
76 func LoadDefaultConfig() (*Config, error) {
77 config := DefaultConfig()
78
79 for _, filename := range config.SchemaFilename {
80 filename = filepath.ToSlash(filename)
81 var err error
82 var schemaRaw []byte
83 schemaRaw, err = os.ReadFile(filename)
84 if err != nil {
85 return nil, fmt.Errorf("unable to open schema: %w", err)
86 }
87
88 config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
89 }
90
91 return config, nil
92 }
93
94
95
96 func LoadConfigFromDefaultLocations() (*Config, error) {
97 cfgFile, err := findCfg()
98 if err != nil {
99 return nil, err
100 }
101
102 err = os.Chdir(filepath.Dir(cfgFile))
103 if err != nil {
104 return nil, fmt.Errorf("unable to enter config dir: %w", err)
105 }
106 return LoadConfig(cfgFile)
107 }
108
109 var path2regex = strings.NewReplacer(
110 `.`, `\.`,
111 `*`, `.+`,
112 `\`, `[\\/]`,
113 `/`, `[\\/]`,
114 )
115
116
117 func LoadConfig(filename string) (*Config, error) {
118 b, err := os.ReadFile(filename)
119 if err != nil {
120 return nil, fmt.Errorf("unable to read config: %w", err)
121 }
122
123 return ReadConfig(bytes.NewReader(b))
124 }
125
126 func ReadConfig(cfgFile io.Reader) (*Config, error) {
127 config := DefaultConfig()
128
129 dec := yaml.NewDecoder(cfgFile)
130 dec.KnownFields(true)
131
132 if err := dec.Decode(config); err != nil {
133 return nil, fmt.Errorf("unable to parse config: %w", err)
134 }
135
136 if err := CompleteConfig(config); err != nil {
137 return nil, err
138 }
139
140 return config, nil
141 }
142
143
144
145 func CompleteConfig(config *Config) error {
146 defaultDirectives := map[string]DirectiveConfig{
147 "skip": {SkipRuntime: true},
148 "include": {SkipRuntime: true},
149 "deprecated": {SkipRuntime: true},
150 "specifiedBy": {SkipRuntime: true},
151 }
152
153 for key, value := range defaultDirectives {
154 if _, defined := config.Directives[key]; !defined {
155 config.Directives[key] = value
156 }
157 }
158
159 preGlobbing := config.SchemaFilename
160 config.SchemaFilename = StringList{}
161 for _, f := range preGlobbing {
162 var matches []string
163
164
165
166 if strings.Contains(f, "**") {
167 pathParts := strings.SplitN(f, "**", 2)
168 rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
169
170
171 globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
172
173 if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
174 if err != nil {
175 return err
176 }
177
178 if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
179 matches = append(matches, path)
180 }
181
182 return nil
183 }); err != nil {
184 return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err)
185 }
186 } else {
187 var err error
188 matches, err = filepath.Glob(f)
189 if err != nil {
190 return fmt.Errorf("failed to glob schema filename %s: %w", f, err)
191 }
192 }
193
194 for _, m := range matches {
195 if config.SchemaFilename.Has(m) {
196 continue
197 }
198 config.SchemaFilename = append(config.SchemaFilename, m)
199 }
200 }
201
202 for _, filename := range config.SchemaFilename {
203 filename = filepath.ToSlash(filename)
204 var err error
205 var schemaRaw []byte
206 schemaRaw, err = os.ReadFile(filename)
207 if err != nil {
208 return fmt.Errorf("unable to open schema: %w", err)
209 }
210
211 config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
212 }
213
214 config.GoInitialisms.setInitialisms()
215
216 return nil
217 }
218
219 func (c *Config) Init() error {
220 if c.Packages == nil {
221 c.Packages = code.NewPackages(
222 code.WithBuildTags(c.GoBuildTags...),
223 )
224 }
225
226 if c.Schema == nil {
227 if err := c.LoadSchema(); err != nil {
228 return err
229 }
230 }
231
232 err := c.injectTypesFromSchema()
233 if err != nil {
234 return err
235 }
236
237 err = c.autobind()
238 if err != nil {
239 return err
240 }
241
242 c.injectBuiltins()
243
244 c.Packages.LoadAll(c.packageList()...)
245
246
247 err = c.check()
248 if err != nil {
249 return err
250 }
251
252 return nil
253 }
254
255 func (c *Config) packageList() []string {
256 pkgs := []string{
257 "github.com/99designs/gqlgen/graphql",
258 "github.com/99designs/gqlgen/graphql/introspection",
259 }
260 pkgs = append(pkgs, c.Models.ReferencedPackages()...)
261 pkgs = append(pkgs, c.AutoBind...)
262 return pkgs
263 }
264
265 func (c *Config) ReloadAllPackages() {
266 c.Packages.ReloadAll(c.packageList()...)
267 }
268
269 func (c *Config) IsRoot(def *ast.Definition) bool {
270 return def == c.Schema.Query || def == c.Schema.Mutation || def == c.Schema.Subscription
271 }
272
273 func (c *Config) injectTypesFromSchema() error {
274 c.Directives["goModel"] = DirectiveConfig{
275 SkipRuntime: true,
276 }
277
278 c.Directives["goField"] = DirectiveConfig{
279 SkipRuntime: true,
280 }
281
282 c.Directives["goTag"] = DirectiveConfig{
283 SkipRuntime: true,
284 }
285
286 for _, schemaType := range c.Schema.Types {
287 if c.IsRoot(schemaType) {
288 continue
289 }
290
291 if bd := schemaType.Directives.ForName("goModel"); bd != nil {
292 if ma := bd.Arguments.ForName("model"); ma != nil {
293 if mv, err := ma.Value.Value(nil); err == nil {
294 c.Models.Add(schemaType.Name, mv.(string))
295 }
296 }
297
298 if ma := bd.Arguments.ForName("models"); ma != nil {
299 if mvs, err := ma.Value.Value(nil); err == nil {
300 for _, mv := range mvs.([]interface{}) {
301 c.Models.Add(schemaType.Name, mv.(string))
302 }
303 }
304 }
305
306 if fg := bd.Arguments.ForName("forceGenerate"); fg != nil {
307 if mv, err := fg.Value.Value(nil); err == nil {
308 c.Models.ForceGenerate(schemaType.Name, mv.(bool))
309 }
310 }
311 }
312
313 if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
314 for _, field := range schemaType.Fields {
315 if fd := field.Directives.ForName("goField"); fd != nil {
316 forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
317 fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
318
319 if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
320 if fr, err := ra.Value.Value(nil); err == nil {
321 forceResolver = fr.(bool)
322 }
323 }
324
325 if na := fd.Arguments.ForName("name"); na != nil {
326 if fr, err := na.Value.Value(nil); err == nil {
327 fieldName = fr.(string)
328 }
329 }
330
331 if c.Models[schemaType.Name].Fields == nil {
332 c.Models[schemaType.Name] = TypeMapEntry{
333 Model: c.Models[schemaType.Name].Model,
334 ExtraFields: c.Models[schemaType.Name].ExtraFields,
335 Fields: map[string]TypeMapField{},
336 }
337 }
338
339 c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
340 FieldName: fieldName,
341 Resolver: forceResolver,
342 }
343 }
344 }
345 }
346 }
347
348 return nil
349 }
350
351 type TypeMapEntry struct {
352 Model StringList `yaml:"model,omitempty"`
353 ForceGenerate bool `yaml:"forceGenerate,omitempty"`
354 Fields map[string]TypeMapField `yaml:"fields,omitempty"`
355
356
357 ExtraFields map[string]ModelExtraField `yaml:"extraFields,omitempty"`
358 }
359
360 type TypeMapField struct {
361 Resolver bool `yaml:"resolver"`
362 FieldName string `yaml:"fieldName"`
363 GeneratedMethod string `yaml:"-"`
364 }
365
366 type ModelExtraField struct {
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383 Type string `yaml:"type"`
384
385
386 OverrideTags string `yaml:"overrideTags"`
387
388
389 Description string `yaml:"description"`
390 }
391
392 type StringList []string
393
394 func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
395 var single string
396 err := unmarshal(&single)
397 if err == nil {
398 *a = []string{single}
399 return nil
400 }
401
402 var multi []string
403 err = unmarshal(&multi)
404 if err != nil {
405 return err
406 }
407
408 *a = multi
409 return nil
410 }
411
412 func (a StringList) Has(file string) bool {
413 for _, existing := range a {
414 if existing == file {
415 return true
416 }
417 }
418 return false
419 }
420
421 func (c *Config) check() error {
422 if c.Models == nil {
423 c.Models = TypeMap{}
424 }
425
426 type FilenamePackage struct {
427 Filename string
428 Package string
429 Declaree string
430 }
431
432 fileList := map[string][]FilenamePackage{}
433
434 if err := c.Models.Check(); err != nil {
435 return fmt.Errorf("config.models: %w", err)
436 }
437 if err := c.Exec.Check(); err != nil {
438 return fmt.Errorf("config.exec: %w", err)
439 }
440 fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
441 Filename: c.Exec.Filename,
442 Package: c.Exec.Package,
443 Declaree: "exec",
444 })
445
446 if c.Model.IsDefined() {
447 if err := c.Model.Check(); err != nil {
448 return fmt.Errorf("config.model: %w", err)
449 }
450 fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
451 Filename: c.Model.Filename,
452 Package: c.Model.Package,
453 Declaree: "model",
454 })
455 }
456 if c.Resolver.IsDefined() {
457 if err := c.Resolver.Check(); err != nil {
458 return fmt.Errorf("config.resolver: %w", err)
459 }
460 fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
461 Filename: c.Resolver.Filename,
462 Package: c.Resolver.Package,
463 Declaree: "resolver",
464 })
465 }
466 if c.Federation.IsDefined() {
467 if err := c.Federation.Check(); err != nil {
468 return fmt.Errorf("config.federation: %w", err)
469 }
470 fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
471 Filename: c.Federation.Filename,
472 Package: c.Federation.Package,
473 Declaree: "federation",
474 })
475 if c.Federation.ImportPath() != c.Exec.ImportPath() {
476 return fmt.Errorf("federation and exec must be in the same package")
477 }
478 }
479 if c.Federated {
480 return fmt.Errorf("federated has been removed, instead use\nfederation:\n filename: path/to/federated.go")
481 }
482
483 for importPath, pkg := range fileList {
484 for _, file1 := range pkg {
485 for _, file2 := range pkg {
486 if file1.Package != file2.Package {
487 return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
488 file1.Declaree,
489 file2.Declaree,
490 importPath,
491 file1.Package,
492 file2.Package,
493 )
494 }
495 }
496 }
497 }
498
499 return nil
500 }
501
502 type TypeMap map[string]TypeMapEntry
503
504 func (tm TypeMap) Exists(typeName string) bool {
505 _, ok := tm[typeName]
506 return ok
507 }
508
509 func (tm TypeMap) UserDefined(typeName string) bool {
510 m, ok := tm[typeName]
511 return ok && len(m.Model) > 0
512 }
513
514 func (tm TypeMap) Check() error {
515 for typeName, entry := range tm {
516 for _, model := range entry.Model {
517 if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
518 return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
519 }
520 }
521 }
522 return nil
523 }
524
525 func (tm TypeMap) ReferencedPackages() []string {
526 var pkgs []string
527
528 for _, typ := range tm {
529 for _, model := range typ.Model {
530 if model == "map[string]interface{}" || model == "interface{}" {
531 continue
532 }
533 pkg, _ := code.PkgAndType(model)
534 if pkg == "" || inStrSlice(pkgs, pkg) {
535 continue
536 }
537 pkgs = append(pkgs, code.QualifyPackagePath(pkg))
538 }
539 }
540
541 sort.Slice(pkgs, func(i, j int) bool {
542 return pkgs[i] > pkgs[j]
543 })
544 return pkgs
545 }
546
547 func (tm TypeMap) Add(name string, goType string) {
548 modelCfg := tm[name]
549 modelCfg.Model = append(modelCfg.Model, goType)
550 tm[name] = modelCfg
551 }
552
553 func (tm TypeMap) ForceGenerate(name string, forceGenerate bool) {
554 modelCfg := tm[name]
555 modelCfg.ForceGenerate = forceGenerate
556 tm[name] = modelCfg
557 }
558
559 type DirectiveConfig struct {
560 SkipRuntime bool `yaml:"skip_runtime"`
561 }
562
563 func inStrSlice(haystack []string, needle string) bool {
564 for _, v := range haystack {
565 if needle == v {
566 return true
567 }
568 }
569
570 return false
571 }
572
573
574
575 func findCfg() (string, error) {
576 dir, err := os.Getwd()
577 if err != nil {
578 return "", fmt.Errorf("unable to get working dir to findCfg: %w", err)
579 }
580
581 cfg := findCfgInDir(dir)
582
583 for cfg == "" && dir != filepath.Dir(dir) {
584 dir = filepath.Dir(dir)
585 cfg = findCfgInDir(dir)
586 }
587
588 if cfg == "" {
589 return "", os.ErrNotExist
590 }
591
592 return cfg, nil
593 }
594
595 func findCfgInDir(dir string) string {
596 for _, cfgName := range cfgFilenames {
597 path := filepath.Join(dir, cfgName)
598 if _, err := os.Stat(path); err == nil {
599 return path
600 }
601 }
602 return ""
603 }
604
605 func (c *Config) autobind() error {
606 if len(c.AutoBind) == 0 {
607 return nil
608 }
609
610 ps := c.Packages.LoadAll(c.AutoBind...)
611
612 for _, t := range c.Schema.Types {
613 if c.Models.UserDefined(t.Name) || c.Models[t.Name].ForceGenerate {
614 continue
615 }
616
617 for i, p := range ps {
618 if p == nil || p.Module == nil {
619 return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
620 }
621
622 autobindType := c.lookupAutobindType(p, t)
623 if autobindType != nil {
624 c.Models.Add(t.Name, autobindType.Pkg().Path()+"."+autobindType.Name())
625 break
626 }
627 }
628 }
629
630 for i, t := range c.Models {
631 if t.ForceGenerate {
632 continue
633 }
634
635 for j, m := range t.Model {
636 pkg, typename := code.PkgAndType(m)
637
638
639 if strings.Contains(pkg, "/") {
640 continue
641 }
642
643 for _, p := range ps {
644 if p.Name != pkg {
645 continue
646 }
647 if t := p.Types.Scope().Lookup(typename); t != nil {
648 c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
649 break
650 }
651 }
652 }
653 }
654
655 return nil
656 }
657
658 func (c *Config) lookupAutobindType(p *packages.Package, schemaType *ast.Definition) types.Object {
659
660 for _, lookupName := range []string{schemaType.Name, templates.ToGo(schemaType.Name)} {
661 if t := p.Types.Scope().Lookup(lookupName); t != nil {
662 return t
663 }
664 }
665
666 return nil
667 }
668
669 func (c *Config) injectBuiltins() {
670 builtins := TypeMap{
671 "__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
672 "__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
673 "__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
674 "__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
675 "__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
676 "__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
677 "__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
678 "__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
679 "Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.FloatContext"}},
680 "String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
681 "Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
682 "Int": {Model: StringList{
683 "github.com/99designs/gqlgen/graphql.Int",
684 "github.com/99designs/gqlgen/graphql.Int32",
685 "github.com/99designs/gqlgen/graphql.Int64",
686 }},
687 "ID": {
688 Model: StringList{
689 "github.com/99designs/gqlgen/graphql.ID",
690 "github.com/99designs/gqlgen/graphql.IntID",
691 },
692 },
693 }
694
695 for typeName, entry := range builtins {
696 if !c.Models.Exists(typeName) {
697 c.Models[typeName] = entry
698 }
699 }
700
701
702 extraBuiltins := TypeMap{
703 "Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
704 "Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
705 "Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
706 "Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
707 }
708
709 for typeName, entry := range extraBuiltins {
710 if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
711 c.Models[typeName] = entry
712 }
713 }
714 }
715
716 func (c *Config) LoadSchema() error {
717 if c.Packages != nil {
718 c.Packages = code.NewPackages(
719 code.WithBuildTags(c.GoBuildTags...),
720 )
721 }
722
723 if err := c.check(); err != nil {
724 return err
725 }
726
727 schema, err := gqlparser.LoadSchema(c.Sources...)
728 if err != nil {
729 return err
730 }
731
732 if schema.Query == nil {
733 schema.Query = &ast.Definition{
734 Kind: ast.Object,
735 Name: "Query",
736 }
737 schema.Types["Query"] = schema.Query
738 }
739
740 c.Schema = schema
741 return nil
742 }
743
744 func abs(path string) string {
745 absPath, err := filepath.Abs(path)
746 if err != nil {
747 panic(err)
748 }
749 return filepath.ToSlash(absPath)
750 }
751
View as plain text