1 package federation
2
3 import (
4 _ "embed"
5 "fmt"
6 "sort"
7 "strings"
8
9 "github.com/vektah/gqlparser/v2/ast"
10
11 "github.com/99designs/gqlgen/codegen"
12 "github.com/99designs/gqlgen/codegen/config"
13 "github.com/99designs/gqlgen/codegen/templates"
14 "github.com/99designs/gqlgen/internal/rewrite"
15 "github.com/99designs/gqlgen/plugin"
16 "github.com/99designs/gqlgen/plugin/federation/fieldset"
17 )
18
19
20 var federationTemplate string
21
22
23 var explicitRequiresTemplate string
24
25 type federation struct {
26 Entities []*Entity
27 Version int
28 PackageOptions map[string]bool
29 }
30
31
32
33 func New(version int) plugin.Plugin {
34 if version == 0 {
35 version = 1
36 }
37
38 return &federation{Version: version}
39 }
40
41
42 func (f *federation) Name() string {
43 return "federation"
44 }
45
46
47 func (f *federation) MutateConfig(cfg *config.Config) error {
48 builtins := config.TypeMap{
49 "_Service": {
50 Model: config.StringList{
51 "github.com/99designs/gqlgen/plugin/federation/fedruntime.Service",
52 },
53 },
54 "_Entity": {
55 Model: config.StringList{
56 "github.com/99designs/gqlgen/plugin/federation/fedruntime.Entity",
57 },
58 },
59 "Entity": {
60 Model: config.StringList{
61 "github.com/99designs/gqlgen/plugin/federation/fedruntime.Entity",
62 },
63 },
64 "_Any": {
65 Model: config.StringList{"github.com/99designs/gqlgen/graphql.Map"},
66 },
67 "federation__Scope": {
68 Model: config.StringList{"github.com/99designs/gqlgen/graphql.String"},
69 },
70 "federation__Policy": {
71 Model: config.StringList{"github.com/99designs/gqlgen/graphql.String"},
72 },
73 }
74
75 for typeName, entry := range builtins {
76 if cfg.Models.Exists(typeName) {
77 return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
78 }
79 cfg.Models[typeName] = entry
80 }
81 cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
82 cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
83 cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
84 cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
85 cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
86
87
88 if f.Version == 2 {
89 cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
90 cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
91 cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
92 cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
93 cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
94 cfg.Directives["authenticated"] = config.DirectiveConfig{SkipRuntime: true}
95 cfg.Directives["requiresScopes"] = config.DirectiveConfig{SkipRuntime: true}
96 cfg.Directives["policy"] = config.DirectiveConfig{SkipRuntime: true}
97 cfg.Directives["interfaceObject"] = config.DirectiveConfig{SkipRuntime: true}
98 cfg.Directives["composeDirective"] = config.DirectiveConfig{SkipRuntime: true}
99 }
100
101 return nil
102 }
103
104 func (f *federation) InjectSourceEarly() *ast.Source {
105 input := ``
106
107
108 if f.Version == 1 {
109 input += `
110 directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
111 directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
112 directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
113 directive @extends on OBJECT | INTERFACE
114 directive @external on FIELD_DEFINITION
115 scalar _Any
116 scalar _FieldSet
117 `
118 } else if f.Version == 2 {
119 input += `
120 directive @authenticated on FIELD_DEFINITION | OBJECT | INTERFACE | SCALAR | ENUM
121 directive @composeDirective(name: String!) repeatable on SCHEMA
122 directive @extends on OBJECT | INTERFACE
123 directive @external on OBJECT | FIELD_DEFINITION
124 directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
125 directive @inaccessible on
126 | ARGUMENT_DEFINITION
127 | ENUM
128 | ENUM_VALUE
129 | FIELD_DEFINITION
130 | INPUT_FIELD_DEFINITION
131 | INPUT_OBJECT
132 | INTERFACE
133 | OBJECT
134 | SCALAR
135 | UNION
136 directive @interfaceObject on OBJECT
137 directive @link(import: [String!], url: String!) repeatable on SCHEMA
138 directive @override(from: String!, label: String) on FIELD_DEFINITION
139 directive @policy(policies: [[federation__Policy!]!]!) on
140 | FIELD_DEFINITION
141 | OBJECT
142 | INTERFACE
143 | SCALAR
144 | ENUM
145 directive @provides(fields: FieldSet!) on FIELD_DEFINITION
146 directive @requires(fields: FieldSet!) on FIELD_DEFINITION
147 directive @requiresScopes(scopes: [[federation__Scope!]!]!) on
148 | FIELD_DEFINITION
149 | OBJECT
150 | INTERFACE
151 | SCALAR
152 | ENUM
153 directive @shareable repeatable on FIELD_DEFINITION | OBJECT
154 directive @tag(name: String!) repeatable on
155 | ARGUMENT_DEFINITION
156 | ENUM
157 | ENUM_VALUE
158 | FIELD_DEFINITION
159 | INPUT_FIELD_DEFINITION
160 | INPUT_OBJECT
161 | INTERFACE
162 | OBJECT
163 | SCALAR
164 | UNION
165 scalar _Any
166 scalar FieldSet
167 scalar federation__Policy
168 scalar federation__Scope
169 `
170 }
171 return &ast.Source{
172 Name: "federation/directives.graphql",
173 Input: input,
174 BuiltIn: true,
175 }
176 }
177
178
179
180 func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
181 f.setEntities(schema)
182
183 var entities, resolvers, entityResolverInputDefinitions string
184 for _, e := range f.Entities {
185
186 if e.Def.Kind != ast.Interface {
187 if entities != "" {
188 entities += " | "
189 }
190 entities += e.Name
191 } else if len(schema.GetPossibleTypes(e.Def)) == 0 {
192 fmt.Println(
193 "skipping @key field on interface " + e.Def.Name + " as no types implement it",
194 )
195 }
196
197 for _, r := range e.Resolvers {
198 if e.Multi {
199 if entityResolverInputDefinitions != "" {
200 entityResolverInputDefinitions += "\n\n"
201 }
202 entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n"
203 for _, keyField := range r.KeyFields {
204 entityResolverInputDefinitions += fmt.Sprintf(
205 "\t%s: %s\n",
206 keyField.Field.ToGo(),
207 keyField.Definition.Type.String(),
208 )
209 }
210 entityResolverInputDefinitions += "}"
211 resolvers += fmt.Sprintf("\t%s(reps: [%s]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name)
212 } else {
213 resolverArgs := ""
214 for _, keyField := range r.KeyFields {
215 resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
216 }
217 resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
218 }
219 }
220 }
221
222 var blocks []string
223 if entities != "" {
224 entities = `# a union of all types that use the @key directive
225 union _Entity = ` + entities
226 blocks = append(blocks, entities)
227 }
228
229
230
231 if resolvers != "" {
232 if entityResolverInputDefinitions != "" {
233 blocks = append(blocks, entityResolverInputDefinitions)
234 }
235 resolvers = `# fake type to build resolver interfaces for users to implement
236 type Entity {
237 ` + resolvers + `
238 }`
239 blocks = append(blocks, resolvers)
240 }
241
242 _serviceTypeDef := `type _Service {
243 sdl: String
244 }`
245 blocks = append(blocks, _serviceTypeDef)
246
247 var additionalQueryFields string
248
249
250
251 if len(f.Entities) > 0 {
252 additionalQueryFields += ` _entities(representations: [_Any!]!): [_Entity]!
253 `
254 }
255
256 additionalQueryFields += ` _service: _Service!`
257
258 extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
259 ` + additionalQueryFields + `
260 }`
261 blocks = append(blocks, extendTypeQueryDef)
262
263 return &ast.Source{
264 Name: "federation/entity.graphql",
265 BuiltIn: true,
266 Input: "\n" + strings.Join(blocks, "\n\n") + "\n",
267 }
268 }
269
270 func (f *federation) GenerateCode(data *codegen.Data) error {
271
272 requiresImports := make(map[string]bool, 0)
273 requiresImports["context"] = true
274 requiresImports["fmt"] = true
275
276 requiresEntities := make(map[string]*Entity, 0)
277
278
279 f.PackageOptions = data.Config.Federation.Options
280
281 if len(f.Entities) > 0 {
282 if data.Objects.ByName("Entity") != nil {
283 data.Objects.ByName("Entity").Root = true
284 }
285 for _, e := range f.Entities {
286 obj := data.Objects.ByName(e.Def.Name)
287
288 if e.Def.Kind == ast.Interface {
289 if len(data.Interfaces[e.Def.Name].Implementors) == 0 {
290 fmt.Println(
291 "skipping @key field on interface " + e.Def.Name + " as no types implement it",
292 )
293 continue
294 }
295 obj = data.Objects.ByName(data.Interfaces[e.Def.Name].Implementors[0].Name)
296 }
297
298 for _, r := range e.Resolvers {
299
300
301 for _, keyField := range r.KeyFields {
302 if len(keyField.Field) == 0 {
303 fmt.Println(
304 "skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
305 )
306 continue
307 }
308 cgField := keyField.Field.TypeReference(obj, data.Objects)
309 keyField.Type = cgField.TypeReference
310 }
311 }
312
313
314
315 for _, reqField := range e.Requires {
316 if len(reqField.Field) == 0 {
317 fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
318 continue
319 }
320
321 requiresEntities[e.Def.Name] = e
322
323 typeString := strings.Split(obj.Type.String(), ".")
324 requiresImports[strings.Join(typeString[:len(typeString)-1], ".")] = true
325
326 cgField := reqField.Field.TypeReference(obj, data.Objects)
327 reqField.Type = cgField.TypeReference
328 }
329
330
331 e.Type = obj.Type
332
333 }
334 }
335
336
337
338 for _, entity := range f.Entities {
339 if !entity.Multi {
340 continue
341 }
342
343 for _, resolver := range entity.Resolvers {
344 obj := data.Inputs.ByName(resolver.InputTypeName)
345 if obj == nil {
346 return fmt.Errorf("input object %s not found", resolver.InputTypeName)
347 }
348
349 resolver.InputType = obj.Type
350 }
351 }
352
353 if data.Config.Federation.Options["explicit_requires"] && len(requiresEntities) > 0 {
354
355 type Populator struct {
356 FuncName string
357 Exists bool
358 Comment string
359 Implementation string
360 Entity *Entity
361 }
362 populators := make([]Populator, 0)
363
364 rewriter, err := rewrite.New(data.Config.Federation.Dir())
365 if err != nil {
366 return err
367 }
368
369 for name, entity := range requiresEntities {
370 populator := Populator{
371 FuncName: fmt.Sprintf("Populate%sRequires", name),
372 Entity: entity,
373 }
374
375 populator.Comment = strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment("executionContext", populator.FuncName), `\`))
376 populator.Implementation = strings.TrimSpace(rewriter.GetMethodBody("executionContext", populator.FuncName))
377
378 if populator.Implementation == "" {
379 populator.Exists = false
380 populator.Implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v\"))", populator.FuncName)
381 }
382 populators = append(populators, populator)
383 }
384
385 sort.Slice(populators, func(i, j int) bool {
386 return populators[i].FuncName < populators[j].FuncName
387 })
388
389 requiresFile := data.Config.Federation.Dir() + "/federation.requires.go"
390 existingImports := rewriter.ExistingImports(requiresFile)
391 for _, imp := range existingImports {
392 if imp.Alias == "" {
393
394 delete(requiresImports, imp.ImportPath)
395 }
396 }
397
398 for k := range requiresImports {
399 existingImports = append(existingImports, rewrite.Import{ImportPath: k})
400 }
401
402
403 err = templates.Render(templates.Options{
404 PackageName: data.Config.Federation.Package,
405 Filename: requiresFile,
406 Data: struct {
407 federation
408 ExistingImports []rewrite.Import
409 Populators []Populator
410 OriginalSource string
411 }{*f, existingImports, populators, ""},
412 GeneratedHeader: false,
413 Packages: data.Config.Packages,
414 Template: explicitRequiresTemplate,
415 })
416 if err != nil {
417 return err
418 }
419
420 }
421
422 return templates.Render(templates.Options{
423 PackageName: data.Config.Federation.Package,
424 Filename: data.Config.Federation.Filename,
425 Data: struct {
426 federation
427 UsePointers bool
428 }{*f, data.Config.ResolversAlwaysReturnPointers},
429 GeneratedHeader: true,
430 Packages: data.Config.Packages,
431 Template: federationTemplate,
432 })
433 }
434
435 func (f *federation) setEntities(schema *ast.Schema) {
436 for _, schemaType := range schema.Types {
437 keys, ok := isFederatedEntity(schemaType)
438 if !ok {
439 continue
440 }
441
442 if (schemaType.Kind == ast.Interface) && (len(schema.GetPossibleTypes(schemaType)) == 0) {
443 fmt.Printf("@key directive found on unused \"interface %s\". Will be ignored.\n", schemaType.Name)
444 continue
445 }
446
447 e := &Entity{
448 Name: schemaType.Name,
449 Def: schemaType,
450 Resolvers: nil,
451 Requires: nil,
452 }
453
454
455 dir := schemaType.Directives.ForName("entityResolver")
456 if dir != nil {
457 if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
458 if dirVal, err := dirArg.Value.Value(nil); err == nil {
459 e.Multi = dirVal.(bool)
460 }
461 }
462 }
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482 if !e.allFieldsAreExternal(f.Version) {
483 for _, dir := range keys {
484 if len(dir.Arguments) > 2 {
485 panic("More than two arguments provided for @key declaration.")
486 }
487 var arg *ast.Argument
488
489
490 for _, a := range dir.Arguments {
491 if a.Name == "fields" {
492 if arg != nil {
493 panic("More than one `fields` provided for @key declaration.")
494 }
495 arg = a
496 }
497 }
498
499 keyFieldSet := fieldset.New(arg.Value.Raw, nil)
500
501 keyFields := make([]*KeyField, len(keyFieldSet))
502 resolverFields := []string{}
503 for i, field := range keyFieldSet {
504 def := field.FieldDefinition(schemaType, schema)
505
506 if def == nil {
507 panic(fmt.Sprintf("no field for %v", field))
508 }
509
510 keyFields[i] = &KeyField{Definition: def, Field: field}
511 resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
512 }
513
514 resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
515 var resolverName string
516 if e.Multi {
517 resolverFieldsToGo += "s"
518 resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
519 } else {
520 resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
521 }
522
523 e.Resolvers = append(e.Resolvers, &EntityResolver{
524 ResolverName: resolverName,
525 KeyFields: keyFields,
526 InputTypeName: resolverFieldsToGo + "Input",
527 })
528 }
529
530 e.Requires = []*Requires{}
531 for _, f := range schemaType.Fields {
532 dir := f.Directives.ForName("requires")
533 if dir == nil {
534 continue
535 }
536 if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
537 panic("Exactly one `fields` argument needed for @requires declaration.")
538 }
539 requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
540 for _, field := range requiresFieldSet {
541 e.Requires = append(e.Requires, &Requires{
542 Name: field.ToGoPrivate(),
543 Field: field,
544 })
545 }
546 }
547 }
548 f.Entities = append(f.Entities, e)
549 }
550
551
552 sort.Slice(f.Entities, func(i, j int) bool {
553 return f.Entities[i].Name < f.Entities[j].Name
554 })
555 }
556
557 func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
558 switch schemaType.Kind {
559 case ast.Object:
560 keys := schemaType.Directives.ForNames("key")
561 if len(keys) > 0 {
562 return keys, true
563 }
564 case ast.Interface:
565 keys := schemaType.Directives.ForNames("key")
566 if len(keys) > 0 {
567 return keys, true
568 }
569
570
571 if dir := schemaType.Directives.ForName("extends"); dir != nil {
572 panic(
573 fmt.Sprintf(
574 "@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
575 schemaType.Name,
576 ))
577 }
578 default:
579
580 }
581 return nil, false
582 }
583
View as plain text