...

Source file src/edge-infra.dev/pkg/edge/api/graphqlhelpers/helpers.go

Documentation: edge-infra.dev/pkg/edge/api/graphqlhelpers

     1  package graphqlhelpers
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/99designs/gqlgen/graphql"
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  	"github.com/vektah/gqlparser/v2/gqlerror"
    11  	"github.com/vektah/gqlparser/v2/parser"
    12  )
    13  
    14  const (
    15  	SensitiveMask = "******************"
    16  )
    17  
    18  var (
    19  	sensitiveParams = map[string]bool{
    20  		"password":     true,
    21  		"key":          true,
    22  		"secret":       true,
    23  		"token":        true,
    24  		"oktaToken":    true,
    25  		"refreshToken": true,
    26  		"newPassword":  true,
    27  		"secretValue":  true,
    28  	}
    29  )
    30  
    31  // Query
    32  type Query struct {
    33  	Name           string
    34  	SelectedFields []*SelectedField
    35  }
    36  
    37  // SelectedField
    38  type SelectedField struct {
    39  	Name      string
    40  	SubFields []*SelectedField
    41  }
    42  
    43  // GetOperation returns the graphql operation mutation/query/subscription.
    44  func GetOperation(rctx *graphql.OperationContext) *ast.Operation {
    45  	if rctx != nil && rctx.Operation != nil {
    46  		return &rctx.Operation.Operation
    47  	}
    48  	return nil
    49  }
    50  
    51  // GetRawQuery returns the graphql query.
    52  func GetRawQuery(rctx *graphql.OperationContext) string {
    53  	if rctx != nil {
    54  		return rctx.RawQuery
    55  	}
    56  	return ""
    57  }
    58  
    59  // GetVariables
    60  func GetVariables(rctx *graphql.OperationContext) map[string]interface{} {
    61  	vars := make(map[string]interface{})
    62  	if rctx != nil {
    63  		for k, v := range rctx.Variables {
    64  			vars[strings.ToUpper(k)] = v
    65  		}
    66  	}
    67  	return vars
    68  }
    69  
    70  // ParseQuery parses the graphql query into a QueryDocument.
    71  func ParseQuery(query string) (*ast.QueryDocument, error) {
    72  	src := &ast.Source{
    73  		Input: query,
    74  	}
    75  	schema, err := parser.ParseQuery(src)
    76  	if err != nil {
    77  		return nil, gqlerror.List{&gqlerror.Error{
    78  			Err: err,
    79  		}}
    80  	}
    81  	return schema, nil
    82  }
    83  
    84  // TODO(pa250194_ncrvoyix): this function is supposed to ident and stringify the GraphQL queries.
    85  func (q Query) String() string {
    86  	// ident the query sufficiently
    87  	res, _ := json.Marshal(q)
    88  	return string(res)
    89  }
    90  
    91  // GetOperations
    92  func GetOperations(schema *ast.QueryDocument) string {
    93  	if schema == nil {
    94  		return ""
    95  	}
    96  	var operations strings.Builder
    97  	for idx, re := range schema.Operations {
    98  		if idx == 0 {
    99  			operations.WriteString(string(re.Operation))
   100  		} else {
   101  			operations.WriteString(fmt.Sprintf(", %s", re.Operation))
   102  		}
   103  	}
   104  	return operations.String()
   105  }
   106  
   107  // GetQueries
   108  func GetQueries(schema *ast.QueryDocument) []*Query {
   109  	queries := make([]*Query, 0)
   110  	if schema == nil {
   111  		return queries
   112  	}
   113  	for _, op := range schema.Operations {
   114  		for _, selection := range op.SelectionSet {
   115  			field := selection.(*ast.Field)
   116  			query := &Query{
   117  				Name:           field.Name,
   118  				SelectedFields: make([]*SelectedField, 0),
   119  			}
   120  			query.SelectedFields = recursiveField(field)
   121  			queries = append(queries, query)
   122  		}
   123  	}
   124  	return queries
   125  }
   126  
   127  // SanitizeDocument
   128  func SanitizeDocument(schema *ast.QueryDocument) {
   129  	if schema == nil {
   130  		return
   131  	}
   132  	for _, re := range schema.Operations {
   133  		for _, selection := range re.SelectionSet {
   134  			if field, ok := selection.(*ast.Field); ok {
   135  				for _, args := range field.Arguments {
   136  					if _, exists := sensitiveParams[args.Name]; exists {
   137  						args.Value = &ast.Value{
   138  							Raw:  SensitiveMask,
   139  							Kind: ast.StringValue,
   140  						}
   141  					}
   142  				}
   143  			}
   144  		}
   145  	}
   146  }
   147  
   148  // recursiveField recusively fetches the selected fields of a specific field
   149  // Example Query:
   150  //
   151  //	users {
   152  //		name
   153  //	 	contact {
   154  //			phone
   155  //			email
   156  //		}
   157  //	}
   158  //
   159  // Returns: [{Name: "name", SubFields: null}, {Name: "contact", SubFields: [{Name: "phone", SubFields: null}, {Name: "email", SubFields: null}]}]
   160  func recursiveField(field *ast.Field) []*SelectedField {
   161  	out := make([]*SelectedField, 0)
   162  	for _, fss := range field.SelectionSet {
   163  		res := fss.(*ast.Field)
   164  		sf := &SelectedField{
   165  			Name: res.Name,
   166  		}
   167  		recursiveSelectionSet(res.SelectionSet, sf)
   168  		out = append(out, sf)
   169  	}
   170  	return out
   171  }
   172  
   173  // recursiveSelectionSet
   174  func recursiveSelectionSet(ss ast.SelectionSet, sf *SelectedField) {
   175  	for _, fss := range ss {
   176  		res := fss.(*ast.Field)
   177  		field := SelectedField{
   178  			Name: res.Name,
   179  		}
   180  		if len(res.SelectionSet) > 0 {
   181  			field.SubFields = make([]*SelectedField, 0)
   182  			for _, fss := range res.SelectionSet {
   183  				res := fss.(*ast.Field)
   184  				field.SubFields = append(field.SubFields, &SelectedField{
   185  					Name: res.Name,
   186  				})
   187  				sf.SubFields = append(sf.SubFields, &field)
   188  				recursiveSelectionSet(res.SelectionSet, sf)
   189  			}
   190  		} else {
   191  			sf.SubFields = append(sf.SubFields, &field)
   192  		}
   193  	}
   194  }
   195  
   196  // GetQueryNames returns the graphql query names example: WhoAmI, Login, Logout etc.
   197  func GetQueryNames(schema *ast.QueryDocument) []string {
   198  	names := make([]string, 0)
   199  	if schema == nil {
   200  		return names
   201  	}
   202  	for _, re := range schema.Operations {
   203  		for _, selection := range re.SelectionSet {
   204  			field := selection.(*ast.Field)
   205  			names = append(names, field.Name)
   206  		}
   207  	}
   208  	return names
   209  }
   210  
   211  // GetParams returns the graphql query parameters and values.
   212  func GetParams(opctx *graphql.OperationContext, schema *ast.QueryDocument) map[string]interface{} {
   213  	params := make(map[string]interface{}, 0)
   214  	if schema == nil {
   215  		return params
   216  	}
   217  	for _, re := range schema.Operations {
   218  		for _, selection := range re.SelectionSet {
   219  			if field, ok := selection.(*ast.Field); ok {
   220  				for _, args := range field.Arguments {
   221  					params[args.Name] = args.Value.String()
   222  				}
   223  			}
   224  		}
   225  	}
   226  	if opctx != nil {
   227  		vars := GetVariables(opctx)
   228  		for key, value := range params {
   229  			v := value
   230  			if val, exists := vars[strings.ToUpper(key)]; exists {
   231  				v = val
   232  			}
   233  			if exists := sensitiveParams[key]; exists {
   234  				v = SensitiveMask
   235  			}
   236  			params[key] = v
   237  		}
   238  	}
   239  	return params
   240  }
   241  
   242  // GetResponseStatus returns the graphql query status.
   243  // Partial Failure if partial data and error(s) were returned.
   244  // Failure if error(s) were returned with no data.
   245  // Success if no error(s) were returned but data was returned.
   246  func GetResponseStatus(resp *graphql.Response) string {
   247  	const nullStr = "null"
   248  	switch {
   249  	case resp != nil && len(resp.Errors) > 0 && string(resp.Data) != nullStr:
   250  		return "Partial Failure"
   251  	case resp != nil && len(resp.Errors) > 0 && string(resp.Data) == nullStr:
   252  		return "Failure"
   253  	case resp != nil && len(resp.Errors) == 0 && string(resp.Data) != nullStr:
   254  		return "Success"
   255  	default:
   256  		return "Unknown"
   257  	}
   258  }
   259  
   260  // UpdateQueryWithVariables updates the graphql document with variables. these variables values are inlined in the query
   261  // Example:
   262  // Query: mutation login($username: String!, $password: String!, $organization: String!) {\n  login(username: $username, password: $password, organization: $organization) {\n    fullName\n    firstName\n    credentialsExpired\n    token\n    __typename\n  }\n}\n
   263  // Becomes: mutation login($username: String!, $password: String!, $organization: String!) {\n login(username: \"test-user\", password: \"123456\", organization: \"test-org\") {\n  fullName\n  firstName\n  credentialsExpired\n  token\n  __typename\n }\n}\n
   264  func UpdateQueryWithVariables(doc *ast.QueryDocument, variables map[string]interface{}) {
   265  	if doc == nil {
   266  		return
   267  	}
   268  	for _, op := range doc.Operations {
   269  		for idx, selection := range op.SelectionSet {
   270  			field, ok := selection.(*ast.Field)
   271  			if ok {
   272  				for i, arg := range field.Arguments {
   273  					argName := strings.Trim(arg.Value.String(), "$")
   274  					val, exists := variables[strings.ToUpper(argName)]
   275  					if !exists && argName == "" {
   276  						// if the parameter is not provided but is in the query
   277  						// set that param to null
   278  						arg.Value.Raw = "<nil>"
   279  						arg.Value.Kind = ast.NullValue
   280  						continue
   281  					}
   282  					kind := getAstKind(val)
   283  					arg.Value.Kind = kind
   284  					getInnerVal(kind, arg.Value, val, (*[]*ast.ChildValue)(&arg.Value.Children))
   285  					field.Arguments[i] = arg
   286  				}
   287  			}
   288  			op.SelectionSet[idx] = field
   289  		}
   290  	}
   291  }
   292  
   293  // getInnerVal is a helper function to set the field argument kind and raw value
   294  // if the type is object or array, we recursively loop through to get the sub value if any.
   295  func getInnerVal(kind ast.ValueKind, arg *ast.Value, val interface{}, children *[]*ast.ChildValue) {
   296  	switch kind {
   297  	case ast.StringValue, ast.BlockValue, ast.EnumValue, ast.Variable:
   298  		arg.Raw = fmt.Sprintf("%s", val)
   299  		arg.Kind = ast.StringValue
   300  	case ast.IntValue:
   301  		arg.Raw = fmt.Sprintf("%v", val)
   302  		arg.Kind = ast.IntValue
   303  	case ast.FloatValue:
   304  		arg.Raw = fmt.Sprintf("%v", val)
   305  		arg.Kind = ast.FloatValue
   306  	case ast.BooleanValue:
   307  		arg.Raw = fmt.Sprintf("%v", val)
   308  		arg.Kind = ast.BooleanValue
   309  	case ast.NullValue:
   310  		arg.Kind = ast.NullValue
   311  	case ast.ObjectValue:
   312  		elem, ok := val.(map[string]interface{})
   313  		if ok {
   314  			for key, value := range elem {
   315  				newChild := &ast.ChildValue{
   316  					Name: key,
   317  					Value: &ast.Value{
   318  						Kind: ast.ObjectValue,
   319  					},
   320  				}
   321  				childKind := getAstKind(value)
   322  				arg.Children = append(arg.Children, newChild)
   323  				getInnerVal(childKind, newChild.Value, value, children)
   324  			}
   325  		} else {
   326  			arg.Raw = fmt.Sprintf("%v", val)
   327  			arg.Kind = ast.StringValue
   328  		}
   329  	case ast.ListValue:
   330  		switch val := val.(type) {
   331  		case []string:
   332  			getSubVal(val, children)
   333  		case []int:
   334  			getSubVal(val, children)
   335  		case []int8:
   336  			getSubVal(val, children)
   337  		case []int16:
   338  			getSubVal(val, children)
   339  		case []int32:
   340  			getSubVal(val, children)
   341  		case []int64:
   342  			getSubVal(val, children)
   343  		case []float32:
   344  			getSubVal(val, children)
   345  		case []float64:
   346  			getSubVal(val, children)
   347  		case []bool:
   348  			getSubVal(val, children)
   349  		case []any:
   350  			getSubVal(val, children)
   351  		}
   352  	}
   353  }
   354  
   355  func getSubVal[T any](arr []T, children *[]*ast.ChildValue) {
   356  	for _, elem := range arr {
   357  		childKind := getAstKind(elem)
   358  		newChild := &ast.ChildValue{Value: &ast.Value{Kind: childKind}}
   359  		*children = append(*children, newChild)
   360  		getInnerVal(childKind, newChild.Value, elem, (*[]*ast.ChildValue)(&newChild.Value.Children))
   361  	}
   362  }
   363  
   364  func getAstKind(_var interface{}) ast.ValueKind {
   365  	switch _var.(type) {
   366  	case int, int8, int16, int32, int64:
   367  		return ast.IntValue
   368  	case float32, float64:
   369  		return ast.FloatValue
   370  	case string:
   371  		return ast.StringValue
   372  	case bool:
   373  		return ast.BooleanValue
   374  	case nil:
   375  		return ast.NullValue
   376  	case []int, []int8, []int16, []int32, []int64, []string, []bool, []any:
   377  		return ast.ListValue
   378  	case struct{}, interface{}:
   379  		return ast.ObjectValue
   380  	default:
   381  		return ast.StringValue
   382  	}
   383  }
   384  
   385  func IsMutation(rctx *graphql.OperationContext) bool {
   386  	return rctx != nil && rctx.Operation != nil && rctx.Operation.Operation == ast.Mutation
   387  }
   388  

View as plain text