1 package handler_test
2
3 import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/http/httptest"
8 "net/url"
9 "testing"
10
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13 "github.com/vektah/gqlparser/v2/ast"
14 "github.com/vektah/gqlparser/v2/gqlerror"
15 "github.com/vektah/gqlparser/v2/parser"
16
17 "github.com/99designs/gqlgen/graphql"
18 "github.com/99designs/gqlgen/graphql/handler/testserver"
19 "github.com/99designs/gqlgen/graphql/handler/transport"
20 )
21
22 func TestServer(t *testing.T) {
23 srv := testserver.New()
24 srv.AddTransport(&transport.GET{})
25
26 t.Run("returns an error if no transport matches", func(t *testing.T) {
27 resp := post(srv, "/foo", "application/json")
28 assert.Equal(t, http.StatusBadRequest, resp.Code)
29 assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String())
30 })
31
32 t.Run("calls query on executable schema", func(t *testing.T) {
33 resp := get(srv, "/foo?query={name}")
34 assert.Equal(t, http.StatusOK, resp.Code)
35 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
36 })
37
38 t.Run("mutations are forbidden", func(t *testing.T) {
39 resp := get(srv, "/foo?query=mutation{name}")
40 assert.Equal(t, http.StatusNotAcceptable, resp.Code)
41 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
42 })
43
44 t.Run("subscriptions are forbidden", func(t *testing.T) {
45 resp := get(srv, "/foo?query=subscription{name}")
46 assert.Equal(t, http.StatusNotAcceptable, resp.Code)
47 assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
48 })
49
50 t.Run("invokes operation middleware in order", func(t *testing.T) {
51 var calls []string
52 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
53 calls = append(calls, "first")
54 return next(ctx)
55 })
56 srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
57 calls = append(calls, "second")
58 return next(ctx)
59 })
60
61 resp := get(srv, "/foo?query={name}")
62 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
63 assert.Equal(t, []string{"first", "second"}, calls)
64 })
65
66 t.Run("invokes response middleware in order", func(t *testing.T) {
67 var calls []string
68 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
69 calls = append(calls, "first")
70 return next(ctx)
71 })
72 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
73 calls = append(calls, "second")
74 return next(ctx)
75 })
76
77 resp := get(srv, "/foo?query={name}")
78 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
79 assert.Equal(t, []string{"first", "second"}, calls)
80 })
81
82 t.Run("invokes field middleware in order", func(t *testing.T) {
83 var calls []string
84 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
85 calls = append(calls, "first")
86 return next(ctx)
87 })
88 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
89 calls = append(calls, "second")
90 return next(ctx)
91 })
92
93 resp := get(srv, "/foo?query={name}")
94 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
95 assert.Equal(t, []string{"first", "second"}, calls)
96 })
97
98 t.Run("get query parse error in AroundResponses", func(t *testing.T) {
99 var errors1 gqlerror.List
100 var errors2 gqlerror.List
101 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
102 resp := next(ctx)
103 errors1 = graphql.GetErrors(ctx)
104 errors2 = resp.Errors
105 return resp
106 })
107
108 resp := get(srv, "/foo?query=invalid")
109 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
110 assert.Equal(t, 1, len(errors1))
111 assert.Equal(t, 1, len(errors2))
112 })
113
114 t.Run("query caching", func(t *testing.T) {
115 ctx := context.Background()
116 cache := &graphql.MapCache{}
117 srv.SetQueryCache(cache)
118 qry := `query Foo {name}`
119
120 t.Run("cache miss populates cache", func(t *testing.T) {
121 resp := get(srv, "/foo?query="+url.QueryEscape(qry))
122 assert.Equal(t, http.StatusOK, resp.Code)
123 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
124
125 cacheDoc, ok := cache.Get(ctx, qry)
126 require.True(t, ok)
127 require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
128 })
129
130 t.Run("cache hits use document from cache", func(t *testing.T) {
131 doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`})
132 require.Nil(t, err)
133 cache.Add(ctx, qry, doc)
134
135 resp := get(srv, "/foo?query="+url.QueryEscape(qry))
136 assert.Equal(t, http.StatusOK, resp.Code)
137 assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
138
139 cacheDoc, ok := cache.Get(ctx, qry)
140 require.True(t, ok)
141 require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
142 })
143 })
144 }
145
146 func TestErrorServer(t *testing.T) {
147 srv := testserver.NewError()
148 srv.AddTransport(&transport.GET{})
149
150 t.Run("get resolver error in AroundResponses", func(t *testing.T) {
151 var errors1 gqlerror.List
152 var errors2 gqlerror.List
153 srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
154 resp := next(ctx)
155 errors1 = graphql.GetErrors(ctx)
156 errors2 = resp.Errors
157 return resp
158 })
159
160 resp := get(srv, "/foo?query={name}")
161 assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
162 assert.Equal(t, 1, len(errors1))
163 assert.Equal(t, 1, len(errors2))
164 })
165 }
166
167 type panicTransport struct{}
168
169 func (t panicTransport) Supports(r *http.Request) bool {
170 return true
171 }
172
173 func (t panicTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
174 panic(fmt.Errorf("panic in transport"))
175 }
176
177 func TestRecover(t *testing.T) {
178 srv := testserver.New()
179 srv.AddTransport(&panicTransport{})
180
181 t.Run("recover from panic", func(t *testing.T) {
182 resp := get(srv, "/foo?query={name}")
183
184 assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
185 })
186 }
187
188 func get(handler http.Handler, target string) *httptest.ResponseRecorder {
189 r := httptest.NewRequest("GET", target, nil)
190 w := httptest.NewRecorder()
191
192 handler.ServeHTTP(w, r)
193 return w
194 }
195
196 func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder {
197 r := httptest.NewRequest("POST", target, nil)
198 r.Header.Set("Content-Type", contentType)
199 w := httptest.NewRecorder()
200
201 handler.ServeHTTP(w, r)
202 return w
203 }
204
View as plain text