...

Source file src/github.com/99designs/gqlgen/graphql/handler/server_test.go

Documentation: github.com/99designs/gqlgen/graphql/handler

     1  package handler_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/vektah/gqlparser/v2/ast"
    14  	"github.com/vektah/gqlparser/v2/gqlerror"
    15  	"github.com/vektah/gqlparser/v2/parser"
    16  
    17  	"github.com/99designs/gqlgen/graphql"
    18  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    19  	"github.com/99designs/gqlgen/graphql/handler/transport"
    20  )
    21  
    22  func TestServer(t *testing.T) {
    23  	srv := testserver.New()
    24  	srv.AddTransport(&transport.GET{})
    25  
    26  	t.Run("returns an error if no transport matches", func(t *testing.T) {
    27  		resp := post(srv, "/foo", "application/json")
    28  		assert.Equal(t, http.StatusBadRequest, resp.Code)
    29  		assert.Equal(t, `{"errors":[{"message":"transport not supported"}],"data":null}`, resp.Body.String())
    30  	})
    31  
    32  	t.Run("calls query on executable schema", func(t *testing.T) {
    33  		resp := get(srv, "/foo?query={name}")
    34  		assert.Equal(t, http.StatusOK, resp.Code)
    35  		assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
    36  	})
    37  
    38  	t.Run("mutations are forbidden", func(t *testing.T) {
    39  		resp := get(srv, "/foo?query=mutation{name}")
    40  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    41  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    42  	})
    43  
    44  	t.Run("subscriptions are forbidden", func(t *testing.T) {
    45  		resp := get(srv, "/foo?query=subscription{name}")
    46  		assert.Equal(t, http.StatusNotAcceptable, resp.Code)
    47  		assert.Equal(t, `{"errors":[{"message":"GET requests only allow query operations"}],"data":null}`, resp.Body.String())
    48  	})
    49  
    50  	t.Run("invokes operation middleware in order", func(t *testing.T) {
    51  		var calls []string
    52  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    53  			calls = append(calls, "first")
    54  			return next(ctx)
    55  		})
    56  		srv.AroundOperations(func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
    57  			calls = append(calls, "second")
    58  			return next(ctx)
    59  		})
    60  
    61  		resp := get(srv, "/foo?query={name}")
    62  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    63  		assert.Equal(t, []string{"first", "second"}, calls)
    64  	})
    65  
    66  	t.Run("invokes response middleware in order", func(t *testing.T) {
    67  		var calls []string
    68  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    69  			calls = append(calls, "first")
    70  			return next(ctx)
    71  		})
    72  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
    73  			calls = append(calls, "second")
    74  			return next(ctx)
    75  		})
    76  
    77  		resp := get(srv, "/foo?query={name}")
    78  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    79  		assert.Equal(t, []string{"first", "second"}, calls)
    80  	})
    81  
    82  	t.Run("invokes field middleware in order", func(t *testing.T) {
    83  		var calls []string
    84  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    85  			calls = append(calls, "first")
    86  			return next(ctx)
    87  		})
    88  		srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    89  			calls = append(calls, "second")
    90  			return next(ctx)
    91  		})
    92  
    93  		resp := get(srv, "/foo?query={name}")
    94  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
    95  		assert.Equal(t, []string{"first", "second"}, calls)
    96  	})
    97  
    98  	t.Run("get query parse error in AroundResponses", func(t *testing.T) {
    99  		var errors1 gqlerror.List
   100  		var errors2 gqlerror.List
   101  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   102  			resp := next(ctx)
   103  			errors1 = graphql.GetErrors(ctx)
   104  			errors2 = resp.Errors
   105  			return resp
   106  		})
   107  
   108  		resp := get(srv, "/foo?query=invalid")
   109  		assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
   110  		assert.Equal(t, 1, len(errors1))
   111  		assert.Equal(t, 1, len(errors2))
   112  	})
   113  
   114  	t.Run("query caching", func(t *testing.T) {
   115  		ctx := context.Background()
   116  		cache := &graphql.MapCache{}
   117  		srv.SetQueryCache(cache)
   118  		qry := `query Foo {name}`
   119  
   120  		t.Run("cache miss populates cache", func(t *testing.T) {
   121  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   122  			assert.Equal(t, http.StatusOK, resp.Code)
   123  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   124  
   125  			cacheDoc, ok := cache.Get(ctx, qry)
   126  			require.True(t, ok)
   127  			require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   128  		})
   129  
   130  		t.Run("cache hits use document from cache", func(t *testing.T) {
   131  			doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`})
   132  			require.Nil(t, err)
   133  			cache.Add(ctx, qry, doc)
   134  
   135  			resp := get(srv, "/foo?query="+url.QueryEscape(qry))
   136  			assert.Equal(t, http.StatusOK, resp.Code)
   137  			assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
   138  
   139  			cacheDoc, ok := cache.Get(ctx, qry)
   140  			require.True(t, ok)
   141  			require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
   142  		})
   143  	})
   144  }
   145  
   146  func TestErrorServer(t *testing.T) {
   147  	srv := testserver.NewError()
   148  	srv.AddTransport(&transport.GET{})
   149  
   150  	t.Run("get resolver error in AroundResponses", func(t *testing.T) {
   151  		var errors1 gqlerror.List
   152  		var errors2 gqlerror.List
   153  		srv.AroundResponses(func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
   154  			resp := next(ctx)
   155  			errors1 = graphql.GetErrors(ctx)
   156  			errors2 = resp.Errors
   157  			return resp
   158  		})
   159  
   160  		resp := get(srv, "/foo?query={name}")
   161  		assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
   162  		assert.Equal(t, 1, len(errors1))
   163  		assert.Equal(t, 1, len(errors2))
   164  	})
   165  }
   166  
   167  type panicTransport struct{}
   168  
   169  func (t panicTransport) Supports(r *http.Request) bool {
   170  	return true
   171  }
   172  
   173  func (t panicTransport) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
   174  	panic(fmt.Errorf("panic in transport"))
   175  }
   176  
   177  func TestRecover(t *testing.T) {
   178  	srv := testserver.New()
   179  	srv.AddTransport(&panicTransport{})
   180  
   181  	t.Run("recover from panic", func(t *testing.T) {
   182  		resp := get(srv, "/foo?query={name}")
   183  
   184  		assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
   185  	})
   186  }
   187  
   188  func get(handler http.Handler, target string) *httptest.ResponseRecorder {
   189  	r := httptest.NewRequest("GET", target, nil)
   190  	w := httptest.NewRecorder()
   191  
   192  	handler.ServeHTTP(w, r)
   193  	return w
   194  }
   195  
   196  func post(handler http.Handler, target, contentType string) *httptest.ResponseRecorder {
   197  	r := httptest.NewRequest("POST", target, nil)
   198  	r.Header.Set("Content-Type", contentType)
   199  	w := httptest.NewRecorder()
   200  
   201  	handler.ServeHTTP(w, r)
   202  	return w
   203  }
   204  

View as plain text