1 package validator
2
3 import (
4 "context"
5 "fmt"
6
7 "github.com/vektah/gqlparser/v2/ast"
8 )
9
10 type Events struct {
11 operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
12 field []func(walker *Walker, field *ast.Field)
13 fragment []func(walker *Walker, fragment *ast.FragmentDefinition)
14 inlineFragment []func(walker *Walker, inlineFragment *ast.InlineFragment)
15 fragmentSpread []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
16 directive []func(walker *Walker, directive *ast.Directive)
17 directiveList []func(walker *Walker, directives []*ast.Directive)
18 value []func(walker *Walker, value *ast.Value)
19 variable []func(walker *Walker, variable *ast.VariableDefinition)
20 }
21
22 func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
23 o.operationVisitor = append(o.operationVisitor, f)
24 }
25 func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
26 o.field = append(o.field, f)
27 }
28 func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
29 o.fragment = append(o.fragment, f)
30 }
31 func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
32 o.inlineFragment = append(o.inlineFragment, f)
33 }
34 func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
35 o.fragmentSpread = append(o.fragmentSpread, f)
36 }
37 func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
38 o.directive = append(o.directive, f)
39 }
40 func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
41 o.directiveList = append(o.directiveList, f)
42 }
43 func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
44 o.value = append(o.value, f)
45 }
46 func (o *Events) OnVariable(f func(walker *Walker, variable *ast.VariableDefinition)) {
47 o.variable = append(o.variable, f)
48 }
49
50 func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
51 w := Walker{
52 Observers: observers,
53 Schema: schema,
54 Document: document,
55 }
56
57 w.walk()
58 }
59
60 type Walker struct {
61 Context context.Context
62 Observers *Events
63 Schema *ast.Schema
64 Document *ast.QueryDocument
65
66 validatedFragmentSpreads map[string]bool
67 CurrentOperation *ast.OperationDefinition
68 }
69
70 func (w *Walker) walk() {
71 for _, child := range w.Document.Operations {
72 w.validatedFragmentSpreads = make(map[string]bool)
73 w.walkOperation(child)
74 }
75 for _, child := range w.Document.Fragments {
76 w.validatedFragmentSpreads = make(map[string]bool)
77 w.walkFragment(child)
78 }
79 }
80
81 func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
82 w.CurrentOperation = operation
83 for _, varDef := range operation.VariableDefinitions {
84 varDef.Definition = w.Schema.Types[varDef.Type.Name()]
85 for _, v := range w.Observers.variable {
86 v(w, varDef)
87 }
88 if varDef.DefaultValue != nil {
89 varDef.DefaultValue.ExpectedType = varDef.Type
90 varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
91 }
92 }
93
94 var def *ast.Definition
95 var loc ast.DirectiveLocation
96 switch operation.Operation {
97 case ast.Query, "":
98 def = w.Schema.Query
99 loc = ast.LocationQuery
100 case ast.Mutation:
101 def = w.Schema.Mutation
102 loc = ast.LocationMutation
103 case ast.Subscription:
104 def = w.Schema.Subscription
105 loc = ast.LocationSubscription
106 }
107
108 for _, varDef := range operation.VariableDefinitions {
109 if varDef.DefaultValue != nil {
110 w.walkValue(varDef.DefaultValue)
111 }
112 w.walkDirectives(varDef.Definition, varDef.Directives, ast.LocationVariableDefinition)
113 }
114
115 w.walkDirectives(def, operation.Directives, loc)
116 w.walkSelectionSet(def, operation.SelectionSet)
117
118 for _, v := range w.Observers.operationVisitor {
119 v(w, operation)
120 }
121 w.CurrentOperation = nil
122 }
123
124 func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
125 def := w.Schema.Types[it.TypeCondition]
126
127 it.Definition = def
128
129 w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
130 w.walkSelectionSet(def, it.SelectionSet)
131
132 for _, v := range w.Observers.fragment {
133 v(w, it)
134 }
135 }
136
137 func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
138 for _, dir := range directives {
139 def := w.Schema.Directives[dir.Name]
140 dir.Definition = def
141 dir.ParentDefinition = parentDef
142 dir.Location = location
143
144 for _, arg := range dir.Arguments {
145 var argDef *ast.ArgumentDefinition
146 if def != nil {
147 argDef = def.Arguments.ForName(arg.Name)
148 }
149
150 w.walkArgument(argDef, arg)
151 }
152
153 for _, v := range w.Observers.directive {
154 v(w, dir)
155 }
156 }
157
158 for _, v := range w.Observers.directiveList {
159 v(w, directives)
160 }
161 }
162
163 func (w *Walker) walkValue(value *ast.Value) {
164 if value.Kind == ast.Variable && w.CurrentOperation != nil {
165 value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
166 if value.VariableDefinition != nil {
167 value.VariableDefinition.Used = true
168 }
169 }
170
171 if value.Kind == ast.ObjectValue {
172 for _, child := range value.Children {
173 if value.Definition != nil {
174 fieldDef := value.Definition.Fields.ForName(child.Name)
175 if fieldDef != nil {
176 child.Value.ExpectedType = fieldDef.Type
177 child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
178 }
179 }
180 w.walkValue(child.Value)
181 }
182 }
183
184 if value.Kind == ast.ListValue {
185 for _, child := range value.Children {
186 if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
187 child.Value.ExpectedType = value.ExpectedType.Elem
188 child.Value.Definition = value.Definition
189 }
190
191 w.walkValue(child.Value)
192 }
193 }
194
195 for _, v := range w.Observers.value {
196 v(w, value)
197 }
198 }
199
200 func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
201 if argDef != nil {
202 arg.Value.ExpectedType = argDef.Type
203 arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
204 }
205
206 w.walkValue(arg.Value)
207 }
208
209 func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
210 for _, child := range it {
211 w.walkSelection(parentDef, child)
212 }
213 }
214
215 func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
216 switch it := it.(type) {
217 case *ast.Field:
218 var def *ast.FieldDefinition
219 if it.Name == "__typename" {
220 def = &ast.FieldDefinition{
221 Name: "__typename",
222 Type: ast.NamedType("String", nil),
223 }
224 } else if parentDef != nil {
225 def = parentDef.Fields.ForName(it.Name)
226 }
227
228 it.Definition = def
229 it.ObjectDefinition = parentDef
230
231 var nextParentDef *ast.Definition
232 if def != nil {
233 nextParentDef = w.Schema.Types[def.Type.Name()]
234 }
235
236 for _, arg := range it.Arguments {
237 var argDef *ast.ArgumentDefinition
238 if def != nil {
239 argDef = def.Arguments.ForName(arg.Name)
240 }
241
242 w.walkArgument(argDef, arg)
243 }
244
245 w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
246 w.walkSelectionSet(nextParentDef, it.SelectionSet)
247
248 for _, v := range w.Observers.field {
249 v(w, it)
250 }
251
252 case *ast.InlineFragment:
253 it.ObjectDefinition = parentDef
254
255 nextParentDef := parentDef
256 if it.TypeCondition != "" {
257 nextParentDef = w.Schema.Types[it.TypeCondition]
258 }
259
260 w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
261 w.walkSelectionSet(nextParentDef, it.SelectionSet)
262
263 for _, v := range w.Observers.inlineFragment {
264 v(w, it)
265 }
266
267 case *ast.FragmentSpread:
268 def := w.Document.Fragments.ForName(it.Name)
269 it.Definition = def
270 it.ObjectDefinition = parentDef
271
272 var nextParentDef *ast.Definition
273 if def != nil {
274 nextParentDef = w.Schema.Types[def.TypeCondition]
275 }
276
277 w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
278
279 if def != nil && !w.validatedFragmentSpreads[def.Name] {
280
281 w.validatedFragmentSpreads[def.Name] = true
282 w.walkSelectionSet(nextParentDef, def.SelectionSet)
283 }
284
285 for _, v := range w.Observers.fragmentSpread {
286 v(w, it)
287 }
288
289 default:
290 panic(fmt.Errorf("unsupported %T", it))
291 }
292 }
293
View as plain text