1 package extension_test
2
3 import (
4 "context"
5 "net/http"
6 "net/http/httptest"
7 "strings"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11
12 "github.com/99designs/gqlgen/graphql"
13 "github.com/99designs/gqlgen/graphql/handler/extension"
14 "github.com/99designs/gqlgen/graphql/handler/testserver"
15 "github.com/99designs/gqlgen/graphql/handler/transport"
16 )
17
18 func TestHandlerComplexity(t *testing.T) {
19 h := testserver.New()
20 h.Use(&extension.ComplexityLimit{
21 Func: func(ctx context.Context, rc *graphql.OperationContext) int {
22 if rc.RawQuery == "{ ok: name }" {
23 return 4
24 }
25 return 2
26 },
27 })
28 h.AddTransport(&transport.POST{})
29 var stats *extension.ComplexityStats
30 h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
31 stats = extension.GetComplexityStats(ctx)
32 return next(ctx)
33 })
34
35 t.Run("below complexity limit", func(t *testing.T) {
36 stats = nil
37 h.SetCalculatedComplexity(2)
38 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
39 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
40 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
41
42 require.Equal(t, 2, stats.ComplexityLimit)
43 require.Equal(t, 2, stats.Complexity)
44 })
45
46 t.Run("above complexity limit", func(t *testing.T) {
47 stats = nil
48 h.SetCalculatedComplexity(4)
49 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
50 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
51 require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
52
53 require.Equal(t, 2, stats.ComplexityLimit)
54 require.Equal(t, 4, stats.Complexity)
55 })
56
57 t.Run("within dynamic complexity limit", func(t *testing.T) {
58 stats = nil
59 h.SetCalculatedComplexity(4)
60 resp := doRequest(h, "POST", "/graphql", `{"query":"{ ok: name }"}`)
61 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
62 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
63
64 require.Equal(t, 4, stats.ComplexityLimit)
65 require.Equal(t, 4, stats.Complexity)
66 })
67 }
68
69 func TestFixedComplexity(t *testing.T) {
70 h := testserver.New()
71 h.Use(extension.FixedComplexityLimit(2))
72 h.AddTransport(&transport.POST{})
73
74 var stats *extension.ComplexityStats
75 h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
76 stats = extension.GetComplexityStats(ctx)
77 return next(ctx)
78 })
79
80 t.Run("below complexity limit", func(t *testing.T) {
81 h.SetCalculatedComplexity(2)
82 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
83 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
84 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
85
86 require.Equal(t, 2, stats.ComplexityLimit)
87 require.Equal(t, 2, stats.Complexity)
88 })
89
90 t.Run("above complexity limit", func(t *testing.T) {
91 h.SetCalculatedComplexity(4)
92 resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
93 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
94 require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
95
96 require.Equal(t, 2, stats.ComplexityLimit)
97 require.Equal(t, 4, stats.Complexity)
98 })
99
100 t.Run("bypass __schema field", func(t *testing.T) {
101 h.SetCalculatedComplexity(4)
102 resp := doRequest(h, "POST", "/graphql", `{ "operationName":"IntrospectionQuery", "query":"query IntrospectionQuery { __schema { queryType { name } mutationType { name }}}"}`)
103 require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
104 require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
105
106 require.Equal(t, 2, stats.ComplexityLimit)
107 require.Equal(t, 0, stats.Complexity)
108 })
109 }
110
111 func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder {
112 r := httptest.NewRequest(method, target, strings.NewReader(body))
113 r.Header.Set("Content-Type", "application/json")
114 w := httptest.NewRecorder()
115
116 handler.ServeHTTP(w, r)
117 return w
118 }
119
View as plain text