1 package graphqlhelpers
2
3 import (
4 "encoding/json"
5 "fmt"
6 "strings"
7
8 "github.com/99designs/gqlgen/graphql"
9 "github.com/vektah/gqlparser/v2/ast"
10 "github.com/vektah/gqlparser/v2/gqlerror"
11 "github.com/vektah/gqlparser/v2/parser"
12 )
13
14 const (
15 SensitiveMask = "******************"
16 )
17
18 var (
19 sensitiveParams = map[string]bool{
20 "password": true,
21 "key": true,
22 "secret": true,
23 "token": true,
24 "oktaToken": true,
25 "refreshToken": true,
26 "newPassword": true,
27 "secretValue": true,
28 }
29 )
30
31
32 type Query struct {
33 Name string
34 SelectedFields []*SelectedField
35 }
36
37
38 type SelectedField struct {
39 Name string
40 SubFields []*SelectedField
41 }
42
43
44 func GetOperation(rctx *graphql.OperationContext) *ast.Operation {
45 if rctx != nil && rctx.Operation != nil {
46 return &rctx.Operation.Operation
47 }
48 return nil
49 }
50
51
52 func GetRawQuery(rctx *graphql.OperationContext) string {
53 if rctx != nil {
54 return rctx.RawQuery
55 }
56 return ""
57 }
58
59
60 func GetVariables(rctx *graphql.OperationContext) map[string]interface{} {
61 vars := make(map[string]interface{})
62 if rctx != nil {
63 for k, v := range rctx.Variables {
64 vars[strings.ToUpper(k)] = v
65 }
66 }
67 return vars
68 }
69
70
71 func ParseQuery(query string) (*ast.QueryDocument, error) {
72 src := &ast.Source{
73 Input: query,
74 }
75 schema, err := parser.ParseQuery(src)
76 if err != nil {
77 return nil, gqlerror.List{&gqlerror.Error{
78 Err: err,
79 }}
80 }
81 return schema, nil
82 }
83
84
85 func (q Query) String() string {
86
87 res, _ := json.Marshal(q)
88 return string(res)
89 }
90
91
92 func GetOperations(schema *ast.QueryDocument) string {
93 if schema == nil {
94 return ""
95 }
96 var operations strings.Builder
97 for idx, re := range schema.Operations {
98 if idx == 0 {
99 operations.WriteString(string(re.Operation))
100 } else {
101 operations.WriteString(fmt.Sprintf(", %s", re.Operation))
102 }
103 }
104 return operations.String()
105 }
106
107
108 func GetQueries(schema *ast.QueryDocument) []*Query {
109 queries := make([]*Query, 0)
110 if schema == nil {
111 return queries
112 }
113 for _, op := range schema.Operations {
114 for _, selection := range op.SelectionSet {
115 field := selection.(*ast.Field)
116 query := &Query{
117 Name: field.Name,
118 SelectedFields: make([]*SelectedField, 0),
119 }
120 query.SelectedFields = recursiveField(field)
121 queries = append(queries, query)
122 }
123 }
124 return queries
125 }
126
127
128 func SanitizeDocument(schema *ast.QueryDocument) {
129 if schema == nil {
130 return
131 }
132 for _, re := range schema.Operations {
133 for _, selection := range re.SelectionSet {
134 if field, ok := selection.(*ast.Field); ok {
135 for _, args := range field.Arguments {
136 if _, exists := sensitiveParams[args.Name]; exists {
137 args.Value = &ast.Value{
138 Raw: SensitiveMask,
139 Kind: ast.StringValue,
140 }
141 }
142 }
143 }
144 }
145 }
146 }
147
148
149
150
151
152
153
154
155
156
157
158
159
160 func recursiveField(field *ast.Field) []*SelectedField {
161 out := make([]*SelectedField, 0)
162 for _, fss := range field.SelectionSet {
163 res := fss.(*ast.Field)
164 sf := &SelectedField{
165 Name: res.Name,
166 }
167 recursiveSelectionSet(res.SelectionSet, sf)
168 out = append(out, sf)
169 }
170 return out
171 }
172
173
174 func recursiveSelectionSet(ss ast.SelectionSet, sf *SelectedField) {
175 for _, fss := range ss {
176 res := fss.(*ast.Field)
177 field := SelectedField{
178 Name: res.Name,
179 }
180 if len(res.SelectionSet) > 0 {
181 field.SubFields = make([]*SelectedField, 0)
182 for _, fss := range res.SelectionSet {
183 res := fss.(*ast.Field)
184 field.SubFields = append(field.SubFields, &SelectedField{
185 Name: res.Name,
186 })
187 sf.SubFields = append(sf.SubFields, &field)
188 recursiveSelectionSet(res.SelectionSet, sf)
189 }
190 } else {
191 sf.SubFields = append(sf.SubFields, &field)
192 }
193 }
194 }
195
196
197 func GetQueryNames(schema *ast.QueryDocument) []string {
198 names := make([]string, 0)
199 if schema == nil {
200 return names
201 }
202 for _, re := range schema.Operations {
203 for _, selection := range re.SelectionSet {
204 field := selection.(*ast.Field)
205 names = append(names, field.Name)
206 }
207 }
208 return names
209 }
210
211
212 func GetParams(opctx *graphql.OperationContext, schema *ast.QueryDocument) map[string]interface{} {
213 params := make(map[string]interface{}, 0)
214 if schema == nil {
215 return params
216 }
217 for _, re := range schema.Operations {
218 for _, selection := range re.SelectionSet {
219 if field, ok := selection.(*ast.Field); ok {
220 for _, args := range field.Arguments {
221 params[args.Name] = args.Value.String()
222 }
223 }
224 }
225 }
226 if opctx != nil {
227 vars := GetVariables(opctx)
228 for key, value := range params {
229 v := value
230 if val, exists := vars[strings.ToUpper(key)]; exists {
231 v = val
232 }
233 if exists := sensitiveParams[key]; exists {
234 v = SensitiveMask
235 }
236 params[key] = v
237 }
238 }
239 return params
240 }
241
242
243
244
245
246 func GetResponseStatus(resp *graphql.Response) string {
247 const nullStr = "null"
248 switch {
249 case resp != nil && len(resp.Errors) > 0 && string(resp.Data) != nullStr:
250 return "Partial Failure"
251 case resp != nil && len(resp.Errors) > 0 && string(resp.Data) == nullStr:
252 return "Failure"
253 case resp != nil && len(resp.Errors) == 0 && string(resp.Data) != nullStr:
254 return "Success"
255 default:
256 return "Unknown"
257 }
258 }
259
260
261
262
263
264 func UpdateQueryWithVariables(doc *ast.QueryDocument, variables map[string]interface{}) {
265 if doc == nil {
266 return
267 }
268 for _, op := range doc.Operations {
269 for idx, selection := range op.SelectionSet {
270 field, ok := selection.(*ast.Field)
271 if ok {
272 for i, arg := range field.Arguments {
273 argName := strings.Trim(arg.Value.String(), "$")
274 val, exists := variables[strings.ToUpper(argName)]
275 if !exists && argName == "" {
276
277
278 arg.Value.Raw = "<nil>"
279 arg.Value.Kind = ast.NullValue
280 continue
281 }
282 kind := getAstKind(val)
283 arg.Value.Kind = kind
284 getInnerVal(kind, arg.Value, val, (*[]*ast.ChildValue)(&arg.Value.Children))
285 field.Arguments[i] = arg
286 }
287 }
288 op.SelectionSet[idx] = field
289 }
290 }
291 }
292
293
294
295 func getInnerVal(kind ast.ValueKind, arg *ast.Value, val interface{}, children *[]*ast.ChildValue) {
296 switch kind {
297 case ast.StringValue, ast.BlockValue, ast.EnumValue, ast.Variable:
298 arg.Raw = fmt.Sprintf("%s", val)
299 arg.Kind = ast.StringValue
300 case ast.IntValue:
301 arg.Raw = fmt.Sprintf("%v", val)
302 arg.Kind = ast.IntValue
303 case ast.FloatValue:
304 arg.Raw = fmt.Sprintf("%v", val)
305 arg.Kind = ast.FloatValue
306 case ast.BooleanValue:
307 arg.Raw = fmt.Sprintf("%v", val)
308 arg.Kind = ast.BooleanValue
309 case ast.NullValue:
310 arg.Kind = ast.NullValue
311 case ast.ObjectValue:
312 elem, ok := val.(map[string]interface{})
313 if ok {
314 for key, value := range elem {
315 newChild := &ast.ChildValue{
316 Name: key,
317 Value: &ast.Value{
318 Kind: ast.ObjectValue,
319 },
320 }
321 childKind := getAstKind(value)
322 arg.Children = append(arg.Children, newChild)
323 getInnerVal(childKind, newChild.Value, value, children)
324 }
325 } else {
326 arg.Raw = fmt.Sprintf("%v", val)
327 arg.Kind = ast.StringValue
328 }
329 case ast.ListValue:
330 switch val := val.(type) {
331 case []string:
332 getSubVal(val, children)
333 case []int:
334 getSubVal(val, children)
335 case []int8:
336 getSubVal(val, children)
337 case []int16:
338 getSubVal(val, children)
339 case []int32:
340 getSubVal(val, children)
341 case []int64:
342 getSubVal(val, children)
343 case []float32:
344 getSubVal(val, children)
345 case []float64:
346 getSubVal(val, children)
347 case []bool:
348 getSubVal(val, children)
349 case []any:
350 getSubVal(val, children)
351 }
352 }
353 }
354
355 func getSubVal[T any](arr []T, children *[]*ast.ChildValue) {
356 for _, elem := range arr {
357 childKind := getAstKind(elem)
358 newChild := &ast.ChildValue{Value: &ast.Value{Kind: childKind}}
359 *children = append(*children, newChild)
360 getInnerVal(childKind, newChild.Value, elem, (*[]*ast.ChildValue)(&newChild.Value.Children))
361 }
362 }
363
364 func getAstKind(_var interface{}) ast.ValueKind {
365 switch _var.(type) {
366 case int, int8, int16, int32, int64:
367 return ast.IntValue
368 case float32, float64:
369 return ast.FloatValue
370 case string:
371 return ast.StringValue
372 case bool:
373 return ast.BooleanValue
374 case nil:
375 return ast.NullValue
376 case []int, []int8, []int16, []int32, []int64, []string, []bool, []any:
377 return ast.ListValue
378 case struct{}, interface{}:
379 return ast.ObjectValue
380 default:
381 return ast.StringValue
382 }
383 }
384
385 func IsMutation(rctx *graphql.OperationContext) bool {
386 return rctx != nil && rctx.Operation != nil && rctx.Operation.Operation == ast.Mutation
387 }
388
View as plain text