1 package modelgen
2
3 import (
4 "errors"
5 "fmt"
6 "go/ast"
7 "go/parser"
8 "go/token"
9 "os"
10 "os/exec"
11 "path/filepath"
12 "reflect"
13 "sort"
14 "strings"
15 "testing"
16
17 "github.com/stretchr/testify/assert"
18 "github.com/stretchr/testify/require"
19
20 "github.com/99designs/gqlgen/codegen/config"
21 "github.com/99designs/gqlgen/graphql"
22 "github.com/99designs/gqlgen/plugin/modelgen/internal/extrafields"
23 "github.com/99designs/gqlgen/plugin/modelgen/out"
24 "github.com/99designs/gqlgen/plugin/modelgen/out_enable_model_json_omitempty_tag_false"
25 "github.com/99designs/gqlgen/plugin/modelgen/out_enable_model_json_omitempty_tag_nil"
26 "github.com/99designs/gqlgen/plugin/modelgen/out_enable_model_json_omitempty_tag_true"
27 "github.com/99designs/gqlgen/plugin/modelgen/out_nullable_input_omittable"
28 "github.com/99designs/gqlgen/plugin/modelgen/out_struct_pointers"
29 )
30
31 func TestModelGeneration(t *testing.T) {
32 cfg, err := config.LoadConfig("testdata/gqlgen.yml")
33 require.NoError(t, err)
34 require.NoError(t, cfg.Init())
35 p := Plugin{
36 MutateHook: mutateHook,
37 FieldHook: DefaultFieldMutateHook,
38 }
39 require.NoError(t, p.MutateConfig(cfg))
40 require.NoError(t, goBuild(t, "./out/"))
41
42 require.True(t, cfg.Models.UserDefined("MissingTypeNotNull"))
43 require.True(t, cfg.Models.UserDefined("MissingTypeNullable"))
44 require.True(t, cfg.Models.UserDefined("MissingEnum"))
45 require.True(t, cfg.Models.UserDefined("MissingUnion"))
46 require.True(t, cfg.Models.UserDefined("MissingInterface"))
47 require.True(t, cfg.Models.UserDefined("TypeWithDescription"))
48 require.True(t, cfg.Models.UserDefined("EnumWithDescription"))
49 require.True(t, cfg.Models.UserDefined("InterfaceWithDescription"))
50 require.True(t, cfg.Models.UserDefined("UnionWithDescription"))
51 require.True(t, cfg.Models.UserDefined("RenameFieldTest"))
52 require.True(t, cfg.Models.UserDefined("ExtraFieldsTest"))
53
54 t.Run("no pointer pointers", func(t *testing.T) {
55 generated, err := os.ReadFile("./out/generated.go")
56 require.NoError(t, err)
57 require.NotContains(t, string(generated), "**")
58 })
59
60 t.Run("description is generated", func(t *testing.T) {
61 node, err := parser.ParseFile(token.NewFileSet(), "./out/generated.go", nil, parser.ParseComments)
62 require.NoError(t, err)
63 for _, commentGroup := range node.Comments {
64 text := commentGroup.Text()
65 words := strings.Split(text, " ")
66 require.True(t, len(words) > 1, "expected description %q to have more than one word", text)
67 }
68 })
69
70 t.Run("tags are applied", func(t *testing.T) {
71 file, err := os.ReadFile("./out/generated.go")
72 require.NoError(t, err)
73
74 fileText := string(file)
75
76 expectedTags := []string{
77 `json:"missing2" database:"MissingTypeNotNullmissing2"`,
78 `json:"name,omitempty" database:"MissingInputname"`,
79 `json:"missing2,omitempty" database:"MissingTypeNullablemissing2"`,
80 `json:"name,omitempty" database:"TypeWithDescriptionname"`,
81 }
82
83 for _, tag := range expectedTags {
84 require.True(t, strings.Contains(fileText, tag), "\nexpected:\n"+tag+"\ngot\n"+fileText)
85 }
86 })
87
88 t.Run("field hooks are applied", func(t *testing.T) {
89 file, err := os.ReadFile("./out/generated.go")
90 require.NoError(t, err)
91
92 fileText := string(file)
93
94 expectedTags := []string{
95 `json:"name,omitempty" anotherTag:"tag"`,
96 `json:"enum,omitempty" yetAnotherTag:"12"`,
97 `json:"noVal,omitempty" yaml:"noVal" repeated:"true"`,
98 `json:"repeated,omitempty" someTag:"value" repeated:"true"`,
99 }
100
101 for _, tag := range expectedTags {
102 require.True(t, strings.Contains(fileText, tag), "\nexpected:\n"+tag+"\ngot\n"+fileText)
103 }
104 })
105
106 t.Run("concrete types implement interface", func(t *testing.T) {
107 var _ out.FooBarer = out.FooBarr{}
108 })
109
110 t.Run("implemented interfaces", func(t *testing.T) {
111 pkg, err := parseAst("out")
112 require.NoError(t, err)
113
114 path := filepath.Join("out", "generated.go")
115 generated := pkg.Files[path]
116
117 type field struct {
118 typ string
119 name string
120 }
121 cases := []struct {
122 name string
123 wantFields []field
124 }{
125 {
126 name: "A",
127 wantFields: []field{
128 {
129 typ: "method",
130 name: "IsA",
131 },
132 {
133 typ: "method",
134 name: "GetA",
135 },
136 },
137 },
138 {
139 name: "B",
140 wantFields: []field{
141 {
142 typ: "method",
143 name: "IsB",
144 },
145 {
146 typ: "method",
147 name: "GetB",
148 },
149 },
150 },
151 {
152 name: "C",
153 wantFields: []field{
154 {
155 typ: "method",
156 name: "IsA",
157 },
158 {
159 typ: "method",
160 name: "IsC",
161 },
162 {
163 typ: "method",
164 name: "GetA",
165 },
166 {
167 typ: "method",
168 name: "GetC",
169 },
170 },
171 },
172 {
173 name: "D",
174 wantFields: []field{
175 {
176 typ: "method",
177 name: "IsA",
178 },
179 {
180 typ: "method",
181 name: "IsB",
182 },
183 {
184 typ: "method",
185 name: "IsD",
186 },
187 {
188 typ: "method",
189 name: "GetA",
190 },
191 {
192 typ: "method",
193 name: "GetB",
194 },
195 {
196 typ: "method",
197 name: "GetD",
198 },
199 },
200 },
201 }
202 for _, tc := range cases {
203 tc := tc
204 t.Run(tc.name, func(t *testing.T) {
205 typeSpec, ok := generated.Scope.Lookup(tc.name).Decl.(*ast.TypeSpec)
206 require.True(t, ok)
207
208 fields := typeSpec.Type.(*ast.InterfaceType).Methods.List
209 for i, want := range tc.wantFields {
210 if want.typ == "ident" {
211 ident, ok := fields[i].Type.(*ast.Ident)
212 require.True(t, ok)
213 assert.Equal(t, want.name, ident.Name)
214 }
215 if want.typ == "method" {
216 require.GreaterOrEqual(t, 1, len(fields[i].Names))
217 name := fields[i].Names[0].Name
218 assert.Equal(t, want.name, name)
219 }
220 }
221 })
222 }
223 })
224
225 t.Run("implemented interfaces type CDImplemented", func(t *testing.T) {
226 pkg, err := parseAst("out")
227 require.NoError(t, err)
228
229 path := filepath.Join("out", "generated.go")
230 generated := pkg.Files[path]
231
232 wantMethods := []string{
233 "IsA",
234 "IsB",
235 "IsC",
236 "IsD",
237 }
238
239 gots := make([]string, 0, len(wantMethods))
240 for _, decl := range generated.Decls {
241 if funcDecl, ok := decl.(*ast.FuncDecl); ok {
242 switch funcDecl.Name.Name {
243 case "IsA", "IsB", "IsC", "IsD":
244 gots = append(gots, funcDecl.Name.Name)
245 require.Len(t, funcDecl.Recv.List, 1)
246 recvIdent, ok := funcDecl.Recv.List[0].Type.(*ast.Ident)
247 require.True(t, ok)
248 require.Equal(t, "CDImplemented", recvIdent.Name)
249 }
250 }
251 }
252
253 sort.Strings(gots)
254 require.Equal(t, wantMethods, gots)
255 })
256
257 t.Run("cyclical struct fields become pointers", func(t *testing.T) {
258 require.Nil(t, out.CyclicalA{}.FieldOne)
259 require.Nil(t, out.CyclicalA{}.FieldTwo)
260 require.Nil(t, out.CyclicalA{}.FieldThree)
261 require.NotNil(t, out.CyclicalA{}.FieldFour)
262 require.Nil(t, out.CyclicalB{}.FieldOne)
263 require.Nil(t, out.CyclicalB{}.FieldTwo)
264 require.Nil(t, out.CyclicalB{}.FieldThree)
265 require.Nil(t, out.CyclicalB{}.FieldFour)
266 require.NotNil(t, out.CyclicalB{}.FieldFive)
267 })
268
269 t.Run("non-cyclical struct fields become pointers", func(t *testing.T) {
270 require.NotNil(t, out.NotCyclicalB{}.FieldOne)
271 require.Nil(t, out.NotCyclicalB{}.FieldTwo)
272 })
273
274 t.Run("recursive struct fields become pointers", func(t *testing.T) {
275 require.Nil(t, out.Recursive{}.FieldOne)
276 require.Nil(t, out.Recursive{}.FieldTwo)
277 require.Nil(t, out.Recursive{}.FieldThree)
278 require.NotNil(t, out.Recursive{}.FieldFour)
279 })
280
281 t.Run("overridden struct field names use same capitalization as config", func(t *testing.T) {
282 require.NotNil(t, out.RenameFieldTest{}.GOODnaME)
283 })
284
285 t.Run("nullable input fields can be made omittable with goField", func(t *testing.T) {
286 require.IsType(t, out.MissingInput{}.NullString, graphql.Omittable[*string]{})
287 require.IsType(t, out.MissingInput{}.NullEnum, graphql.Omittable[*out.MissingEnum]{})
288 require.IsType(t, out.MissingInput{}.NullObject, graphql.Omittable[*out.ExistingInput]{})
289 })
290
291 t.Run("extra fields are present", func(t *testing.T) {
292 var m out.ExtraFieldsTest
293
294 require.IsType(t, m.FieldInt, int64(0))
295 require.IsType(t, m.FieldInternalType, extrafields.Type{})
296 require.IsType(t, m.FieldStringPtr, new(string))
297 require.IsType(t, m.FieldIntSlice, []int64{})
298 })
299 }
300
301 func TestModelGenerationOmitRootModels(t *testing.T) {
302 cfg, err := config.LoadConfig("testdata/gqlgen_omit_root_models.yml")
303 require.NoError(t, err)
304 require.NoError(t, cfg.Init())
305 p := Plugin{
306 MutateHook: mutateHook,
307 FieldHook: DefaultFieldMutateHook,
308 }
309 require.NoError(t, p.MutateConfig(cfg))
310 require.NoError(t, goBuild(t, "./out/"))
311 generated, err := os.ReadFile("./out/generated_omit_root_models.go")
312 require.NoError(t, err)
313 require.NotContains(t, string(generated), "type Mutation struct")
314 require.NotContains(t, string(generated), "type Query struct")
315 require.NotContains(t, string(generated), "type Subscription struct")
316 }
317
318 func TestModelGenerationOmitResolverFields(t *testing.T) {
319 cfg, err := config.LoadConfig("testdata/gqlgen_omit_resolver_fields.yml")
320 require.NoError(t, err)
321 require.NoError(t, cfg.Init())
322 p := Plugin{
323 MutateHook: mutateHook,
324 FieldHook: DefaultFieldMutateHook,
325 }
326 require.NoError(t, p.MutateConfig(cfg))
327 require.NoError(t, goBuild(t, "./out_omit_resolver_fields/"))
328 generated, err := os.ReadFile("./out_omit_resolver_fields/generated.go")
329 require.NoError(t, err)
330 require.Contains(t, string(generated), "type Base struct")
331 require.Contains(t, string(generated), "StandardField")
332 require.NotContains(t, string(generated), "ResolverField")
333 }
334
335 func TestModelGenerationStructFieldPointers(t *testing.T) {
336 cfg, err := config.LoadConfig("testdata/gqlgen_struct_field_pointers.yml")
337 require.NoError(t, err)
338 require.NoError(t, cfg.Init())
339 p := Plugin{
340 MutateHook: mutateHook,
341 FieldHook: DefaultFieldMutateHook,
342 }
343 require.NoError(t, p.MutateConfig(cfg))
344
345 t.Run("no pointer pointers", func(t *testing.T) {
346 generated, err := os.ReadFile("./out_struct_pointers/generated.go")
347 require.NoError(t, err)
348 require.NotContains(t, string(generated), "**")
349 })
350
351 t.Run("cyclical struct fields become pointers", func(t *testing.T) {
352 require.Nil(t, out_struct_pointers.CyclicalA{}.FieldOne)
353 require.Nil(t, out_struct_pointers.CyclicalA{}.FieldTwo)
354 require.Nil(t, out_struct_pointers.CyclicalA{}.FieldThree)
355 require.NotNil(t, out_struct_pointers.CyclicalA{}.FieldFour)
356 require.Nil(t, out_struct_pointers.CyclicalB{}.FieldOne)
357 require.Nil(t, out_struct_pointers.CyclicalB{}.FieldTwo)
358 require.Nil(t, out_struct_pointers.CyclicalB{}.FieldThree)
359 require.Nil(t, out_struct_pointers.CyclicalB{}.FieldFour)
360 require.NotNil(t, out_struct_pointers.CyclicalB{}.FieldFive)
361 })
362
363 t.Run("non-cyclical struct fields do not become pointers", func(t *testing.T) {
364 require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldOne)
365 require.NotNil(t, out_struct_pointers.NotCyclicalB{}.FieldTwo)
366 })
367
368 t.Run("recursive struct fields become pointers", func(t *testing.T) {
369 require.Nil(t, out_struct_pointers.Recursive{}.FieldOne)
370 require.Nil(t, out_struct_pointers.Recursive{}.FieldTwo)
371 require.Nil(t, out_struct_pointers.Recursive{}.FieldThree)
372 require.NotNil(t, out_struct_pointers.Recursive{}.FieldFour)
373 })
374
375 t.Run("no getters", func(t *testing.T) {
376 generated, err := os.ReadFile("./out_struct_pointers/generated.go")
377 require.NoError(t, err)
378 require.NotContains(t, string(generated), "func (this")
379 })
380 }
381
382 func TestModelGenerationNullableInputOmittable(t *testing.T) {
383 cfg, err := config.LoadConfig("testdata/gqlgen_nullable_input_omittable.yml")
384 require.NoError(t, err)
385 require.NoError(t, cfg.Init())
386 p := Plugin{
387 MutateHook: mutateHook,
388 FieldHook: DefaultFieldMutateHook,
389 }
390 require.NoError(t, p.MutateConfig(cfg))
391
392 t.Run("nullable input fields are omittable", func(t *testing.T) {
393 require.IsType(t, out_nullable_input_omittable.MissingInput{}.Name, graphql.Omittable[*string]{})
394 require.IsType(t, out_nullable_input_omittable.MissingInput{}.Enum, graphql.Omittable[*out_nullable_input_omittable.MissingEnum]{})
395 require.IsType(t, out_nullable_input_omittable.MissingInput{}.NullString, graphql.Omittable[*string]{})
396 require.IsType(t, out_nullable_input_omittable.MissingInput{}.NullEnum, graphql.Omittable[*out_nullable_input_omittable.MissingEnum]{})
397 require.IsType(t, out_nullable_input_omittable.MissingInput{}.NullObject, graphql.Omittable[*out_nullable_input_omittable.ExistingInput]{})
398 })
399
400 t.Run("non-nullable input fields are not omittable", func(t *testing.T) {
401 require.IsType(t, out_nullable_input_omittable.MissingInput{}.NonNullString, "")
402 })
403 }
404
405 func TestModelGenerationOmitemptyConfig(t *testing.T) {
406 suites := []struct {
407 n string
408 cfg string
409 enabled bool
410 t any
411 }{
412 {
413 n: "nil",
414 cfg: "gqlgen_enable_model_json_omitempty_tag_nil.yml",
415 enabled: true,
416 t: out_enable_model_json_omitempty_tag_nil.OmitEmptyJSONTagTest{},
417 },
418 {
419 n: "true",
420 cfg: "gqlgen_enable_model_json_omitempty_tag_true.yml",
421 enabled: true,
422 t: out_enable_model_json_omitempty_tag_true.OmitEmptyJSONTagTest{},
423 },
424 {
425 n: "false",
426 cfg: "gqlgen_enable_model_json_omitempty_tag_false.yml",
427 enabled: false,
428 t: out_enable_model_json_omitempty_tag_false.OmitEmptyJSONTagTest{},
429 },
430 }
431
432 for _, s := range suites {
433 t.Run(s.n, func(t *testing.T) {
434 cfg, err := config.LoadConfig(fmt.Sprintf("testdata/%s", s.cfg))
435 require.NoError(t, err)
436 require.NoError(t, cfg.Init())
437 p := Plugin{
438 MutateHook: mutateHook,
439 FieldHook: DefaultFieldMutateHook,
440 }
441 require.NoError(t, p.MutateConfig(cfg))
442 rt := reflect.TypeOf(s.t)
443
444
445 sfn, ok := rt.FieldByName("ValueNonNil")
446 require.True(t, ok)
447 require.Equal(t, "ValueNonNil", sfn.Tag.Get("json"))
448
449
450 sf, ok := rt.FieldByName("Value")
451 require.True(t, ok)
452
453 var expected string
454 if s.enabled {
455 expected = "Value,omitempty"
456 } else {
457 expected = "Value"
458 }
459 require.Equal(t, expected, sf.Tag.Get("json"))
460 })
461 }
462 }
463
464 func mutateHook(b *ModelBuild) *ModelBuild {
465 for _, model := range b.Models {
466 for _, field := range model.Fields {
467 field.Tag += ` database:"` + model.Name + field.Name + `"`
468 }
469 }
470
471 return b
472 }
473
474 func parseAst(path string) (*ast.Package, error) {
475
476 fset := token.NewFileSet()
477 pkgs, err := parser.ParseDir(fset, path, nil, parser.AllErrors)
478 if err != nil {
479 return nil, err
480 }
481 return pkgs["out"], nil
482 }
483
484 func goBuild(t *testing.T, path string) error {
485 t.Helper()
486 cmd := exec.Command("go", "build", path)
487 out, err := cmd.CombinedOutput()
488 if err != nil {
489 return errors.New(string(out))
490 }
491
492 return nil
493 }
494
495 func TestRemoveDuplicate(t *testing.T) {
496 type args struct {
497 t string
498 }
499 tests := []struct {
500 name string
501 args args
502 want string
503 wantPanic bool
504 }{
505 {
506 name: "Duplicate Test with 1",
507 args: args{
508 t: "json:\"name\"",
509 },
510 want: "json:\"name\"",
511 },
512 {
513 name: "Duplicate Test with 2",
514 args: args{
515 t: "json:\"name\" json:\"name2\"",
516 },
517 want: "json:\"name2\"",
518 },
519 {
520 name: "Duplicate Test with 3",
521 args: args{
522 t: "json:\"name\" json:\"name2\" json:\"name3\"",
523 },
524 want: "json:\"name3\"",
525 },
526 {
527 name: "Duplicate Test with 3 and 1 unrelated",
528 args: args{
529 t: "json:\"name\" something:\"name2\" json:\"name3\"",
530 },
531 want: "something:\"name2\" json:\"name3\"",
532 },
533 {
534 name: "Duplicate Test with 3 and 2 unrelated",
535 args: args{
536 t: "something:\"name1\" json:\"name\" something:\"name2\" json:\"name3\"",
537 },
538 want: "something:\"name2\" json:\"name3\"",
539 },
540 {
541 name: "Test tag value with leading empty space",
542 args: args{
543 t: "json:\"name, name2\"",
544 },
545 want: "json:\"name, name2\"",
546 wantPanic: true,
547 },
548 {
549 name: "Test tag value with trailing empty space",
550 args: args{
551 t: "json:\"name,name2 \"",
552 },
553 want: "json:\"name,name2 \"",
554 wantPanic: true,
555 },
556 {
557 name: "Test tag value with space in between",
558 args: args{
559 t: "gorm:\"unique;not null\"",
560 },
561 want: "gorm:\"unique;not null\"",
562 wantPanic: false,
563 },
564 {
565 name: "Test mix use of gorm and json tags",
566 args: args{
567 t: "gorm:\"unique;not null\" json:\"name,name2\"",
568 },
569 want: "gorm:\"unique;not null\" json:\"name,name2\"",
570 wantPanic: false,
571 },
572 {
573 name: "Test gorm tag with colon",
574 args: args{
575 t: "gorm:\"type:varchar(63);unique_index\"",
576 },
577 want: "gorm:\"type:varchar(63);unique_index\"",
578 wantPanic: false,
579 },
580 {
581 name: "Test mix use of gorm and duplicate json tags with colon",
582 args: args{
583 t: "json:\"name0\" gorm:\"type:varchar(63);unique_index\" json:\"name,name2\"",
584 },
585 want: "gorm:\"type:varchar(63);unique_index\" json:\"name,name2\"",
586 wantPanic: false,
587 },
588 }
589 for _, tt := range tests {
590 t.Run(tt.name, func(t *testing.T) {
591 if tt.wantPanic {
592 assert.Panics(t, func() { removeDuplicateTags(tt.args.t) }, "The code did not panic")
593 } else {
594 if got := removeDuplicateTags(tt.args.t); got != tt.want {
595 t.Errorf("removeDuplicate() = %v, want %v", got, tt.want)
596 }
597 }
598 })
599 }
600 }
601
602 func Test_containsInvalidSpace(t *testing.T) {
603 type args struct {
604 valuesString string
605 }
606 tests := []struct {
607 name string
608 args args
609 want bool
610 }{
611 {
612 name: "Test tag value with leading empty space",
613 args: args{
614 valuesString: "name, name2",
615 },
616 want: true,
617 },
618 {
619 name: "Test tag value with trailing empty space",
620 args: args{
621 valuesString: "name ,name2",
622 },
623 want: true,
624 },
625 {
626 name: "Test tag value with valid empty space in words",
627 args: args{
628 valuesString: "accept this,name2",
629 },
630 want: false,
631 },
632 }
633 for _, tt := range tests {
634 t.Run(tt.name, func(t *testing.T) {
635 assert.Equalf(t, tt.want, containsInvalidSpace(tt.args.valuesString), "containsInvalidSpace(%v)", tt.args.valuesString)
636 })
637 }
638 }
639
640 func Test_splitTagsBySpace(t *testing.T) {
641 type args struct {
642 tagsString string
643 }
644 tests := []struct {
645 name string
646 args args
647 want []string
648 }{
649 {
650 name: "multiple tags, single value",
651 args: args{
652 tagsString: "json:\"name\" something:\"name2\" json:\"name3\"",
653 },
654 want: []string{"json:\"name\"", "something:\"name2\"", "json:\"name3\""},
655 },
656 {
657 name: "multiple tag, multiple values",
658 args: args{
659 tagsString: "json:\"name\" something:\"name2\" json:\"name3,name4\"",
660 },
661 want: []string{"json:\"name\"", "something:\"name2\"", "json:\"name3,name4\""},
662 },
663 {
664 name: "single tag, single value",
665 args: args{
666 tagsString: "json:\"name\"",
667 },
668 want: []string{"json:\"name\""},
669 },
670 {
671 name: "single tag, multiple values",
672 args: args{
673 tagsString: "json:\"name,name2\"",
674 },
675 want: []string{"json:\"name,name2\""},
676 },
677 {
678 name: "space in value",
679 args: args{
680 tagsString: "gorm:\"not nul,name2\"",
681 },
682 want: []string{"gorm:\"not nul,name2\""},
683 },
684 }
685 for _, tt := range tests {
686 t.Run(tt.name, func(t *testing.T) {
687 assert.Equalf(t, tt.want, splitTagsBySpace(tt.args.tagsString), "splitTagsBySpace(%v)", tt.args.tagsString)
688 })
689 }
690 }
691
692 func TestCustomTemplate(t *testing.T) {
693 cfg, err := config.LoadConfig("testdata/gqlgen_custom_model_template.yml")
694 require.NoError(t, err)
695 require.NoError(t, cfg.Init())
696 p := Plugin{
697 MutateHook: mutateHook,
698 FieldHook: DefaultFieldMutateHook,
699 }
700 require.NoError(t, p.MutateConfig(cfg))
701 }
702
View as plain text