...

Source file src/github.com/vektah/gqlparser/validator/walk.go

Documentation: github.com/vektah/gqlparser/validator

     1  package validator
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/vektah/gqlparser/ast"
     8  )
     9  
    10  type Events struct {
    11  	operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
    12  	field            []func(walker *Walker, field *ast.Field)
    13  	fragment         []func(walker *Walker, fragment *ast.FragmentDefinition)
    14  	inlineFragment   []func(walker *Walker, inlineFragment *ast.InlineFragment)
    15  	fragmentSpread   []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
    16  	directive        []func(walker *Walker, directive *ast.Directive)
    17  	directiveList    []func(walker *Walker, directives []*ast.Directive)
    18  	value            []func(walker *Walker, value *ast.Value)
    19  }
    20  
    21  func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
    22  	o.operationVisitor = append(o.operationVisitor, f)
    23  }
    24  func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
    25  	o.field = append(o.field, f)
    26  }
    27  func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
    28  	o.fragment = append(o.fragment, f)
    29  }
    30  func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
    31  	o.inlineFragment = append(o.inlineFragment, f)
    32  }
    33  func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
    34  	o.fragmentSpread = append(o.fragmentSpread, f)
    35  }
    36  func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
    37  	o.directive = append(o.directive, f)
    38  }
    39  func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
    40  	o.directiveList = append(o.directiveList, f)
    41  }
    42  func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
    43  	o.value = append(o.value, f)
    44  }
    45  
    46  func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
    47  	w := Walker{
    48  		Observers: observers,
    49  		Schema:    schema,
    50  		Document:  document,
    51  	}
    52  
    53  	w.walk()
    54  }
    55  
    56  type Walker struct {
    57  	Context   context.Context
    58  	Observers *Events
    59  	Schema    *ast.Schema
    60  	Document  *ast.QueryDocument
    61  
    62  	validatedFragmentSpreads map[string]bool
    63  	CurrentOperation         *ast.OperationDefinition
    64  }
    65  
    66  func (w *Walker) walk() {
    67  	for _, child := range w.Document.Operations {
    68  		w.validatedFragmentSpreads = make(map[string]bool)
    69  		w.walkOperation(child)
    70  	}
    71  	for _, child := range w.Document.Fragments {
    72  		w.validatedFragmentSpreads = make(map[string]bool)
    73  		w.walkFragment(child)
    74  	}
    75  }
    76  
    77  func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
    78  	w.CurrentOperation = operation
    79  	for _, varDef := range operation.VariableDefinitions {
    80  		varDef.Definition = w.Schema.Types[varDef.Type.Name()]
    81  
    82  		if varDef.DefaultValue != nil {
    83  			varDef.DefaultValue.ExpectedType = varDef.Type
    84  			varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
    85  		}
    86  	}
    87  
    88  	var def *ast.Definition
    89  	var loc ast.DirectiveLocation
    90  	switch operation.Operation {
    91  	case ast.Query, "":
    92  		def = w.Schema.Query
    93  		loc = ast.LocationQuery
    94  	case ast.Mutation:
    95  		def = w.Schema.Mutation
    96  		loc = ast.LocationMutation
    97  	case ast.Subscription:
    98  		def = w.Schema.Subscription
    99  		loc = ast.LocationSubscription
   100  	}
   101  
   102  	w.walkDirectives(def, operation.Directives, loc)
   103  
   104  	for _, varDef := range operation.VariableDefinitions {
   105  		if varDef.DefaultValue != nil {
   106  			w.walkValue(varDef.DefaultValue)
   107  		}
   108  	}
   109  
   110  	w.walkSelectionSet(def, operation.SelectionSet)
   111  
   112  	for _, v := range w.Observers.operationVisitor {
   113  		v(w, operation)
   114  	}
   115  	w.CurrentOperation = nil
   116  }
   117  
   118  func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
   119  	def := w.Schema.Types[it.TypeCondition]
   120  
   121  	it.Definition = def
   122  
   123  	w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
   124  	w.walkSelectionSet(def, it.SelectionSet)
   125  
   126  	for _, v := range w.Observers.fragment {
   127  		v(w, it)
   128  	}
   129  }
   130  
   131  func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
   132  	for _, dir := range directives {
   133  		def := w.Schema.Directives[dir.Name]
   134  		dir.Definition = def
   135  		dir.ParentDefinition = parentDef
   136  		dir.Location = location
   137  
   138  		for _, arg := range dir.Arguments {
   139  			var argDef *ast.ArgumentDefinition
   140  			if def != nil {
   141  				argDef = def.Arguments.ForName(arg.Name)
   142  			}
   143  
   144  			w.walkArgument(argDef, arg)
   145  		}
   146  
   147  		for _, v := range w.Observers.directive {
   148  			v(w, dir)
   149  		}
   150  	}
   151  
   152  	for _, v := range w.Observers.directiveList {
   153  		v(w, directives)
   154  	}
   155  }
   156  
   157  func (w *Walker) walkValue(value *ast.Value) {
   158  	if value.Kind == ast.Variable && w.CurrentOperation != nil {
   159  		value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
   160  		if value.VariableDefinition != nil {
   161  			value.VariableDefinition.Used = true
   162  		}
   163  	}
   164  
   165  	if value.Kind == ast.ObjectValue {
   166  		for _, child := range value.Children {
   167  			if value.Definition != nil {
   168  				fieldDef := value.Definition.Fields.ForName(child.Name)
   169  				if fieldDef != nil {
   170  					child.Value.ExpectedType = fieldDef.Type
   171  					child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
   172  				}
   173  			}
   174  			w.walkValue(child.Value)
   175  		}
   176  	}
   177  
   178  	if value.Kind == ast.ListValue {
   179  		for _, child := range value.Children {
   180  			if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
   181  				child.Value.ExpectedType = value.ExpectedType.Elem
   182  				child.Value.Definition = value.Definition
   183  			}
   184  
   185  			w.walkValue(child.Value)
   186  		}
   187  	}
   188  
   189  	for _, v := range w.Observers.value {
   190  		v(w, value)
   191  	}
   192  }
   193  
   194  func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
   195  	if argDef != nil {
   196  		arg.Value.ExpectedType = argDef.Type
   197  		arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
   198  	}
   199  
   200  	w.walkValue(arg.Value)
   201  }
   202  
   203  func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
   204  	for _, child := range it {
   205  		w.walkSelection(parentDef, child)
   206  	}
   207  }
   208  
   209  func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
   210  	switch it := it.(type) {
   211  	case *ast.Field:
   212  		var def *ast.FieldDefinition
   213  		if it.Name == "__typename" {
   214  			def = &ast.FieldDefinition{
   215  				Name: "__typename",
   216  				Type: ast.NamedType("String", nil),
   217  			}
   218  		} else if parentDef != nil {
   219  			def = parentDef.Fields.ForName(it.Name)
   220  		}
   221  
   222  		it.Definition = def
   223  		it.ObjectDefinition = parentDef
   224  
   225  		var nextParentDef *ast.Definition
   226  		if def != nil {
   227  			nextParentDef = w.Schema.Types[def.Type.Name()]
   228  		}
   229  
   230  		for _, arg := range it.Arguments {
   231  			var argDef *ast.ArgumentDefinition
   232  			if def != nil {
   233  				argDef = def.Arguments.ForName(arg.Name)
   234  			}
   235  
   236  			w.walkArgument(argDef, arg)
   237  		}
   238  
   239  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
   240  		w.walkSelectionSet(nextParentDef, it.SelectionSet)
   241  
   242  		for _, v := range w.Observers.field {
   243  			v(w, it)
   244  		}
   245  
   246  	case *ast.InlineFragment:
   247  		it.ObjectDefinition = parentDef
   248  
   249  		nextParentDef := parentDef
   250  		if it.TypeCondition != "" {
   251  			nextParentDef = w.Schema.Types[it.TypeCondition]
   252  		}
   253  
   254  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
   255  		w.walkSelectionSet(nextParentDef, it.SelectionSet)
   256  
   257  		for _, v := range w.Observers.inlineFragment {
   258  			v(w, it)
   259  		}
   260  
   261  	case *ast.FragmentSpread:
   262  		def := w.Document.Fragments.ForName(it.Name)
   263  		it.Definition = def
   264  		it.ObjectDefinition = parentDef
   265  
   266  		var nextParentDef *ast.Definition
   267  		if def != nil {
   268  			nextParentDef = w.Schema.Types[def.TypeCondition]
   269  		}
   270  
   271  		w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
   272  
   273  		if def != nil && !w.validatedFragmentSpreads[def.Name] {
   274  			// prevent inifinite recursion
   275  			w.validatedFragmentSpreads[def.Name] = true
   276  			w.walkSelectionSet(nextParentDef, def.SelectionSet)
   277  		}
   278  
   279  		for _, v := range w.Observers.fragmentSpread {
   280  			v(w, it)
   281  		}
   282  
   283  	default:
   284  		panic(fmt.Errorf("unsupported %T", it))
   285  	}
   286  }
   287  

View as plain text