1 package validator
2
3 import (
4 "bytes"
5 "fmt"
6 "reflect"
7
8 "github.com/vektah/gqlparser/v2/ast"
9
10
11 . "github.com/vektah/gqlparser/v2/validator"
12 )
13
14 func init() {
15
16 AddRule("OverlappingFieldsCanBeMerged", func(observers *Events, addError AddErrFunc) {
17
71
72 m := &overlappingFieldsCanBeMergedManager{
73 comparedFragmentPairs: pairSet{data: make(map[string]map[string]bool)},
74 }
75
76 observers.OnOperation(func(walker *Walker, operation *ast.OperationDefinition) {
77 m.walker = walker
78 conflicts := m.findConflictsWithinSelectionSet(operation.SelectionSet)
79 for _, conflict := range conflicts {
80 conflict.addFieldsConflictMessage(addError)
81 }
82 })
83 observers.OnField(func(walker *Walker, field *ast.Field) {
84 if walker.CurrentOperation == nil {
85
86 return
87 }
88 m.walker = walker
89 conflicts := m.findConflictsWithinSelectionSet(field.SelectionSet)
90 for _, conflict := range conflicts {
91 conflict.addFieldsConflictMessage(addError)
92 }
93 })
94 observers.OnInlineFragment(func(walker *Walker, inlineFragment *ast.InlineFragment) {
95 m.walker = walker
96 conflicts := m.findConflictsWithinSelectionSet(inlineFragment.SelectionSet)
97 for _, conflict := range conflicts {
98 conflict.addFieldsConflictMessage(addError)
99 }
100 })
101 observers.OnFragment(func(walker *Walker, fragment *ast.FragmentDefinition) {
102 m.walker = walker
103 conflicts := m.findConflictsWithinSelectionSet(fragment.SelectionSet)
104 for _, conflict := range conflicts {
105 conflict.addFieldsConflictMessage(addError)
106 }
107 })
108 })
109 }
110
111 type pairSet struct {
112 data map[string]map[string]bool
113 }
114
115 func (pairSet *pairSet) Add(a *ast.FragmentSpread, b *ast.FragmentSpread, areMutuallyExclusive bool) {
116 add := func(a *ast.FragmentSpread, b *ast.FragmentSpread) {
117 m := pairSet.data[a.Name]
118 if m == nil {
119 m = make(map[string]bool)
120 pairSet.data[a.Name] = m
121 }
122 m[b.Name] = areMutuallyExclusive
123 }
124 add(a, b)
125 add(b, a)
126 }
127
128 func (pairSet *pairSet) Has(a *ast.FragmentSpread, b *ast.FragmentSpread, areMutuallyExclusive bool) bool {
129 am, ok := pairSet.data[a.Name]
130 if !ok {
131 return false
132 }
133 result, ok := am[b.Name]
134 if !ok {
135 return false
136 }
137
138
139
140
141 if !areMutuallyExclusive {
142 return !result
143 }
144
145 return true
146 }
147
148 type sequentialFieldsMap struct {
149
150 seq []string
151 data map[string][]*ast.Field
152 }
153
154 type fieldIterateEntry struct {
155 ResponseName string
156 Fields []*ast.Field
157 }
158
159 func (m *sequentialFieldsMap) Push(responseName string, field *ast.Field) {
160 fields, ok := m.data[responseName]
161 if !ok {
162 m.seq = append(m.seq, responseName)
163 }
164 fields = append(fields, field)
165 m.data[responseName] = fields
166 }
167
168 func (m *sequentialFieldsMap) Get(responseName string) ([]*ast.Field, bool) {
169 fields, ok := m.data[responseName]
170 return fields, ok
171 }
172
173 func (m *sequentialFieldsMap) Iterator() [][]*ast.Field {
174 fieldsList := make([][]*ast.Field, 0, len(m.seq))
175 for _, responseName := range m.seq {
176 fields := m.data[responseName]
177 fieldsList = append(fieldsList, fields)
178 }
179 return fieldsList
180 }
181
182 func (m *sequentialFieldsMap) KeyValueIterator() []*fieldIterateEntry {
183 fieldEntriesList := make([]*fieldIterateEntry, 0, len(m.seq))
184 for _, responseName := range m.seq {
185 fields := m.data[responseName]
186 fieldEntriesList = append(fieldEntriesList, &fieldIterateEntry{
187 ResponseName: responseName,
188 Fields: fields,
189 })
190 }
191 return fieldEntriesList
192 }
193
194 type conflictMessageContainer struct {
195 Conflicts []*ConflictMessage
196 }
197
198 type ConflictMessage struct {
199 Message string
200 ResponseName string
201 Names []string
202 SubMessage []*ConflictMessage
203 Position *ast.Position
204 }
205
206 func (m *ConflictMessage) String(buf *bytes.Buffer) {
207 if len(m.SubMessage) == 0 {
208 buf.WriteString(m.Message)
209 return
210 }
211
212 for idx, subMessage := range m.SubMessage {
213 buf.WriteString(`subfields "`)
214 buf.WriteString(subMessage.ResponseName)
215 buf.WriteString(`" conflict because `)
216 subMessage.String(buf)
217 if idx != len(m.SubMessage)-1 {
218 buf.WriteString(" and ")
219 }
220 }
221 }
222
223 func (m *ConflictMessage) addFieldsConflictMessage(addError AddErrFunc) {
224 var buf bytes.Buffer
225 m.String(&buf)
226 addError(
227 Message(`Fields "%s" conflict because %s. Use different aliases on the fields to fetch both if this was intentional.`, m.ResponseName, buf.String()),
228 At(m.Position),
229 )
230 }
231
232 type overlappingFieldsCanBeMergedManager struct {
233 walker *Walker
234
235
236 comparedFragmentPairs pairSet
237
238
239
240 comparedFragments map[string]bool
241 }
242
243 func (m *overlappingFieldsCanBeMergedManager) findConflictsWithinSelectionSet(selectionSet ast.SelectionSet) []*ConflictMessage {
244 if len(selectionSet) == 0 {
245 return nil
246 }
247
248 fieldsMap, fragmentSpreads := getFieldsAndFragmentNames(selectionSet)
249
250 var conflicts conflictMessageContainer
251
252
253
254 m.collectConflictsWithin(&conflicts, fieldsMap)
255
256 m.comparedFragments = make(map[string]bool)
257 for idx, fragmentSpreadA := range fragmentSpreads {
258
259
260 m.collectConflictsBetweenFieldsAndFragment(&conflicts, false, fieldsMap, fragmentSpreadA)
261
262 for _, fragmentSpreadB := range fragmentSpreads[idx+1:] {
263
264
265
266
267 m.collectConflictsBetweenFragments(&conflicts, false, fragmentSpreadA, fragmentSpreadB)
268 }
269 }
270
271 return conflicts.Conflicts
272 }
273
274 func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetweenFieldsAndFragment(conflicts *conflictMessageContainer, areMutuallyExclusive bool, fieldsMap *sequentialFieldsMap, fragmentSpread *ast.FragmentSpread) {
275 if m.comparedFragments[fragmentSpread.Name] {
276 return
277 }
278 m.comparedFragments[fragmentSpread.Name] = true
279
280 if fragmentSpread.Definition == nil {
281 return
282 }
283
284 fieldsMapB, fragmentSpreads := getFieldsAndFragmentNames(fragmentSpread.Definition.SelectionSet)
285
286
287 if reflect.DeepEqual(fieldsMap, fieldsMapB) {
288 return
289 }
290
291
292
293 m.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsMap, fieldsMapB)
294
295
296
297 baseFragmentSpread := fragmentSpread
298 for _, fragmentSpread := range fragmentSpreads {
299 if fragmentSpread.Name == baseFragmentSpread.Name {
300 continue
301 }
302 m.collectConflictsBetweenFieldsAndFragment(conflicts, areMutuallyExclusive, fieldsMap, fragmentSpread)
303 }
304 }
305
306 func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetweenFragments(conflicts *conflictMessageContainer, areMutuallyExclusive bool, fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread) {
307
308 var check func(fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread)
309 check = func(fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread) {
310
311 if fragmentSpreadA.Name == fragmentSpreadB.Name {
312 return
313 }
314
315 if m.comparedFragmentPairs.Has(fragmentSpreadA, fragmentSpreadB, areMutuallyExclusive) {
316 return
317 }
318 m.comparedFragmentPairs.Add(fragmentSpreadA, fragmentSpreadB, areMutuallyExclusive)
319
320 if fragmentSpreadA.Definition == nil {
321 return
322 }
323 if fragmentSpreadB.Definition == nil {
324 return
325 }
326
327 fieldsMapA, fragmentSpreadsA := getFieldsAndFragmentNames(fragmentSpreadA.Definition.SelectionSet)
328 fieldsMapB, fragmentSpreadsB := getFieldsAndFragmentNames(fragmentSpreadB.Definition.SelectionSet)
329
330
331
332 m.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsMapA, fieldsMapB)
333
334
335
336 for _, fragmentSpread := range fragmentSpreadsB {
337 check(fragmentSpreadA, fragmentSpread)
338 }
339
340
341 for _, fragmentSpread := range fragmentSpreadsA {
342 check(fragmentSpread, fragmentSpreadB)
343 }
344 }
345
346 check(fragmentSpreadA, fragmentSpreadB)
347 }
348
349 func (m *overlappingFieldsCanBeMergedManager) findConflictsBetweenSubSelectionSets(areMutuallyExclusive bool, selectionSetA ast.SelectionSet, selectionSetB ast.SelectionSet) *conflictMessageContainer {
350 var conflicts conflictMessageContainer
351
352 fieldsMapA, fragmentSpreadsA := getFieldsAndFragmentNames(selectionSetA)
353 fieldsMapB, fragmentSpreadsB := getFieldsAndFragmentNames(selectionSetB)
354
355
356 m.collectConflictsBetween(&conflicts, areMutuallyExclusive, fieldsMapA, fieldsMapB)
357
358
359
360 for _, fragmentSpread := range fragmentSpreadsB {
361 m.comparedFragments = make(map[string]bool)
362 m.collectConflictsBetweenFieldsAndFragment(&conflicts, areMutuallyExclusive, fieldsMapA, fragmentSpread)
363 }
364
365
366
367 for _, fragmentSpread := range fragmentSpreadsA {
368 m.comparedFragments = make(map[string]bool)
369 m.collectConflictsBetweenFieldsAndFragment(&conflicts, areMutuallyExclusive, fieldsMapB, fragmentSpread)
370 }
371
372
373
374
375 for _, fragmentSpreadA := range fragmentSpreadsA {
376 for _, fragmentSpreadB := range fragmentSpreadsB {
377 m.collectConflictsBetweenFragments(&conflicts, areMutuallyExclusive, fragmentSpreadA, fragmentSpreadB)
378 }
379 }
380
381 if len(conflicts.Conflicts) == 0 {
382 return nil
383 }
384
385 return &conflicts
386 }
387
388 func (m *overlappingFieldsCanBeMergedManager) collectConflictsWithin(conflicts *conflictMessageContainer, fieldsMap *sequentialFieldsMap) {
389 for _, fields := range fieldsMap.Iterator() {
390 for idx, fieldA := range fields {
391 for _, fieldB := range fields[idx+1:] {
392 conflict := m.findConflict(false, fieldA, fieldB)
393 if conflict != nil {
394 conflicts.Conflicts = append(conflicts.Conflicts, conflict)
395 }
396 }
397 }
398 }
399 }
400
401 func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetween(conflicts *conflictMessageContainer, parentFieldsAreMutuallyExclusive bool, fieldsMapA *sequentialFieldsMap, fieldsMapB *sequentialFieldsMap) {
402 for _, fieldsEntryA := range fieldsMapA.KeyValueIterator() {
403 fieldsB, ok := fieldsMapB.Get(fieldsEntryA.ResponseName)
404 if !ok {
405 continue
406 }
407 for _, fieldA := range fieldsEntryA.Fields {
408 for _, fieldB := range fieldsB {
409 conflict := m.findConflict(parentFieldsAreMutuallyExclusive, fieldA, fieldB)
410 if conflict != nil {
411 conflicts.Conflicts = append(conflicts.Conflicts, conflict)
412 }
413 }
414 }
415 }
416 }
417
418 func (m *overlappingFieldsCanBeMergedManager) findConflict(parentFieldsAreMutuallyExclusive bool, fieldA *ast.Field, fieldB *ast.Field) *ConflictMessage {
419 if fieldA.ObjectDefinition == nil || fieldB.ObjectDefinition == nil {
420 return nil
421 }
422
423 areMutuallyExclusive := parentFieldsAreMutuallyExclusive
424 if !areMutuallyExclusive {
425 tmp := fieldA.ObjectDefinition.Name != fieldB.ObjectDefinition.Name
426 tmp = tmp && fieldA.ObjectDefinition.Kind == ast.Object
427 tmp = tmp && fieldB.ObjectDefinition.Kind == ast.Object
428 tmp = tmp && fieldA.Definition != nil && fieldB.Definition != nil
429 areMutuallyExclusive = tmp
430 }
431
432 fieldNameA := fieldA.Name
433 if fieldA.Alias != "" {
434 fieldNameA = fieldA.Alias
435 }
436
437 if !areMutuallyExclusive {
438
439 if fieldA.Name != fieldB.Name {
440 return &ConflictMessage{
441 ResponseName: fieldNameA,
442 Message: fmt.Sprintf(`"%s" and "%s" are different fields`, fieldA.Name, fieldB.Name),
443 Position: fieldB.Position,
444 }
445 }
446
447
448 if !sameArguments(fieldA.Arguments, fieldB.Arguments) {
449 return &ConflictMessage{
450 ResponseName: fieldNameA,
451 Message: "they have differing arguments",
452 Position: fieldB.Position,
453 }
454 }
455 }
456
457 if fieldA.Definition != nil && fieldB.Definition != nil && doTypesConflict(m.walker, fieldA.Definition.Type, fieldB.Definition.Type) {
458 return &ConflictMessage{
459 ResponseName: fieldNameA,
460 Message: fmt.Sprintf(`they return conflicting types "%s" and "%s"`, fieldA.Definition.Type.String(), fieldB.Definition.Type.String()),
461 Position: fieldB.Position,
462 }
463 }
464
465
466
467
468 conflicts := m.findConflictsBetweenSubSelectionSets(areMutuallyExclusive, fieldA.SelectionSet, fieldB.SelectionSet)
469 if conflicts == nil {
470 return nil
471 }
472 return &ConflictMessage{
473 ResponseName: fieldNameA,
474 SubMessage: conflicts.Conflicts,
475 Position: fieldB.Position,
476 }
477 }
478
479 func sameArguments(args1 []*ast.Argument, args2 []*ast.Argument) bool {
480 if len(args1) != len(args2) {
481 return false
482 }
483 for _, arg1 := range args1 {
484 var matched bool
485 for _, arg2 := range args2 {
486 if arg1.Name == arg2.Name && sameValue(arg1.Value, arg2.Value) {
487 matched = true
488 break
489 }
490 }
491 if !matched {
492 return false
493 }
494 }
495 return true
496 }
497
498 func sameValue(value1 *ast.Value, value2 *ast.Value) bool {
499 if value1.Kind != value2.Kind {
500 return false
501 }
502 if value1.Raw != value2.Raw {
503 return false
504 }
505 return true
506 }
507
508 func doTypesConflict(walker *Walker, type1 *ast.Type, type2 *ast.Type) bool {
509 if type1.Elem != nil {
510 if type2.Elem != nil {
511 return doTypesConflict(walker, type1.Elem, type2.Elem)
512 }
513 return true
514 }
515 if type2.Elem != nil {
516 return true
517 }
518 if type1.NonNull && !type2.NonNull {
519 return true
520 }
521 if !type1.NonNull && type2.NonNull {
522 return true
523 }
524
525 t1 := walker.Schema.Types[type1.NamedType]
526 t2 := walker.Schema.Types[type2.NamedType]
527 if (t1.Kind == ast.Scalar || t1.Kind == ast.Enum) && (t2.Kind == ast.Scalar || t2.Kind == ast.Enum) {
528 return t1.Name != t2.Name
529 }
530
531 return false
532 }
533
534 func getFieldsAndFragmentNames(selectionSet ast.SelectionSet) (*sequentialFieldsMap, []*ast.FragmentSpread) {
535 fieldsMap := sequentialFieldsMap{
536 data: make(map[string][]*ast.Field),
537 }
538 var fragmentSpreads []*ast.FragmentSpread
539
540 var walk func(selectionSet ast.SelectionSet)
541 walk = func(selectionSet ast.SelectionSet) {
542 for _, selection := range selectionSet {
543 switch selection := selection.(type) {
544 case *ast.Field:
545 responseName := selection.Name
546 if selection.Alias != "" {
547 responseName = selection.Alias
548 }
549 fieldsMap.Push(responseName, selection)
550
551 case *ast.InlineFragment:
552 walk(selection.SelectionSet)
553
554 case *ast.FragmentSpread:
555 fragmentSpreads = append(fragmentSpreads, selection)
556 }
557 }
558 }
559 walk(selectionSet)
560
561 return &fieldsMap, fragmentSpreads
562 }
563
View as plain text