...
1 package graphql
2
3 import (
4 "context"
5 "fmt"
6 "sync"
7
8 "github.com/vektah/gqlparser/v2/gqlerror"
9 )
10
11 type responseContext struct {
12 errorPresenter ErrorPresenterFunc
13 recover RecoverFunc
14
15 errors gqlerror.List
16 errorsMu sync.Mutex
17
18 extensions map[string]interface{}
19 extensionsMu sync.Mutex
20 }
21
22 const resultCtx key = "result_context"
23
24 func getResponseContext(ctx context.Context) *responseContext {
25 val, ok := ctx.Value(resultCtx).(*responseContext)
26 if !ok {
27 panic("missing response context")
28 }
29 return val
30 }
31
32 func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
33 return context.WithValue(ctx, resultCtx, &responseContext{
34 errorPresenter: presenterFunc,
35 recover: recoverFunc,
36 })
37 }
38
39 func WithFreshResponseContext(ctx context.Context) context.Context {
40 e := getResponseContext(ctx)
41 return context.WithValue(ctx, resultCtx, &responseContext{
42 errorPresenter: e.errorPresenter,
43 recover: e.recover,
44 })
45 }
46
47
48 func AddErrorf(ctx context.Context, format string, args ...interface{}) {
49 AddError(ctx, fmt.Errorf(format, args...))
50 }
51
52
53 func AddError(ctx context.Context, err error) {
54 c := getResponseContext(ctx)
55
56 presentedError := c.errorPresenter(ctx, ErrorOnPath(ctx, err))
57
58 c.errorsMu.Lock()
59 defer c.errorsMu.Unlock()
60 c.errors = append(c.errors, presentedError)
61 }
62
63 func Recover(ctx context.Context, err interface{}) (userMessage error) {
64 c := getResponseContext(ctx)
65 return ErrorOnPath(ctx, c.recover(ctx, err))
66 }
67
68
69 func HasFieldError(ctx context.Context, rctx *FieldContext) bool {
70 c := getResponseContext(ctx)
71
72 c.errorsMu.Lock()
73 defer c.errorsMu.Unlock()
74
75 if len(c.errors) == 0 {
76 return false
77 }
78
79 path := rctx.Path()
80 for _, err := range c.errors {
81 if equalPath(err.Path, path) {
82 return true
83 }
84 }
85 return false
86 }
87
88
89 func GetFieldErrors(ctx context.Context, rctx *FieldContext) gqlerror.List {
90 c := getResponseContext(ctx)
91
92 c.errorsMu.Lock()
93 defer c.errorsMu.Unlock()
94
95 if len(c.errors) == 0 {
96 return nil
97 }
98
99 path := rctx.Path()
100 var errs gqlerror.List
101 for _, err := range c.errors {
102 if equalPath(err.Path, path) {
103 errs = append(errs, err)
104 }
105 }
106 return errs
107 }
108
109 func GetErrors(ctx context.Context) gqlerror.List {
110 resCtx := getResponseContext(ctx)
111 resCtx.errorsMu.Lock()
112 defer resCtx.errorsMu.Unlock()
113
114 if len(resCtx.errors) == 0 {
115 return nil
116 }
117
118 errs := resCtx.errors
119 cpy := make(gqlerror.List, len(errs))
120 for i := range errs {
121 errCpy := *errs[i]
122 cpy[i] = &errCpy
123 }
124 return cpy
125 }
126
127
128 func RegisterExtension(ctx context.Context, key string, value interface{}) {
129 c := getResponseContext(ctx)
130 c.extensionsMu.Lock()
131 defer c.extensionsMu.Unlock()
132
133 if c.extensions == nil {
134 c.extensions = make(map[string]interface{})
135 }
136
137 if _, ok := c.extensions[key]; ok {
138 panic(fmt.Errorf("extension already registered for key %s", key))
139 }
140
141 c.extensions[key] = value
142 }
143
144
145 func GetExtensions(ctx context.Context) map[string]interface{} {
146 ext := getResponseContext(ctx).extensions
147 if ext == nil {
148 return map[string]interface{}{}
149 }
150
151 return ext
152 }
153
154 func GetExtension(ctx context.Context, name string) interface{} {
155 ext := getResponseContext(ctx).extensions
156 if ext == nil {
157 return nil
158 }
159
160 return ext[name]
161 }
162
View as plain text