...

Source file src/github.com/99designs/gqlgen/graphql/handler/extension/complexity_test.go

Documentation: github.com/99designs/gqlgen/graphql/handler/extension

     1  package extension_test
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"strings"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/99designs/gqlgen/graphql"
    13  	"github.com/99designs/gqlgen/graphql/handler/extension"
    14  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    15  	"github.com/99designs/gqlgen/graphql/handler/transport"
    16  )
    17  
    18  func TestHandlerComplexity(t *testing.T) {
    19  	h := testserver.New()
    20  	h.Use(&extension.ComplexityLimit{
    21  		Func: func(ctx context.Context, rc *graphql.OperationContext) int {
    22  			if rc.RawQuery == "{ ok: name }" {
    23  				return 4
    24  			}
    25  			return 2
    26  		},
    27  	})
    28  	h.AddTransport(&transport.POST{})
    29  	var stats *extension.ComplexityStats
    30  	h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    31  		stats = extension.GetComplexityStats(ctx)
    32  		return next(ctx)
    33  	})
    34  
    35  	t.Run("below complexity limit", func(t *testing.T) {
    36  		stats = nil
    37  		h.SetCalculatedComplexity(2)
    38  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    39  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    40  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    41  
    42  		require.Equal(t, 2, stats.ComplexityLimit)
    43  		require.Equal(t, 2, stats.Complexity)
    44  	})
    45  
    46  	t.Run("above complexity limit", func(t *testing.T) {
    47  		stats = nil
    48  		h.SetCalculatedComplexity(4)
    49  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    50  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    51  		require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
    52  
    53  		require.Equal(t, 2, stats.ComplexityLimit)
    54  		require.Equal(t, 4, stats.Complexity)
    55  	})
    56  
    57  	t.Run("within dynamic complexity limit", func(t *testing.T) {
    58  		stats = nil
    59  		h.SetCalculatedComplexity(4)
    60  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ ok: name }"}`)
    61  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    62  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    63  
    64  		require.Equal(t, 4, stats.ComplexityLimit)
    65  		require.Equal(t, 4, stats.Complexity)
    66  	})
    67  }
    68  
    69  func TestFixedComplexity(t *testing.T) {
    70  	h := testserver.New()
    71  	h.Use(extension.FixedComplexityLimit(2))
    72  	h.AddTransport(&transport.POST{})
    73  
    74  	var stats *extension.ComplexityStats
    75  	h.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    76  		stats = extension.GetComplexityStats(ctx)
    77  		return next(ctx)
    78  	})
    79  
    80  	t.Run("below complexity limit", func(t *testing.T) {
    81  		h.SetCalculatedComplexity(2)
    82  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    83  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    84  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    85  
    86  		require.Equal(t, 2, stats.ComplexityLimit)
    87  		require.Equal(t, 2, stats.Complexity)
    88  	})
    89  
    90  	t.Run("above complexity limit", func(t *testing.T) {
    91  		h.SetCalculatedComplexity(4)
    92  		resp := doRequest(h, "POST", "/graphql", `{"query":"{ name }"}`)
    93  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    94  		require.Equal(t, `{"errors":[{"message":"operation has complexity 4, which exceeds the limit of 2","extensions":{"code":"COMPLEXITY_LIMIT_EXCEEDED"}}],"data":null}`, resp.Body.String())
    95  
    96  		require.Equal(t, 2, stats.ComplexityLimit)
    97  		require.Equal(t, 4, stats.Complexity)
    98  	})
    99  
   100  	t.Run("bypass __schema field", func(t *testing.T) {
   101  		h.SetCalculatedComplexity(4)
   102  		resp := doRequest(h, "POST", "/graphql", `{ "operationName":"IntrospectionQuery", "query":"query IntrospectionQuery { __schema { queryType { name } mutationType { name }}}"}`)
   103  		require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
   104  		require.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   105  
   106  		require.Equal(t, 2, stats.ComplexityLimit)
   107  		require.Equal(t, 0, stats.Complexity)
   108  	})
   109  }
   110  
   111  func doRequest(handler http.Handler, method string, target string, body string) *httptest.ResponseRecorder {
   112  	r := httptest.NewRequest(method, target, strings.NewReader(body))
   113  	r.Header.Set("Content-Type", "application/json")
   114  	w := httptest.NewRecorder()
   115  
   116  	handler.ServeHTTP(w, r)
   117  	return w
   118  }
   119  

View as plain text