...

Source file src/github.com/99designs/gqlgen/plugin/federation/federation.go

Documentation: github.com/99designs/gqlgen/plugin/federation

     1  package federation
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  
    11  	"github.com/99designs/gqlgen/codegen"
    12  	"github.com/99designs/gqlgen/codegen/config"
    13  	"github.com/99designs/gqlgen/codegen/templates"
    14  	"github.com/99designs/gqlgen/internal/rewrite"
    15  	"github.com/99designs/gqlgen/plugin"
    16  	"github.com/99designs/gqlgen/plugin/federation/fieldset"
    17  )
    18  
    19  //go:embed federation.gotpl
    20  var federationTemplate string
    21  
    22  //go:embed requires.gotpl
    23  var explicitRequiresTemplate string
    24  
    25  type federation struct {
    26  	Entities       []*Entity
    27  	Version        int
    28  	PackageOptions map[string]bool
    29  }
    30  
    31  // New returns a federation plugin that injects
    32  // federated directives and types into the schema
    33  func New(version int) plugin.Plugin {
    34  	if version == 0 {
    35  		version = 1
    36  	}
    37  
    38  	return &federation{Version: version}
    39  }
    40  
    41  // Name returns the plugin name
    42  func (f *federation) Name() string {
    43  	return "federation"
    44  }
    45  
    46  // MutateConfig mutates the configuration
    47  func (f *federation) MutateConfig(cfg *config.Config) error {
    48  	builtins := config.TypeMap{
    49  		"_Service": {
    50  			Model: config.StringList{
    51  				"github.com/99designs/gqlgen/plugin/federation/fedruntime.Service",
    52  			},
    53  		},
    54  		"_Entity": {
    55  			Model: config.StringList{
    56  				"github.com/99designs/gqlgen/plugin/federation/fedruntime.Entity",
    57  			},
    58  		},
    59  		"Entity": {
    60  			Model: config.StringList{
    61  				"github.com/99designs/gqlgen/plugin/federation/fedruntime.Entity",
    62  			},
    63  		},
    64  		"_Any": {
    65  			Model: config.StringList{"github.com/99designs/gqlgen/graphql.Map"},
    66  		},
    67  		"federation__Scope": {
    68  			Model: config.StringList{"github.com/99designs/gqlgen/graphql.String"},
    69  		},
    70  		"federation__Policy": {
    71  			Model: config.StringList{"github.com/99designs/gqlgen/graphql.String"},
    72  		},
    73  	}
    74  
    75  	for typeName, entry := range builtins {
    76  		if cfg.Models.Exists(typeName) {
    77  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    78  		}
    79  		cfg.Models[typeName] = entry
    80  	}
    81  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    82  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    83  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    84  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    85  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    86  
    87  	// Federation 2 specific directives
    88  	if f.Version == 2 {
    89  		cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
    90  		cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
    91  		cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
    92  		cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
    93  		cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
    94  		cfg.Directives["authenticated"] = config.DirectiveConfig{SkipRuntime: true}
    95  		cfg.Directives["requiresScopes"] = config.DirectiveConfig{SkipRuntime: true}
    96  		cfg.Directives["policy"] = config.DirectiveConfig{SkipRuntime: true}
    97  		cfg.Directives["interfaceObject"] = config.DirectiveConfig{SkipRuntime: true}
    98  		cfg.Directives["composeDirective"] = config.DirectiveConfig{SkipRuntime: true}
    99  	}
   100  
   101  	return nil
   102  }
   103  
   104  func (f *federation) InjectSourceEarly() *ast.Source {
   105  	input := ``
   106  
   107  	// add version-specific changes on key directive, as well as adding the new directives for federation 2
   108  	if f.Version == 1 {
   109  		input += `
   110  	directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
   111  	directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
   112  	directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
   113  	directive @extends on OBJECT | INTERFACE
   114  	directive @external on FIELD_DEFINITION
   115  	scalar _Any
   116  	scalar _FieldSet
   117  `
   118  	} else if f.Version == 2 {
   119  		input += `
   120  	directive @authenticated on FIELD_DEFINITION | OBJECT | INTERFACE | SCALAR | ENUM
   121  	directive @composeDirective(name: String!) repeatable on SCHEMA
   122  	directive @extends on OBJECT | INTERFACE
   123  	directive @external on OBJECT | FIELD_DEFINITION
   124  	directive @key(fields: FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
   125  	directive @inaccessible on
   126  	  | ARGUMENT_DEFINITION
   127  	  | ENUM
   128  	  | ENUM_VALUE
   129  	  | FIELD_DEFINITION
   130  	  | INPUT_FIELD_DEFINITION
   131  	  | INPUT_OBJECT
   132  	  | INTERFACE
   133  	  | OBJECT
   134  	  | SCALAR
   135  	  | UNION
   136  	directive @interfaceObject on OBJECT
   137  	directive @link(import: [String!], url: String!) repeatable on SCHEMA
   138  	directive @override(from: String!, label: String) on FIELD_DEFINITION
   139  	directive @policy(policies: [[federation__Policy!]!]!) on 
   140  	  | FIELD_DEFINITION
   141  	  | OBJECT
   142  	  | INTERFACE
   143  	  | SCALAR
   144  	  | ENUM
   145  	directive @provides(fields: FieldSet!) on FIELD_DEFINITION
   146  	directive @requires(fields: FieldSet!) on FIELD_DEFINITION
   147  	directive @requiresScopes(scopes: [[federation__Scope!]!]!) on 
   148  	  | FIELD_DEFINITION
   149  	  | OBJECT
   150  	  | INTERFACE
   151  	  | SCALAR
   152  	  | ENUM
   153  	directive @shareable repeatable on FIELD_DEFINITION | OBJECT
   154  	directive @tag(name: String!) repeatable on
   155  	  | ARGUMENT_DEFINITION
   156  	  | ENUM
   157  	  | ENUM_VALUE
   158  	  | FIELD_DEFINITION
   159  	  | INPUT_FIELD_DEFINITION
   160  	  | INPUT_OBJECT
   161  	  | INTERFACE
   162  	  | OBJECT
   163  	  | SCALAR
   164  	  | UNION
   165  	scalar _Any
   166  	scalar FieldSet
   167  	scalar federation__Policy
   168  	scalar federation__Scope
   169  `
   170  	}
   171  	return &ast.Source{
   172  		Name:    "federation/directives.graphql",
   173  		Input:   input,
   174  		BuiltIn: true,
   175  	}
   176  }
   177  
   178  // InjectSourceLate creates a GraphQL Entity type with all
   179  // the fields that had the @key directive
   180  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
   181  	f.setEntities(schema)
   182  
   183  	var entities, resolvers, entityResolverInputDefinitions string
   184  	for _, e := range f.Entities {
   185  
   186  		if e.Def.Kind != ast.Interface {
   187  			if entities != "" {
   188  				entities += " | "
   189  			}
   190  			entities += e.Name
   191  		} else if len(schema.GetPossibleTypes(e.Def)) == 0 {
   192  			fmt.Println(
   193  				"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   194  			)
   195  		}
   196  
   197  		for _, r := range e.Resolvers {
   198  			if e.Multi {
   199  				if entityResolverInputDefinitions != "" {
   200  					entityResolverInputDefinitions += "\n\n"
   201  				}
   202  				entityResolverInputDefinitions += "input " + r.InputTypeName + " {\n"
   203  				for _, keyField := range r.KeyFields {
   204  					entityResolverInputDefinitions += fmt.Sprintf(
   205  						"\t%s: %s\n",
   206  						keyField.Field.ToGo(),
   207  						keyField.Definition.Type.String(),
   208  					)
   209  				}
   210  				entityResolverInputDefinitions += "}"
   211  				resolvers += fmt.Sprintf("\t%s(reps: [%s]!): [%s]\n", r.ResolverName, r.InputTypeName, e.Name)
   212  			} else {
   213  				resolverArgs := ""
   214  				for _, keyField := range r.KeyFields {
   215  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   216  				}
   217  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   218  			}
   219  		}
   220  	}
   221  
   222  	var blocks []string
   223  	if entities != "" {
   224  		entities = `# a union of all types that use the @key directive
   225  union _Entity = ` + entities
   226  		blocks = append(blocks, entities)
   227  	}
   228  
   229  	// resolvers can be empty if a service defines only "empty
   230  	// extend" types.  This should be rare.
   231  	if resolvers != "" {
   232  		if entityResolverInputDefinitions != "" {
   233  			blocks = append(blocks, entityResolverInputDefinitions)
   234  		}
   235  		resolvers = `# fake type to build resolver interfaces for users to implement
   236  type Entity {
   237  	` + resolvers + `
   238  }`
   239  		blocks = append(blocks, resolvers)
   240  	}
   241  
   242  	_serviceTypeDef := `type _Service {
   243    sdl: String
   244  }`
   245  	blocks = append(blocks, _serviceTypeDef)
   246  
   247  	var additionalQueryFields string
   248  	// Quote from the Apollo Federation subgraph specification:
   249  	// If no types are annotated with the key directive, then the
   250  	// _Entity union and _entities field should be removed from the schema
   251  	if len(f.Entities) > 0 {
   252  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   253  `
   254  	}
   255  	// _service field is required in any case
   256  	additionalQueryFields += `  _service: _Service!`
   257  
   258  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   259  ` + additionalQueryFields + `
   260  }`
   261  	blocks = append(blocks, extendTypeQueryDef)
   262  
   263  	return &ast.Source{
   264  		Name:    "federation/entity.graphql",
   265  		BuiltIn: true,
   266  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   267  	}
   268  }
   269  
   270  func (f *federation) GenerateCode(data *codegen.Data) error {
   271  	// requires imports
   272  	requiresImports := make(map[string]bool, 0)
   273  	requiresImports["context"] = true
   274  	requiresImports["fmt"] = true
   275  
   276  	requiresEntities := make(map[string]*Entity, 0)
   277  
   278  	// Save package options on f for template use
   279  	f.PackageOptions = data.Config.Federation.Options
   280  
   281  	if len(f.Entities) > 0 {
   282  		if data.Objects.ByName("Entity") != nil {
   283  			data.Objects.ByName("Entity").Root = true
   284  		}
   285  		for _, e := range f.Entities {
   286  			obj := data.Objects.ByName(e.Def.Name)
   287  
   288  			if e.Def.Kind == ast.Interface {
   289  				if len(data.Interfaces[e.Def.Name].Implementors) == 0 {
   290  					fmt.Println(
   291  						"skipping @key field on interface " + e.Def.Name + " as no types implement it",
   292  					)
   293  					continue
   294  				}
   295  				obj = data.Objects.ByName(data.Interfaces[e.Def.Name].Implementors[0].Name)
   296  			}
   297  
   298  			for _, r := range e.Resolvers {
   299  				// fill in types for key fields
   300  				//
   301  				for _, keyField := range r.KeyFields {
   302  					if len(keyField.Field) == 0 {
   303  						fmt.Println(
   304  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   305  						)
   306  						continue
   307  					}
   308  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   309  					keyField.Type = cgField.TypeReference
   310  				}
   311  			}
   312  
   313  			// fill in types for requires fields
   314  			//
   315  			for _, reqField := range e.Requires {
   316  				if len(reqField.Field) == 0 {
   317  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   318  					continue
   319  				}
   320  				// keep track of which entities have requires
   321  				requiresEntities[e.Def.Name] = e
   322  				// make a proper import path
   323  				typeString := strings.Split(obj.Type.String(), ".")
   324  				requiresImports[strings.Join(typeString[:len(typeString)-1], ".")] = true
   325  
   326  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   327  				reqField.Type = cgField.TypeReference
   328  			}
   329  
   330  			// add type info to entity
   331  			e.Type = obj.Type
   332  
   333  		}
   334  	}
   335  
   336  	// fill in types for resolver inputs
   337  	//
   338  	for _, entity := range f.Entities {
   339  		if !entity.Multi {
   340  			continue
   341  		}
   342  
   343  		for _, resolver := range entity.Resolvers {
   344  			obj := data.Inputs.ByName(resolver.InputTypeName)
   345  			if obj == nil {
   346  				return fmt.Errorf("input object %s not found", resolver.InputTypeName)
   347  			}
   348  
   349  			resolver.InputType = obj.Type
   350  		}
   351  	}
   352  
   353  	if data.Config.Federation.Options["explicit_requires"] && len(requiresEntities) > 0 {
   354  		// check for existing requires functions
   355  		type Populator struct {
   356  			FuncName       string
   357  			Exists         bool
   358  			Comment        string
   359  			Implementation string
   360  			Entity         *Entity
   361  		}
   362  		populators := make([]Populator, 0)
   363  
   364  		rewriter, err := rewrite.New(data.Config.Federation.Dir())
   365  		if err != nil {
   366  			return err
   367  		}
   368  
   369  		for name, entity := range requiresEntities {
   370  			populator := Populator{
   371  				FuncName: fmt.Sprintf("Populate%sRequires", name),
   372  				Entity:   entity,
   373  			}
   374  
   375  			populator.Comment = strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment("executionContext", populator.FuncName), `\`))
   376  			populator.Implementation = strings.TrimSpace(rewriter.GetMethodBody("executionContext", populator.FuncName))
   377  
   378  			if populator.Implementation == "" {
   379  				populator.Exists = false
   380  				populator.Implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v\"))", populator.FuncName)
   381  			}
   382  			populators = append(populators, populator)
   383  		}
   384  
   385  		sort.Slice(populators, func(i, j int) bool {
   386  			return populators[i].FuncName < populators[j].FuncName
   387  		})
   388  
   389  		requiresFile := data.Config.Federation.Dir() + "/federation.requires.go"
   390  		existingImports := rewriter.ExistingImports(requiresFile)
   391  		for _, imp := range existingImports {
   392  			if imp.Alias == "" {
   393  				// import exists in both places, remove
   394  				delete(requiresImports, imp.ImportPath)
   395  			}
   396  		}
   397  
   398  		for k := range requiresImports {
   399  			existingImports = append(existingImports, rewrite.Import{ImportPath: k})
   400  		}
   401  
   402  		// render requires populators
   403  		err = templates.Render(templates.Options{
   404  			PackageName: data.Config.Federation.Package,
   405  			Filename:    requiresFile,
   406  			Data: struct {
   407  				federation
   408  				ExistingImports []rewrite.Import
   409  				Populators      []Populator
   410  				OriginalSource  string
   411  			}{*f, existingImports, populators, ""},
   412  			GeneratedHeader: false,
   413  			Packages:        data.Config.Packages,
   414  			Template:        explicitRequiresTemplate,
   415  		})
   416  		if err != nil {
   417  			return err
   418  		}
   419  
   420  	}
   421  
   422  	return templates.Render(templates.Options{
   423  		PackageName: data.Config.Federation.Package,
   424  		Filename:    data.Config.Federation.Filename,
   425  		Data: struct {
   426  			federation
   427  			UsePointers bool
   428  		}{*f, data.Config.ResolversAlwaysReturnPointers},
   429  		GeneratedHeader: true,
   430  		Packages:        data.Config.Packages,
   431  		Template:        federationTemplate,
   432  	})
   433  }
   434  
   435  func (f *federation) setEntities(schema *ast.Schema) {
   436  	for _, schemaType := range schema.Types {
   437  		keys, ok := isFederatedEntity(schemaType)
   438  		if !ok {
   439  			continue
   440  		}
   441  
   442  		if (schemaType.Kind == ast.Interface) && (len(schema.GetPossibleTypes(schemaType)) == 0) {
   443  			fmt.Printf("@key directive found on unused \"interface %s\". Will be ignored.\n", schemaType.Name)
   444  			continue
   445  		}
   446  
   447  		e := &Entity{
   448  			Name:      schemaType.Name,
   449  			Def:       schemaType,
   450  			Resolvers: nil,
   451  			Requires:  nil,
   452  		}
   453  
   454  		// Let's process custom entity resolver settings.
   455  		dir := schemaType.Directives.ForName("entityResolver")
   456  		if dir != nil {
   457  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   458  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   459  					e.Multi = dirVal.(bool)
   460  				}
   461  			}
   462  		}
   463  
   464  		// If our schema has a field with a type defined in
   465  		// another service, then we need to define an "empty
   466  		// extend" of that type in this service, so this service
   467  		// knows what the type is like.  But the graphql-server
   468  		// will never ask us to actually resolve this "empty
   469  		// extend", so we don't require a resolver function for
   470  		// it.  (Well, it will never ask in practice; it's
   471  		// unclear whether the spec guarantees this.  See
   472  		// https://github.com/apollographql/apollo-server/issues/3852
   473  		// ).  Example:
   474  		//    type MyType {
   475  		//       myvar: TypeDefinedInOtherService
   476  		//    }
   477  		//    // Federation needs this type, but
   478  		//    // it doesn't need a resolver for it!
   479  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   480  		//       id: ID @external
   481  		//    }
   482  		if !e.allFieldsAreExternal(f.Version) {
   483  			for _, dir := range keys {
   484  				if len(dir.Arguments) > 2 {
   485  					panic("More than two arguments provided for @key declaration.")
   486  				}
   487  				var arg *ast.Argument
   488  
   489  				// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
   490  				for _, a := range dir.Arguments {
   491  					if a.Name == "fields" {
   492  						if arg != nil {
   493  							panic("More than one `fields` provided for @key declaration.")
   494  						}
   495  						arg = a
   496  					}
   497  				}
   498  
   499  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   500  
   501  				keyFields := make([]*KeyField, len(keyFieldSet))
   502  				resolverFields := []string{}
   503  				for i, field := range keyFieldSet {
   504  					def := field.FieldDefinition(schemaType, schema)
   505  
   506  					if def == nil {
   507  						panic(fmt.Sprintf("no field for %v", field))
   508  					}
   509  
   510  					keyFields[i] = &KeyField{Definition: def, Field: field}
   511  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   512  				}
   513  
   514  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   515  				var resolverName string
   516  				if e.Multi {
   517  					resolverFieldsToGo += "s" // Pluralize for better API readability
   518  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   519  				} else {
   520  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   521  				}
   522  
   523  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   524  					ResolverName:  resolverName,
   525  					KeyFields:     keyFields,
   526  					InputTypeName: resolverFieldsToGo + "Input",
   527  				})
   528  			}
   529  
   530  			e.Requires = []*Requires{}
   531  			for _, f := range schemaType.Fields {
   532  				dir := f.Directives.ForName("requires")
   533  				if dir == nil {
   534  					continue
   535  				}
   536  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   537  					panic("Exactly one `fields` argument needed for @requires declaration.")
   538  				}
   539  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   540  				for _, field := range requiresFieldSet {
   541  					e.Requires = append(e.Requires, &Requires{
   542  						Name:  field.ToGoPrivate(),
   543  						Field: field,
   544  					})
   545  				}
   546  			}
   547  		}
   548  		f.Entities = append(f.Entities, e)
   549  	}
   550  
   551  	// make sure order remains stable across multiple builds
   552  	sort.Slice(f.Entities, func(i, j int) bool {
   553  		return f.Entities[i].Name < f.Entities[j].Name
   554  	})
   555  }
   556  
   557  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   558  	switch schemaType.Kind {
   559  	case ast.Object:
   560  		keys := schemaType.Directives.ForNames("key")
   561  		if len(keys) > 0 {
   562  			return keys, true
   563  		}
   564  	case ast.Interface:
   565  		keys := schemaType.Directives.ForNames("key")
   566  		if len(keys) > 0 {
   567  			return keys, true
   568  		}
   569  
   570  		// TODO: support @extends for interfaces
   571  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   572  			panic(
   573  				fmt.Sprintf(
   574  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   575  					schemaType.Name,
   576  				))
   577  		}
   578  	default:
   579  		// ignore
   580  	}
   581  	return nil, false
   582  }
   583  

View as plain text