...
1 package validator
2
3 import (
4 "strconv"
5 "strings"
6
7 "github.com/vektah/gqlparser/v2/ast"
8
9
10 . "github.com/vektah/gqlparser/v2/validator"
11 )
12
13 func init() {
14 AddRule("SingleFieldSubscriptions", func(observers *Events, addError AddErrFunc) {
15 observers.OnOperation(func(walker *Walker, operation *ast.OperationDefinition) {
16 if walker.Schema.Subscription == nil || operation.Operation != ast.Subscription {
17 return
18 }
19
20 fields := retrieveTopFieldNames(operation.SelectionSet)
21
22 name := "Anonymous Subscription"
23 if operation.Name != "" {
24 name = `Subscription ` + strconv.Quote(operation.Name)
25 }
26
27 if len(fields) > 1 {
28 addError(
29 Message(`%s must select only one top level field.`, name),
30 At(fields[1].position),
31 )
32 }
33
34 for _, field := range fields {
35 if strings.HasPrefix(field.name, "__") {
36 addError(
37 Message(`%s must not select an introspection top level field.`, name),
38 At(field.position),
39 )
40 }
41 }
42 })
43 })
44 }
45
46 type topField struct {
47 name string
48 position *ast.Position
49 }
50
51 func retrieveTopFieldNames(selectionSet ast.SelectionSet) []*topField {
52 fields := []*topField{}
53 inFragmentRecursive := map[string]bool{}
54 var walk func(selectionSet ast.SelectionSet)
55 walk = func(selectionSet ast.SelectionSet) {
56 for _, selection := range selectionSet {
57 switch selection := selection.(type) {
58 case *ast.Field:
59 fields = append(fields, &topField{
60 name: selection.Name,
61 position: selection.GetPosition(),
62 })
63 case *ast.InlineFragment:
64 walk(selection.SelectionSet)
65 case *ast.FragmentSpread:
66 if selection.Definition == nil {
67 return
68 }
69 fragment := selection.Definition.Name
70 if !inFragmentRecursive[fragment] {
71 inFragmentRecursive[fragment] = true
72 walk(selection.Definition.SelectionSet)
73 }
74 }
75 }
76 }
77 walk(selectionSet)
78
79 seen := make(map[string]bool, len(fields))
80 uniquedFields := make([]*topField, 0, len(fields))
81 for _, field := range fields {
82 if !seen[field.name] {
83 uniquedFields = append(uniquedFields, field)
84 }
85 seen[field.name] = true
86 }
87 return uniquedFields
88 }
89
View as plain text