1 package followschema
2
3 import (
4 "context"
5 "fmt"
6 "runtime"
7 "sort"
8 "testing"
9 "time"
10
11 "github.com/stretchr/testify/require"
12
13 "github.com/99designs/gqlgen/client"
14 "github.com/99designs/gqlgen/graphql"
15 "github.com/99designs/gqlgen/graphql/handler"
16 "github.com/99designs/gqlgen/graphql/handler/transport"
17 )
18
19 func TestSubscriptions(t *testing.T) {
20 tick := make(chan string, 1)
21
22 resolvers := &Stub{}
23
24 resolvers.SubscriptionResolver.InitPayload = func(ctx context.Context) (strings <-chan string, e error) {
25 payload := transport.GetInitPayload(ctx)
26 channel := make(chan string, len(payload)+1)
27
28 go func() {
29 <-ctx.Done()
30 close(channel)
31 }()
32
33
34 auth := payload.Authorization()
35 if auth != "" {
36 channel <- "AUTH:" + auth
37 } else {
38 channel <- "AUTH:NONE"
39 }
40
41
42 keys := make([]string, 0, len(payload))
43 for key := range payload {
44 keys = append(keys, key)
45 }
46 sort.Strings(keys)
47 for _, key := range keys {
48 channel <- fmt.Sprintf("%s = %#+v", key, payload[key])
49 }
50
51 return channel, nil
52 }
53
54 errorTick := make(chan *Error, 1)
55 resolvers.SubscriptionResolver.ErrorRequired = func(ctx context.Context) (<-chan *Error, error) {
56 res := make(chan *Error, 1)
57
58 go func() {
59 for {
60 select {
61 case t := <-errorTick:
62 res <- t
63 case <-ctx.Done():
64 close(res)
65 return
66 }
67 }
68 }()
69 return res, nil
70 }
71
72 resolvers.SubscriptionResolver.Updated = func(ctx context.Context) (<-chan string, error) {
73 res := make(chan string, 1)
74
75 go func() {
76 for {
77 select {
78 case t := <-tick:
79 res <- t
80 case <-ctx.Done():
81 close(res)
82 return
83 }
84 }
85 }()
86 return res, nil
87 }
88
89 srv := handler.NewDefaultServer(
90 NewExecutableSchema(Config{Resolvers: resolvers}),
91 )
92 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
93 path, _ := ctx.Value(ckey("path")).([]int)
94 return next(context.WithValue(ctx, ckey("path"), append(path, 1)))
95 })
96
97 srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
98 path, _ := ctx.Value(ckey("path")).([]int)
99 return next(context.WithValue(ctx, ckey("path"), append(path, 2)))
100 })
101
102 c := client.New(srv)
103
104 t.Run("wont leak goroutines", func(t *testing.T) {
105 runtime.GC()
106 initialGoroutineCount := runtime.NumGoroutine()
107
108 sub := c.Websocket(`subscription { updated }`)
109
110 tick <- "message"
111
112 var msg struct {
113 resp struct {
114 Updated string
115 }
116 }
117
118 err := sub.Next(&msg.resp)
119 require.NoError(t, err)
120 require.Equal(t, "message", msg.resp.Updated)
121 sub.Close()
122
123
124 start := time.Now()
125 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
126 time.Sleep(5 * time.Millisecond)
127 }
128
129 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
130 })
131
132 t.Run("will parse init payload", func(t *testing.T) {
133 runtime.GC()
134 initialGoroutineCount := runtime.NumGoroutine()
135
136 sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{
137 "Authorization": "Bearer of the curse",
138 "number": 32,
139 "strings": []string{"hello", "world"},
140 })
141
142 var msg struct {
143 resp struct {
144 InitPayload string
145 }
146 }
147
148 err := sub.Next(&msg.resp)
149 require.NoError(t, err)
150 require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload)
151 err = sub.Next(&msg.resp)
152 require.NoError(t, err)
153 require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload)
154 err = sub.Next(&msg.resp)
155 require.NoError(t, err)
156 require.Equal(t, "number = 32", msg.resp.InitPayload)
157 err = sub.Next(&msg.resp)
158 require.NoError(t, err)
159 require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload)
160 sub.Close()
161
162
163 start := time.Now()
164 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
165 time.Sleep(5 * time.Millisecond)
166 }
167
168 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
169 })
170
171 t.Run("websocket gets errors", func(t *testing.T) {
172 runtime.GC()
173 initialGoroutineCount := runtime.NumGoroutine()
174
175 sub := c.Websocket(`subscription { errorRequired { id } }`)
176
177 errorTick <- &Error{ID: "ID1234"}
178
179 var msg struct {
180 resp struct {
181 ErrorRequired *struct {
182 Id string
183 }
184 }
185 }
186
187 err := sub.Next(&msg.resp)
188 require.NoError(t, err)
189 require.Equal(t, "ID1234", msg.resp.ErrorRequired.Id)
190
191 errorTick <- nil
192 err = sub.Next(&msg.resp)
193 require.Error(t, err)
194
195 sub.Close()
196
197
198 start := time.Now()
199 for time.Since(start).Seconds() < 2 && initialGoroutineCount != runtime.NumGoroutine() {
200 time.Sleep(5 * time.Millisecond)
201 }
202
203 require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
204 })
205 }
206
View as plain text