...

Source file src/github.com/99designs/gqlgen/codegen/testserver/followschema/subscription_test.go

Documentation: github.com/99designs/gqlgen/codegen/testserver/followschema

     1  package followschema
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"runtime"
     7  	"sort"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/99designs/gqlgen/client"
    14  	"github.com/99designs/gqlgen/graphql"
    15  	"github.com/99designs/gqlgen/graphql/handler"
    16  	"github.com/99designs/gqlgen/graphql/handler/transport"
    17  )
    18  
    19  func TestSubscriptions(t *testing.T) {
    20  	tick := make(chan string, 1)
    21  
    22  	resolvers := &Stub{}
    23  
    24  	resolvers.SubscriptionResolver.InitPayload = func(ctx context.Context) (strings <-chan string, e error) {
    25  		payload := transport.GetInitPayload(ctx)
    26  		channel := make(chan string, len(payload)+1)
    27  
    28  		go func() {
    29  			<-ctx.Done()
    30  			close(channel)
    31  		}()
    32  
    33  		// Test the helper function separately
    34  		auth := payload.Authorization()
    35  		if auth != "" {
    36  			channel <- "AUTH:" + auth
    37  		} else {
    38  			channel <- "AUTH:NONE"
    39  		}
    40  
    41  		// Send them over the channel in alphabetic order
    42  		keys := make([]string, 0, len(payload))
    43  		for key := range payload {
    44  			keys = append(keys, key)
    45  		}
    46  		sort.Strings(keys)
    47  		for _, key := range keys {
    48  			channel <- fmt.Sprintf("%s = %#+v", key, payload[key])
    49  		}
    50  
    51  		return channel, nil
    52  	}
    53  
    54  	errorTick := make(chan *Error, 1)
    55  	resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) {
    56  		res := make(chan *Error, 1)
    57  
    58  		go func() {
    59  			for {
    60  				select {
    61  				case t := <-errorTick:
    62  					res <- t
    63  				case <-ctx.Done():
    64  					close(res)
    65  					return
    66  				}
    67  			}
    68  		}()
    69  		return res, nil
    70  	}
    71  
    72  	resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) {
    73  		res := make(chan string, 1)
    74  
    75  		go func() {
    76  			for {
    77  				select {
    78  				case t := <-tick:
    79  					res <- t
    80  				case <-ctx.Done():
    81  					close(res)
    82  					return
    83  				}
    84  			}
    85  		}()
    86  		return res, nil
    87  	}
    88  
    89  	srv := handler.NewDefaultServer(
    90  		NewExecutableSchema(Config{Resolvers: resolvers}),
    91  	)
    92  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    93  		path, _ := ctx.Value(ckey("path")).([]int)
    94  		return next(context.WithValue(ctx, ckey("path"), append(path, 1)))
    95  	})
    96  
    97  	srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
    98  		path, _ := ctx.Value(ckey("path")).([]int)
    99  		return next(context.WithValue(ctx, ckey("path"), append(path, 2)))
   100  	})
   101  
   102  	c := client.New(srv)
   103  
   104  	t.Run("wont leak goroutines", func(t *testing.T) {
   105  		runtime.GC() // ensure no go-routines left from preceding tests
   106  		initialGoroutineCount := runtime.NumGoroutine()
   107  
   108  		sub := c.Websocket(`subscription { updated }`)
   109  
   110  		tick <- "message"
   111  
   112  		var msg struct {
   113  			resp struct {
   114  				Updated string
   115  			}
   116  		}
   117  
   118  		err := sub.Next(&msg.resp)
   119  		require.NoError(t, err)
   120  		require.Equal(t, "message", msg.resp.Updated)
   121  		sub.Close()
   122  
   123  		// need a little bit of time for goroutines to settle
   124  		start := time.Now()
   125  		for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
   126  			time.Sleep(5 * time.Millisecond)
   127  		}
   128  
   129  		require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
   130  	})
   131  
   132  	t.Run("will parse init payload", func(t *testing.T) {
   133  		runtime.GC() // ensure no go-routines left from preceding tests
   134  		initialGoroutineCount := runtime.NumGoroutine()
   135  
   136  		sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{
   137  			"Authorization": "Bearer of the curse",
   138  			"number":        32,
   139  			"strings":       []string{"hello", "world"},
   140  		})
   141  
   142  		var msg struct {
   143  			resp struct {
   144  				InitPayload string
   145  			}
   146  		}
   147  
   148  		err := sub.Next(&msg.resp)
   149  		require.NoError(t, err)
   150  		require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload)
   151  		err = sub.Next(&msg.resp)
   152  		require.NoError(t, err)
   153  		require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload)
   154  		err = sub.Next(&msg.resp)
   155  		require.NoError(t, err)
   156  		require.Equal(t, "number = 32", msg.resp.InitPayload)
   157  		err = sub.Next(&msg.resp)
   158  		require.NoError(t, err)
   159  		require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload)
   160  		sub.Close()
   161  
   162  		// need a little bit of time for goroutines to settle
   163  		start := time.Now()
   164  		for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
   165  			time.Sleep(5 * time.Millisecond)
   166  		}
   167  
   168  		require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
   169  	})
   170  
   171  	t.Run("websocket gets errors", func(t *testing.T) {
   172  		runtime.GC() // ensure no go-routines left from preceding tests
   173  		initialGoroutineCount := runtime.NumGoroutine()
   174  
   175  		sub := c.Websocket(`subscription { errorRequired { id } }`)
   176  
   177  		errorTick <- &Error{ID: "ID1234"}
   178  
   179  		var msg struct {
   180  			resp struct {
   181  				ErrorRequired *struct {
   182  					Id string
   183  				}
   184  			}
   185  		}
   186  
   187  		err := sub.Next(&msg.resp)
   188  		require.NoError(t, err)
   189  		require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id)
   190  
   191  		errorTick <- nil
   192  		err = sub.Next(&msg.resp)
   193  		require.Error(t, err)
   194  
   195  		sub.Close()
   196  
   197  		// need a little bit of time for goroutines to settle
   198  		start := time.Now()
   199  		for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
   200  			time.Sleep(5 * time.Millisecond)
   201  		}
   202  
   203  		require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
   204  	})
   205  }
   206  

View as plain text