1 package config
2
3 import (
4 "errors"
5 "fmt"
6 "go/token"
7 "go/types"
8
9 "github.com/vektah/gqlparser/v2/ast"
10 "golang.org/x/tools/go/packages"
11
12 "github.com/99designs/gqlgen/codegen/templates"
13 "github.com/99designs/gqlgen/internal/code"
14 )
15
16 var ErrTypeNotFound = errors.New("unable to find type")
17
18
19 type Binder struct {
20 pkgs *code.Packages
21 schema *ast.Schema
22 cfg *Config
23 tctx *types.Context
24 References []*TypeReference
25 SawInvalid bool
26 objectCache map[string]map[string]types.Object
27 }
28
29 func (c *Config) NewBinder() *Binder {
30 return &Binder{
31 pkgs: c.Packages,
32 schema: c.Schema,
33 cfg: c,
34 }
35 }
36
37 func (b *Binder) TypePosition(typ types.Type) token.Position {
38 named, isNamed := typ.(*types.Named)
39 if !isNamed {
40 return token.Position{
41 Filename: "unknown",
42 }
43 }
44
45 return b.ObjectPosition(named.Obj())
46 }
47
48 func (b *Binder) ObjectPosition(typ types.Object) token.Position {
49 if typ == nil {
50 return token.Position{
51 Filename: "unknown",
52 }
53 }
54 pkg := b.pkgs.Load(typ.Pkg().Path())
55 return pkg.Fset.Position(typ.Pos())
56 }
57
58 func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
59 pkgName, typeName := code.PkgAndType(name)
60 return b.FindType(pkgName, typeName)
61 }
62
63 func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
64 if pkgName == "" {
65 if typeName == "map[string]interface{}" {
66 return MapType, nil
67 }
68
69 if typeName == "interface{}" {
70 return InterfaceType, nil
71 }
72 }
73
74 obj, err := b.FindObject(pkgName, typeName)
75 if err != nil {
76 return nil, err
77 }
78
79 if fun, isFunc := obj.(*types.Func); isFunc {
80 return fun.Type().(*types.Signature).Params().At(0).Type(), nil
81 }
82 return obj.Type(), nil
83 }
84
85 func (b *Binder) InstantiateType(orig types.Type, targs []types.Type) (types.Type, error) {
86 if b.tctx == nil {
87 b.tctx = types.NewContext()
88 }
89
90 return types.Instantiate(b.tctx, orig, targs, false)
91 }
92
93 var (
94 MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
95 InterfaceType = types.NewInterfaceType(nil, nil)
96 )
97
98 func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
99 models := b.cfg.Models[name].Model
100 if len(models) == 0 {
101 return nil, fmt.Errorf(name + " not found in typemap")
102 }
103
104 if models[0] == "map[string]interface{}" {
105 return MapType, nil
106 }
107
108 if models[0] == "interface{}" {
109 return InterfaceType, nil
110 }
111
112 pkgName, typeName := code.PkgAndType(models[0])
113 if pkgName == "" {
114 return nil, fmt.Errorf("missing package name for %s", name)
115 }
116
117 obj, err := b.FindObject(pkgName, typeName)
118 if err != nil {
119 return nil, err
120 }
121
122 return obj.Type(), nil
123 }
124
125 func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
126 if pkgName == "" {
127 return nil, fmt.Errorf("package cannot be nil")
128 }
129
130 pkg := b.pkgs.LoadWithTypes(pkgName)
131 if pkg == nil {
132 err := b.pkgs.Errors()
133 if err != nil {
134 return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
135 }
136 return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
137 }
138
139 if b.objectCache == nil {
140 b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
141 }
142
143 defsIndex, ok := b.objectCache[pkgName]
144 if !ok {
145 defsIndex = indexDefs(pkg)
146 b.objectCache[pkgName] = defsIndex
147 }
148
149
150 if val, ok := defsIndex["Marshal"+typeName]; ok {
151 return val, nil
152 }
153
154 if val, ok := defsIndex[typeName]; ok {
155 return val, nil
156 }
157
158 return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
159 }
160
161 func indexDefs(pkg *packages.Package) map[string]types.Object {
162 res := make(map[string]types.Object)
163
164 scope := pkg.Types.Scope()
165 for astNode, def := range pkg.TypesInfo.Defs {
166
167 if def == nil {
168 continue
169 }
170 parent := def.Parent()
171 if parent == nil || parent != scope {
172 continue
173 }
174
175 if _, ok := res[astNode.Name]; !ok {
176
177
178
179 res[astNode.Name] = def
180 }
181 }
182
183 return res
184 }
185
186 func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
187 newRef := *ref
188 newRef.GO = types.NewPointer(ref.GO)
189 b.References = append(b.References, &newRef)
190 return &newRef
191 }
192
193
194 type TypeReference struct {
195 Definition *ast.Definition
196 GQL *ast.Type
197 GO types.Type
198 Target types.Type
199 CastType types.Type
200 Marshaler *types.Func
201 Unmarshaler *types.Func
202 IsMarshaler bool
203 IsOmittable bool
204 IsContext bool
205 PointersInUmarshalInput bool
206 IsRoot bool
207 }
208
209 func (ref *TypeReference) Elem() *TypeReference {
210 if p, isPtr := ref.GO.(*types.Pointer); isPtr {
211 newRef := *ref
212 newRef.GO = p.Elem()
213 return &newRef
214 }
215
216 if ref.IsSlice() {
217 newRef := *ref
218 newRef.GO = ref.GO.(*types.Slice).Elem()
219 newRef.GQL = ref.GQL.Elem
220 return &newRef
221 }
222 return nil
223 }
224
225 func (ref *TypeReference) IsPtr() bool {
226 _, isPtr := ref.GO.(*types.Pointer)
227 return isPtr
228 }
229
230
231 func (ref *TypeReference) IsPtrToPtr() bool {
232 if p, isPtr := ref.GO.(*types.Pointer); isPtr {
233 _, isPtr := p.Elem().(*types.Pointer)
234 return isPtr
235 }
236 return false
237 }
238
239 func (ref *TypeReference) IsNilable() bool {
240 return IsNilable(ref.GO)
241 }
242
243 func (ref *TypeReference) IsSlice() bool {
244 _, isSlice := ref.GO.(*types.Slice)
245 return ref.GQL.Elem != nil && isSlice
246 }
247
248 func (ref *TypeReference) IsPtrToSlice() bool {
249 if ref.IsPtr() {
250 _, isPointerToSlice := ref.GO.(*types.Pointer).Elem().(*types.Slice)
251 return isPointerToSlice
252 }
253 return false
254 }
255
256 func (ref *TypeReference) IsPtrToIntf() bool {
257 if ref.IsPtr() {
258 _, isPointerToInterface := ref.GO.(*types.Pointer).Elem().(*types.Interface)
259 return isPointerToInterface
260 }
261 return false
262 }
263
264 func (ref *TypeReference) IsNamed() bool {
265 _, isSlice := ref.GO.(*types.Named)
266 return isSlice
267 }
268
269 func (ref *TypeReference) IsStruct() bool {
270 _, isStruct := ref.GO.Underlying().(*types.Struct)
271 return isStruct
272 }
273
274 func (ref *TypeReference) IsScalar() bool {
275 return ref.Definition.Kind == ast.Scalar
276 }
277
278 func (ref *TypeReference) IsMap() bool {
279 return ref.GO == MapType
280 }
281
282 func (ref *TypeReference) UniquenessKey() string {
283 nullability := "O"
284 if ref.GQL.NonNull {
285 nullability = "N"
286 }
287
288 elemNullability := ""
289 if ref.GQL.Elem != nil && ref.GQL.Elem.NonNull {
290
291 elemNullability = "áš„"
292 }
293 return nullability + ref.Definition.Name + "2" + templates.TypeIdentifier(ref.GO) + elemNullability
294 }
295
296 func (ref *TypeReference) MarshalFunc() string {
297 if ref.Definition == nil {
298 panic(errors.New("Definition missing for " + ref.GQL.Name()))
299 }
300
301 if ref.Definition.Kind == ast.InputObject {
302 return ""
303 }
304
305 return "marshal" + ref.UniquenessKey()
306 }
307
308 func (ref *TypeReference) UnmarshalFunc() string {
309 if ref.Definition == nil {
310 panic(errors.New("Definition missing for " + ref.GQL.Name()))
311 }
312
313 if !ref.Definition.IsInputType() {
314 return ""
315 }
316
317 return "unmarshal" + ref.UniquenessKey()
318 }
319
320 func (ref *TypeReference) IsTargetNilable() bool {
321 return IsNilable(ref.Target)
322 }
323
324 func (b *Binder) PushRef(ret *TypeReference) {
325 b.References = append(b.References, ret)
326 }
327
328 func isMap(t types.Type) bool {
329 if t == nil {
330 return true
331 }
332 _, ok := t.(*types.Map)
333 return ok
334 }
335
336 func isIntf(t types.Type) bool {
337 if t == nil {
338 return true
339 }
340 _, ok := t.(*types.Interface)
341 return ok
342 }
343
344 func unwrapOmittable(t types.Type) (types.Type, bool) {
345 if t == nil {
346 return t, false
347 }
348 named, ok := t.(*types.Named)
349 if !ok {
350 return t, false
351 }
352 if named.Origin().String() != "github.com/99designs/gqlgen/graphql.Omittable[T any]" {
353 return t, false
354 }
355 return named.TypeArgs().At(0), true
356 }
357
358 func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
359 if innerType, ok := unwrapOmittable(bindTarget); ok {
360 if schemaType.NonNull {
361 return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name())
362 }
363
364 ref, err := b.TypeReference(schemaType, innerType)
365 if err != nil {
366 return nil, err
367 }
368
369 ref.IsOmittable = true
370 return ref, err
371 }
372
373 if !isValid(bindTarget) {
374 b.SawInvalid = true
375 return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
376 }
377
378 var pkgName, typeName string
379 def := b.schema.Types[schemaType.Name()]
380 defer func() {
381 if err == nil && ret != nil {
382 b.PushRef(ret)
383 }
384 }()
385
386 if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
387 return nil, fmt.Errorf("%s was not found", schemaType.Name())
388 }
389
390 for _, model := range b.cfg.Models[schemaType.Name()].Model {
391 if model == "map[string]interface{}" {
392 if !isMap(bindTarget) {
393 continue
394 }
395 return &TypeReference{
396 Definition: def,
397 GQL: schemaType,
398 GO: MapType,
399 IsRoot: b.cfg.IsRoot(def),
400 }, nil
401 }
402
403 if model == "interface{}" {
404 if !isIntf(bindTarget) {
405 continue
406 }
407 return &TypeReference{
408 Definition: def,
409 GQL: schemaType,
410 GO: InterfaceType,
411 IsRoot: b.cfg.IsRoot(def),
412 }, nil
413 }
414
415 pkgName, typeName = code.PkgAndType(model)
416 if pkgName == "" {
417 return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
418 }
419
420 ref := &TypeReference{
421 Definition: def,
422 GQL: schemaType,
423 IsRoot: b.cfg.IsRoot(def),
424 }
425
426 obj, err := b.FindObject(pkgName, typeName)
427 if err != nil {
428 return nil, err
429 }
430
431 if fun, isFunc := obj.(*types.Func); isFunc {
432 ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
433 ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
434 ref.Marshaler = fun
435 ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
436 } else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
437 ref.GO = obj.Type()
438 ref.IsContext = true
439 ref.IsMarshaler = true
440 } else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
441 ref.GO = obj.Type()
442 ref.IsMarshaler = true
443 } else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
444
445
446 ref.GO = obj.Type()
447 ref.CastType = underlying
448
449 underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
450 if err != nil {
451 return nil, err
452 }
453
454 ref.Marshaler = underlyingRef.Marshaler
455 ref.Unmarshaler = underlyingRef.Unmarshaler
456 } else {
457 ref.GO = obj.Type()
458 }
459
460 ref.Target = ref.GO
461 ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
462
463 if bindTarget != nil {
464 if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
465 continue
466 }
467 ref.GO = bindTarget
468 }
469
470 ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
471
472 return ref, nil
473 }
474
475 return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
476 }
477
478 func isValid(t types.Type) bool {
479 basic, isBasic := t.(*types.Basic)
480 if !isBasic {
481 return true
482 }
483 return basic.Kind() != types.Invalid
484 }
485
486 func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
487 if t.Elem != nil {
488 child := b.CopyModifiersFromAst(t.Elem, base)
489 if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
490 child = types.NewPointer(child)
491 }
492 return types.NewSlice(child)
493 }
494
495 var isInterface bool
496 if named, ok := base.(*types.Named); ok {
497 _, isInterface = named.Underlying().(*types.Interface)
498 }
499
500 if !isInterface && !IsNilable(base) && !t.NonNull {
501 return types.NewPointer(base)
502 }
503
504 return base
505 }
506
507 func IsNilable(t types.Type) bool {
508 if namedType, isNamed := t.(*types.Named); isNamed {
509 return IsNilable(namedType.Underlying())
510 }
511 _, isPtr := t.(*types.Pointer)
512 _, isMap := t.(*types.Map)
513 _, isInterface := t.(*types.Interface)
514 _, isSlice := t.(*types.Slice)
515 _, isChan := t.(*types.Chan)
516 return isPtr || isMap || isInterface || isSlice || isChan
517 }
518
519 func hasMethod(it types.Type, name string) bool {
520 if ptr, isPtr := it.(*types.Pointer); isPtr {
521 it = ptr.Elem()
522 }
523 namedType, ok := it.(*types.Named)
524 if !ok {
525 return false
526 }
527
528 for i := 0; i < namedType.NumMethods(); i++ {
529 if namedType.Method(i).Name() == name {
530 return true
531 }
532 }
533 return false
534 }
535
536 func basicUnderlying(it types.Type) *types.Basic {
537 if ptr, isPtr := it.(*types.Pointer); isPtr {
538 it = ptr.Elem()
539 }
540 namedType, ok := it.(*types.Named)
541 if !ok {
542 return nil
543 }
544
545 if basic, ok := namedType.Underlying().(*types.Basic); ok {
546 return basic
547 }
548
549 return nil
550 }
551
View as plain text