...

Source file src/github.com/vektah/gqlparser/v2/validator/walk.go

Documentation: github.com/vektah/gqlparser/v2/validator

     1  package validator
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/vektah/gqlparser/v2/ast"
     8  )
     9  
    10  type Events struct {
    11  	operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
    12  	field            []func(walker *Walker, field *ast.Field)
    13  	fragment         []func(walker *Walker, fragment *ast.FragmentDefinition)
    14  	inlineFragment   []func(walker *Walker, inlineFragment *ast.InlineFragment)
    15  	fragmentSpread   []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
    16  	directive        []func(walker *Walker, directive *ast.Directive)
    17  	directiveList    []func(walker *Walker, directives []*ast.Directive)
    18  	value            []func(walker *Walker, value *ast.Value)
    19  	variable         []func(walker *Walker, variable *ast.VariableDefinition)
    20  }
    21  
    22  func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
    23  	o.operationVisitor = append(o.operationVisitor, f)
    24  }
    25  func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
    26  	o.field = append(o.field, f)
    27  }
    28  func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
    29  	o.fragment = append(o.fragment, f)
    30  }
    31  func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
    32  	o.inlineFragment = append(o.inlineFragment, f)
    33  }
    34  func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
    35  	o.fragmentSpread = append(o.fragmentSpread, f)
    36  }
    37  func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
    38  	o.directive = append(o.directive, f)
    39  }
    40  func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
    41  	o.directiveList = append(o.directiveList, f)
    42  }
    43  func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
    44  	o.value = append(o.value, f)
    45  }
    46  func (o *Events) OnVariable(f func(walker *Walker, variable *ast.VariableDefinition)) {
    47  	o.variable = append(o.variable, f)
    48  }
    49  
    50  func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
    51  	w := Walker{
    52  		Observers: observers,
    53  		Schema:    schema,
    54  		Document:  document,
    55  	}
    56  
    57  	w.walk()
    58  }
    59  
    60  type Walker struct {
    61  	Context   context.Context
    62  	Observers *Events
    63  	Schema    *ast.Schema
    64  	Document  *ast.QueryDocument
    65  
    66  	validatedFragmentSpreads map[string]bool
    67  	CurrentOperation         *ast.OperationDefinition
    68  }
    69  
    70  func (w *Walker) walk() {
    71  	for _, child := range w.Document.Operations {
    72  		w.validatedFragmentSpreads = make(map[string]bool)
    73  		w.walkOperation(child)
    74  	}
    75  	for _, child := range w.Document.Fragments {
    76  		w.validatedFragmentSpreads = make(map[string]bool)
    77  		w.walkFragment(child)
    78  	}
    79  }
    80  
    81  func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
    82  	w.CurrentOperation = operation
    83  	for _, varDef := range operation.VariableDefinitions {
    84  		varDef.Definition = w.Schema.Types[varDef.Type.Name()]
    85  		for _, v := range w.Observers.variable {
    86  			v(w, varDef)
    87  		}
    88  		if varDef.DefaultValue != nil {
    89  			varDef.DefaultValue.ExpectedType = varDef.Type
    90  			varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
    91  		}
    92  	}
    93  
    94  	var def *ast.Definition
    95  	var loc ast.DirectiveLocation
    96  	switch operation.Operation {
    97  	case ast.Query, "":
    98  		def = w.Schema.Query
    99  		loc = ast.LocationQuery
   100  	case ast.Mutation:
   101  		def = w.Schema.Mutation
   102  		loc = ast.LocationMutation
   103  	case ast.Subscription:
   104  		def = w.Schema.Subscription
   105  		loc = ast.LocationSubscription
   106  	}
   107  
   108  	for _, varDef := range operation.VariableDefinitions {
   109  		if varDef.DefaultValue != nil {
   110  			w.walkValue(varDef.DefaultValue)
   111  		}
   112  		w.walkDirectives(varDef.Definition, varDef.Directives, ast.LocationVariableDefinition)
   113  	}
   114  
   115  	w.walkDirectives(def, operation.Directives, loc)
   116  	w.walkSelectionSet(def, operation.SelectionSet)
   117  
   118  	for _, v := range w.Observers.operationVisitor {
   119  		v(w, operation)
   120  	}
   121  	w.CurrentOperation = nil
   122  }
   123  
   124  func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
   125  	def := w.Schema.Types[it.TypeCondition]
   126  
   127  	it.Definition = def
   128  
   129  	w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
   130  	w.walkSelectionSet(def, it.SelectionSet)
   131  
   132  	for _, v := range w.Observers.fragment {
   133  		v(w, it)
   134  	}
   135  }
   136  
   137  func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
   138  	for _, dir := range directives {
   139  		def := w.Schema.Directives[dir.Name]
   140  		dir.Definition = def
   141  		dir.ParentDefinition = parentDef
   142  		dir.Location = location
   143  
   144  		for _, arg := range dir.Arguments {
   145  			var argDef *ast.ArgumentDefinition
   146  			if def != nil {
   147  				argDef = def.Arguments.ForName(arg.Name)
   148  			}
   149  
   150  			w.walkArgument(argDef, arg)
   151  		}
   152  
   153  		for _, v := range w.Observers.directive {
   154  			v(w, dir)
   155  		}
   156  	}
   157  
   158  	for _, v := range w.Observers.directiveList {
   159  		v(w, directives)
   160  	}
   161  }
   162  
   163  func (w *Walker) walkValue(value *ast.Value) {
   164  	if value.Kind == ast.Variable && w.CurrentOperation != nil {
   165  		value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
   166  		if value.VariableDefinition != nil {
   167  			value.VariableDefinition.Used = true
   168  		}
   169  	}
   170  
   171  	if value.Kind == ast.ObjectValue {
   172  		for _, child := range value.Children {
   173  			if value.Definition != nil {
   174  				fieldDef := value.Definition.Fields.ForName(child.Name)
   175  				if fieldDef != nil {
   176  					child.Value.ExpectedType = fieldDef.Type
   177  					child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
   178  				}
   179  			}
   180  			w.walkValue(child.Value)
   181  		}
   182  	}
   183  
   184  	if value.Kind == ast.ListValue {
   185  		for _, child := range value.Children {
   186  			if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
   187  				child.Value.ExpectedType = value.ExpectedType.Elem
   188  				child.Value.Definition = value.Definition
   189  			}
   190  
   191  			w.walkValue(child.Value)
   192  		}
   193  	}
   194  
   195  	for _, v := range w.Observers.value {
   196  		v(w, value)
   197  	}
   198  }
   199  
   200  func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
   201  	if argDef != nil {
   202  		arg.Value.ExpectedType = argDef.Type
   203  		arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
   204  	}
   205  
   206  	w.walkValue(arg.Value)
   207  }
   208  
   209  func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
   210  	for _, child := range it {
   211  		w.walkSelection(parentDef, child)
   212  	}
   213  }
   214  
   215  func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
   216  	switch it := it.(type) {
   217  	case *ast.Field:
   218  		var def *ast.FieldDefinition
   219  		if it.Name == "__typename" {
   220  			def = &ast.FieldDefinition{
   221  				Name: "__typename",
   222  				Type: ast.NamedType("String", nil),
   223  			}
   224  		} else if parentDef != nil {
   225  			def = parentDef.Fields.ForName(it.Name)
   226  		}
   227  
   228  		it.Definition = def
   229  		it.ObjectDefinition = parentDef
   230  
   231  		var nextParentDef *ast.Definition
   232  		if def != nil {
   233  			nextParentDef = w.Schema.Types[def.Type.Name()]
   234  		}
   235  
   236  		for _, arg := range it.Arguments {
   237  			var argDef *ast.ArgumentDefinition
   238  			if def != nil {
   239  				argDef = def.Arguments.ForName(arg.Name)
   240  			}
   241  
   242  			w.walkArgument(argDef, arg)
   243  		}
   244  
   245  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
   246  		w.walkSelectionSet(nextParentDef, it.SelectionSet)
   247  
   248  		for _, v := range w.Observers.field {
   249  			v(w, it)
   250  		}
   251  
   252  	case *ast.InlineFragment:
   253  		it.ObjectDefinition = parentDef
   254  
   255  		nextParentDef := parentDef
   256  		if it.TypeCondition != "" {
   257  			nextParentDef = w.Schema.Types[it.TypeCondition]
   258  		}
   259  
   260  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
   261  		w.walkSelectionSet(nextParentDef, it.SelectionSet)
   262  
   263  		for _, v := range w.Observers.inlineFragment {
   264  			v(w, it)
   265  		}
   266  
   267  	case *ast.FragmentSpread:
   268  		def := w.Document.Fragments.ForName(it.Name)
   269  		it.Definition = def
   270  		it.ObjectDefinition = parentDef
   271  
   272  		var nextParentDef *ast.Definition
   273  		if def != nil {
   274  			nextParentDef = w.Schema.Types[def.TypeCondition]
   275  		}
   276  
   277  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
   278  
   279  		if def != nil && !w.validatedFragmentSpreads[def.Name] {
   280  			// prevent infinite recursion
   281  			w.validatedFragmentSpreads[def.Name] = true
   282  			w.walkSelectionSet(nextParentDef, def.SelectionSet)
   283  		}
   284  
   285  		for _, v := range w.Observers.fragmentSpread {
   286  			v(w, it)
   287  		}
   288  
   289  	default:
   290  		panic(fmt.Errorf("unsupported %T", it))
   291  	}
   292  }
   293  

View as plain text