...
1 package graphql
2
3 import (
4 "context"
5 "errors"
6 "net/http"
7
8 "github.com/vektah/gqlparser/v2/ast"
9 "github.com/vektah/gqlparser/v2/gqlerror"
10 )
11
12
13 type RequestContext = OperationContext
14
15 type OperationContext struct {
16 RawQuery string
17 Variables map[string]interface{}
18 OperationName string
19 Doc *ast.QueryDocument
20 Headers http.Header
21
22 Operation *ast.OperationDefinition
23 DisableIntrospection bool
24 RecoverFunc RecoverFunc
25 ResolverMiddleware FieldMiddleware
26 RootResolverMiddleware RootFieldMiddleware
27
28 Stats Stats
29 }
30
31 func (c *OperationContext) Validate(ctx context.Context) error {
32 if c.Doc == nil {
33 return errors.New("field 'Doc'is required")
34 }
35 if c.RawQuery == "" {
36 return errors.New("field 'RawQuery' is required")
37 }
38 if c.Variables == nil {
39 c.Variables = make(map[string]interface{})
40 }
41 if c.ResolverMiddleware == nil {
42 return errors.New("field 'ResolverMiddleware' is required")
43 }
44 if c.RootResolverMiddleware == nil {
45 return errors.New("field 'RootResolverMiddleware' is required")
46 }
47 if c.RecoverFunc == nil {
48 c.RecoverFunc = DefaultRecover
49 }
50
51 return nil
52 }
53
54 const operationCtx key = "operation_context"
55
56
57 func GetRequestContext(ctx context.Context) *RequestContext {
58 return GetOperationContext(ctx)
59 }
60
61 func GetOperationContext(ctx context.Context) *OperationContext {
62 if val, ok := ctx.Value(operationCtx).(*OperationContext); ok && val != nil {
63 return val
64 }
65 panic("missing operation context")
66 }
67
68 func WithOperationContext(ctx context.Context, rc *OperationContext) context.Context {
69 return context.WithValue(ctx, operationCtx, rc)
70 }
71
72
73
74
75 func HasOperationContext(ctx context.Context) bool {
76 val, ok := ctx.Value(operationCtx).(*OperationContext)
77 return ok && val != nil
78 }
79
80
81 func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
82 resctx := GetFieldContext(ctx)
83 return CollectFields(GetOperationContext(ctx), resctx.Field.Selections, satisfies)
84 }
85
86
87
88 func CollectAllFields(ctx context.Context) []string {
89 resctx := GetFieldContext(ctx)
90 collected := CollectFields(GetOperationContext(ctx), resctx.Field.Selections, nil)
91 uniq := make([]string, 0, len(collected))
92 Next:
93 for _, f := range collected {
94 for _, name := range uniq {
95 if name == f.Name {
96 continue Next
97 }
98 }
99 uniq = append(uniq, f.Name)
100 }
101 return uniq
102 }
103
104
105
106 func (c *OperationContext) Errorf(ctx context.Context, format string, args ...interface{}) {
107 AddErrorf(ctx, format, args...)
108 }
109
110
111
112 func (c *OperationContext) Error(ctx context.Context, err error) {
113 if errList, ok := err.(gqlerror.List); ok {
114 for _, e := range errList {
115 AddError(ctx, e)
116 }
117 return
118 }
119
120 AddError(ctx, err)
121 }
122
123 func (c *OperationContext) Recover(ctx context.Context, err interface{}) error {
124 return ErrorOnPath(ctx, c.RecoverFunc(ctx, err))
125 }
126
View as plain text