1 package validator
2
3 import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/ast"
8 )
9
10 type Events struct {
11 operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
12 field []func(walker *Walker, field *ast.Field)
13 fragment []func(walker *Walker, fragment *ast.FragmentDefinition)
14 inlineFragment []func(walker *Walker, inlineFragment *ast.InlineFragment)
15 fragmentSpread []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
16 directive []func(walker *Walker, directive *ast.Directive)
17 directiveList []func(walker *Walker, directives []*ast.Directive)
18 value []func(walker *Walker, value *ast.Value)
19 }
20
21 func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
22 o.operationVisitor = append(o.operationVisitor, f)
23 }
24 func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
25 o.field = append(o.field, f)
26 }
27 func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
28 o.fragment = append(o.fragment, f)
29 }
30 func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
31 o.inlineFragment = append(o.inlineFragment, f)
32 }
33 func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
34 o.fragmentSpread = append(o.fragmentSpread, f)
35 }
36 func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
37 o.directive = append(o.directive, f)
38 }
39 func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
40 o.directiveList = append(o.directiveList, f)
41 }
42 func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
43 o.value = append(o.value, f)
44 }
45
46 func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
47 w := Walker{
48 Observers: observers,
49 Schema: schema,
50 Document: document,
51 }
52
53 w.walk()
54 }
55
56 type Walker struct {
57 Context context.Context
58 Observers *Events
59 Schema *ast.Schema
60 Document *ast.QueryDocument
61
62 validatedFragmentSpreads map[string]bool
63 CurrentOperation *ast.OperationDefinition
64 }
65
66 func (w *Walker) walk() {
67 for _, child := range w.Document.Operations {
68 w.validatedFragmentSpreads = make(map[string]bool)
69 w.walkOperation(child)
70 }
71 for _, child := range w.Document.Fragments {
72 w.validatedFragmentSpreads = make(map[string]bool)
73 w.walkFragment(child)
74 }
75 }
76
77 func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
78 w.CurrentOperation = operation
79 for _, varDef := range operation.VariableDefinitions {
80 varDef.Definition = w.Schema.Types[varDef.Type.Name()]
81
82 if varDef.DefaultValue != nil {
83 varDef.DefaultValue.ExpectedType = varDef.Type
84 varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
85 }
86 }
87
88 var def *ast.Definition
89 var loc ast.DirectiveLocation
90 switch operation.Operation {
91 case ast.Query, "":
92 def = w.Schema.Query
93 loc = ast.LocationQuery
94 case ast.Mutation:
95 def = w.Schema.Mutation
96 loc = ast.LocationMutation
97 case ast.Subscription:
98 def = w.Schema.Subscription
99 loc = ast.LocationSubscription
100 }
101
102 w.walkDirectives(def, operation.Directives, loc)
103
104 for _, varDef := range operation.VariableDefinitions {
105 if varDef.DefaultValue != nil {
106 w.walkValue(varDef.DefaultValue)
107 }
108 }
109
110 w.walkSelectionSet(def, operation.SelectionSet)
111
112 for _, v := range w.Observers.operationVisitor {
113 v(w, operation)
114 }
115 w.CurrentOperation = nil
116 }
117
118 func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
119 def := w.Schema.Types[it.TypeCondition]
120
121 it.Definition = def
122
123 w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
124 w.walkSelectionSet(def, it.SelectionSet)
125
126 for _, v := range w.Observers.fragment {
127 v(w, it)
128 }
129 }
130
131 func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
132 for _, dir := range directives {
133 def := w.Schema.Directives[dir.Name]
134 dir.Definition = def
135 dir.ParentDefinition = parentDef
136 dir.Location = location
137
138 for _, arg := range dir.Arguments {
139 var argDef *ast.ArgumentDefinition
140 if def != nil {
141 argDef = def.Arguments.ForName(arg.Name)
142 }
143
144 w.walkArgument(argDef, arg)
145 }
146
147 for _, v := range w.Observers.directive {
148 v(w, dir)
149 }
150 }
151
152 for _, v := range w.Observers.directiveList {
153 v(w, directives)
154 }
155 }
156
157 func (w *Walker) walkValue(value *ast.Value) {
158 if value.Kind == ast.Variable && w.CurrentOperation != nil {
159 value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
160 if value.VariableDefinition != nil {
161 value.VariableDefinition.Used = true
162 }
163 }
164
165 if value.Kind == ast.ObjectValue {
166 for _, child := range value.Children {
167 if value.Definition != nil {
168 fieldDef := value.Definition.Fields.ForName(child.Name)
169 if fieldDef != nil {
170 child.Value.ExpectedType = fieldDef.Type
171 child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
172 }
173 }
174 w.walkValue(child.Value)
175 }
176 }
177
178 if value.Kind == ast.ListValue {
179 for _, child := range value.Children {
180 if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
181 child.Value.ExpectedType = value.ExpectedType.Elem
182 child.Value.Definition = value.Definition
183 }
184
185 w.walkValue(child.Value)
186 }
187 }
188
189 for _, v := range w.Observers.value {
190 v(w, value)
191 }
192 }
193
194 func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
195 if argDef != nil {
196 arg.Value.ExpectedType = argDef.Type
197 arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
198 }
199
200 w.walkValue(arg.Value)
201 }
202
203 func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
204 for _, child := range it {
205 w.walkSelection(parentDef, child)
206 }
207 }
208
209 func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
210 switch it := it.(type) {
211 case *ast.Field:
212 var def *ast.FieldDefinition
213 if it.Name == "__typename" {
214 def = &ast.FieldDefinition{
215 Name: "__typename",
216 Type: ast.NamedType("String", nil),
217 }
218 } else if parentDef != nil {
219 def = parentDef.Fields.ForName(it.Name)
220 }
221
222 it.Definition = def
223 it.ObjectDefinition = parentDef
224
225 var nextParentDef *ast.Definition
226 if def != nil {
227 nextParentDef = w.Schema.Types[def.Type.Name()]
228 }
229
230 for _, arg := range it.Arguments {
231 var argDef *ast.ArgumentDefinition
232 if def != nil {
233 argDef = def.Arguments.ForName(arg.Name)
234 }
235
236 w.walkArgument(argDef, arg)
237 }
238
239 w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
240 w.walkSelectionSet(nextParentDef, it.SelectionSet)
241
242 for _, v := range w.Observers.field {
243 v(w, it)
244 }
245
246 case *ast.InlineFragment:
247 it.ObjectDefinition = parentDef
248
249 nextParentDef := parentDef
250 if it.TypeCondition != "" {
251 nextParentDef = w.Schema.Types[it.TypeCondition]
252 }
253
254 w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
255 w.walkSelectionSet(nextParentDef, it.SelectionSet)
256
257 for _, v := range w.Observers.inlineFragment {
258 v(w, it)
259 }
260
261 case *ast.FragmentSpread:
262 def := w.Document.Fragments.ForName(it.Name)
263 it.Definition = def
264 it.ObjectDefinition = parentDef
265
266 var nextParentDef *ast.Definition
267 if def != nil {
268 nextParentDef = w.Schema.Types[def.TypeCondition]
269 }
270
271 w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
272
273 if def != nil && !w.validatedFragmentSpreads[def.Name] {
274
275 w.validatedFragmentSpreads[def.Name] = true
276 w.walkSelectionSet(nextParentDef, def.SelectionSet)
277 }
278
279 for _, v := range w.Observers.fragmentSpread {
280 v(w, it)
281 }
282
283 default:
284 panic(fmt.Errorf("unsupported %T", it))
285 }
286 }
287
View as plain text