1
2
3 package graphql
4
5 import (
6 "context"
7 "fmt"
8
9 "github.com/vektah/gqlparser/v2/ast"
10 )
11
12 type ExecutableSchema interface {
13 Schema() *ast.Schema
14
15 Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
16 Exec(ctx context.Context) ResponseHandler
17 }
18
19
20
21
22 func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
23 return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
24 }
25
26 func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
27 groupedFields := make([]CollectedField, 0, len(selSet))
28
29 for _, sel := range selSet {
30 switch sel := sel.(type) {
31 case *ast.Field:
32 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
33 continue
34 }
35 f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField {
36 return CollectedField{Field: sel}
37 })
38
39 f.Selections = append(f.Selections, sel.SelectionSet...)
40
41 case *ast.InlineFragment:
42 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
43 continue
44 }
45 if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
46 continue
47 }
48
49 shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)
50
51 for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
52 f := getOrCreateAndAppendField(
53 &groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition,
54 func() CollectedField { return childField })
55 f.Selections = append(f.Selections, childField.Selections...)
56 if shouldDefer {
57 f.Deferrable = &Deferrable{
58 Label: label,
59 }
60 }
61 }
62
63 case *ast.FragmentSpread:
64 if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
65 continue
66 }
67 fragmentName := sel.Name
68 if _, seen := visited[fragmentName]; seen {
69 continue
70 }
71 visited[fragmentName] = true
72
73 fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
74 if fragment == nil {
75
76 panic(fmt.Errorf("missing fragment %s", fragmentName))
77 }
78
79 if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
80 continue
81 }
82
83 shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)
84
85 for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
86 f := getOrCreateAndAppendField(&groupedFields,
87 childField.Name, childField.Alias, childField.ObjectDefinition,
88 func() CollectedField { return childField })
89 f.Selections = append(f.Selections, childField.Selections...)
90 if shouldDefer {
91 f.Deferrable = &Deferrable{Label: label}
92 }
93 }
94
95 default:
96 panic(fmt.Errorf("unsupported %T", sel))
97 }
98 }
99
100 return groupedFields
101 }
102
103 type CollectedField struct {
104 *ast.Field
105
106 Selections ast.SelectionSet
107 Deferrable *Deferrable
108 }
109
110 func instanceOf(val string, satisfies []string) bool {
111 for _, s := range satisfies {
112 if val == s {
113 return true
114 }
115 }
116 return false
117 }
118
119 func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField {
120 for i, cf := range *c {
121 if cf.Name == name && cf.Alias == alias {
122 if cf.ObjectDefinition == objectDefinition {
123 return &(*c)[i]
124 }
125
126 if cf.ObjectDefinition == nil || objectDefinition == nil {
127 continue
128 }
129
130 if cf.ObjectDefinition.Name == objectDefinition.Name {
131 return &(*c)[i]
132 }
133
134 for _, ifc := range objectDefinition.Interfaces {
135 if ifc == cf.ObjectDefinition.Name {
136 return &(*c)[i]
137 }
138 }
139 for _, ifc := range cf.ObjectDefinition.Interfaces {
140 if ifc == objectDefinition.Name {
141 return &(*c)[i]
142 }
143 }
144 }
145 }
146
147 f := creator()
148
149 *c = append(*c, f)
150 return &(*c)[len(*c)-1]
151 }
152
153 func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
154 if len(directives) == 0 {
155 return true
156 }
157
158 skip, include := false, true
159
160 if d := directives.ForName("skip"); d != nil {
161 skip = resolveIfArgument(d, variables)
162 }
163
164 if d := directives.ForName("include"); d != nil {
165 include = resolveIfArgument(d, variables)
166 }
167
168 return !skip && include
169 }
170
171 func deferrable(directives ast.DirectiveList, variables map[string]interface{}) (shouldDefer bool, label string) {
172 d := directives.ForName("defer")
173 if d == nil {
174 return false, ""
175 }
176
177 shouldDefer = true
178
179 for _, arg := range d.Arguments {
180 switch arg.Name {
181 case "if":
182 if value, err := arg.Value.Value(variables); err == nil {
183 shouldDefer, _ = value.(bool)
184 }
185 case "label":
186 if value, err := arg.Value.Value(variables); err == nil {
187 label, _ = value.(string)
188 }
189 default:
190 panic(fmt.Sprintf("defer: argument '%s' not supported", arg.Name))
191 }
192 }
193
194 return shouldDefer, label
195 }
196
197 func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
198 arg := d.Arguments.ForName("if")
199 if arg == nil {
200 panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
201 }
202 value, err := arg.Value.Value(variables)
203 if err != nil {
204 panic(err)
205 }
206 ret, ok := value.(bool)
207 if !ok {
208 panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
209 }
210 return ret
211 }
212
View as plain text