1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package pubsub
16
17
18
19
20
21 import (
22 "context"
23 "io"
24 "strconv"
25 "sync"
26 "sync/atomic"
27 "testing"
28 "time"
29
30 "cloud.google.com/go/internal/testutil"
31 pb "cloud.google.com/go/pubsub/apiv1/pubsubpb"
32 "github.com/google/go-cmp/cmp"
33 "github.com/google/go-cmp/cmp/cmpopts"
34 "google.golang.org/api/option"
35 "google.golang.org/grpc"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/status"
38 tspb "google.golang.org/protobuf/types/known/timestamppb"
39 )
40
41 var (
42 timestamp = &tspb.Timestamp{}
43 testMessages = []*pb.ReceivedMessage{
44 {AckId: "0", Message: &pb.PubsubMessage{Data: []byte{1}, PublishTime: timestamp}},
45 {AckId: "1", Message: &pb.PubsubMessage{Data: []byte{2}, PublishTime: timestamp}},
46 {AckId: "2", Message: &pb.PubsubMessage{Data: []byte{3}, PublishTime: timestamp}},
47 }
48 )
49
50 func TestStreamingPullBasic(t *testing.T) {
51 client, server := newMock(t)
52 defer server.srv.Close()
53 defer client.Close()
54 server.addStreamingPullMessages(testMessages)
55 testStreamingPullIteration(t, client, server, testMessages)
56 }
57
58 func TestStreamingPullMultipleFetches(t *testing.T) {
59 client, server := newMock(t)
60 defer server.srv.Close()
61 defer client.Close()
62 server.addStreamingPullMessages(testMessages[:1])
63 server.addStreamingPullMessages(testMessages[1:])
64 testStreamingPullIteration(t, client, server, testMessages)
65 }
66
67 func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
68 sub := client.Subscription("S")
69 gotMsgs, err := pullN(context.Background(), sub, len(msgs), 0, func(_ context.Context, m *Message) {
70 id, err := strconv.Atoi(msgAckID(m))
71 if err != nil {
72 t.Fatalf("pullN err: %v", err)
73 }
74
75 if id%2 == 0 {
76 m.Ack()
77 } else {
78 m.Nack()
79 }
80 })
81 if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
82 t.Fatalf("Pull: %v", err)
83 }
84 gotMap := map[string]*Message{}
85 for _, m := range gotMsgs {
86 gotMap[msgAckID(m)] = m
87 }
88 for i, msg := range msgs {
89 want, err := toMessage(msg, time.Time{}, nil)
90 if err != nil {
91 t.Fatal(err)
92 }
93 wantAckh, _ := msgAckHandler(want, false)
94 wantAckh.calledDone = true
95 got := gotMap[wantAckh.ackID]
96 if got == nil {
97 t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
98 continue
99 }
100 opts := []cmp.Option{
101 cmp.AllowUnexported(Message{}, psAckHandler{}),
102 cmpopts.IgnoreTypes(
103 time.Time{},
104 func(string, bool,
105 *AckResult, time.Time) {
106 },
107 AckResult{},
108 ),
109 }
110 if !testutil.Equal(got, want, opts...) {
111 t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
112 }
113 }
114 server.wait()
115 for i := 0; i < len(msgs); i++ {
116 id := msgs[i].AckId
117 if i%2 == 0 {
118 if !server.Acked[id] {
119 t.Errorf("msg %q should have been acked but wasn't", id)
120 }
121 } else {
122 if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
123 t.Errorf("msg %q should have been nacked but wasn't", id)
124 }
125 }
126 }
127 }
128
129 func TestStreamingPullError(t *testing.T) {
130
131
132
133 client, server := newMock(t)
134 defer server.srv.Close()
135 defer client.Close()
136 server.addStreamingPullMessages(testMessages[:1])
137 server.addStreamingPullError(status.Errorf(codes.Unknown, ""))
138 sub := client.Subscription("S")
139
140
141 sub.ReceiveSettings.NumGoroutines = 1
142 callbackDone := make(chan struct{})
143 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
144 defer cancel()
145 err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
146 defer close(callbackDone)
147 <-ctx.Done()
148 })
149 select {
150 case <-callbackDone:
151 default:
152 t.Fatal("Receive returned but callback was not done")
153 }
154 if want := codes.Unknown; status.Code(err) != want {
155 t.Fatalf("got <%v>, want code %v", err, want)
156 }
157 }
158
159 func TestStreamingPullCancel(t *testing.T) {
160
161
162 client, server := newMock(t)
163 defer server.srv.Close()
164 defer client.Close()
165 server.addStreamingPullMessages(testMessages)
166 sub := client.Subscription("S")
167 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
168 var n int32
169 err := sub.Receive(ctx, func(ctx2 context.Context, m *Message) {
170 atomic.AddInt32(&n, 1)
171 defer atomic.AddInt32(&n, -1)
172 cancel()
173 m.Ack()
174 })
175 if got := atomic.LoadInt32(&n); got != 0 {
176 t.Fatalf("Receive returned with %d callbacks still running", got)
177 }
178 if err != nil {
179 t.Fatalf("Receive got <%v>, want nil", err)
180 }
181 }
182
183 func TestStreamingPullRetry(t *testing.T) {
184
185 t.Parallel()
186 client, server := newMock(t)
187 defer server.srv.Close()
188 defer client.Close()
189 server.addStreamingPullMessages(testMessages[:1])
190 server.addStreamingPullError(io.EOF)
191 server.addStreamingPullError(io.EOF)
192 server.addStreamingPullMessages(testMessages[1:2])
193 server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
194 server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
195 server.addStreamingPullMessages(testMessages[2:])
196
197 sub := client.Subscription("S")
198 sub.ReceiveSettings.NumGoroutines = 1
199 gotMsgs, err := pullN(context.Background(), sub, len(testMessages), 0, func(_ context.Context, m *Message) {
200 id, err := strconv.Atoi(msgAckID(m))
201 if err != nil {
202 t.Fatalf("pullN err: %v", err)
203 }
204
205 if id%2 == 0 {
206 m.Ack()
207 } else {
208 m.Nack()
209 }
210 })
211 if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
212 t.Fatalf("Pull: %v", err)
213 }
214 gotMap := map[string]*Message{}
215 for _, m := range gotMsgs {
216 gotMap[msgAckID(m)] = m
217 }
218 for i, msg := range testMessages {
219 want, err := toMessage(msg, time.Time{}, nil)
220 if err != nil {
221 t.Fatal(err)
222 }
223 wantAckh, _ := msgAckHandler(want, false)
224 wantAckh.calledDone = true
225 got := gotMap[wantAckh.ackID]
226 if got == nil {
227 t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
228 continue
229 }
230 opts := []cmp.Option{
231 cmp.AllowUnexported(Message{}, psAckHandler{}),
232 cmpopts.IgnoreTypes(
233 time.Time{},
234 func(string, bool,
235 *AckResult, time.Time) {
236 },
237 AckResult{},
238 ),
239 }
240 if !testutil.Equal(got, want, opts...) {
241 t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
242 }
243 }
244 server.wait()
245 for i := 0; i < len(testMessages); i++ {
246 id := testMessages[i].AckId
247 if i%2 == 0 {
248 if !server.Acked[id] {
249 t.Errorf("msg %q should have been acked but wasn't", id)
250 }
251 } else {
252 if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
253 t.Errorf("msg %q should have been nacked but wasn't", id)
254 }
255 }
256 }
257 }
258
259 func TestStreamingPullOneActive(t *testing.T) {
260
261 client, srv := newMock(t)
262 defer client.Close()
263 defer srv.srv.Close()
264 srv.addStreamingPullMessages(testMessages[:1])
265 sub := client.Subscription("S")
266 ctx, cancel := context.WithCancel(context.Background())
267 err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
268 m.Ack()
269 err := sub.Receive(ctx, func(context.Context, *Message) {})
270 if err != errReceiveInProgress {
271 t.Errorf("got <%v>, want <%v>", err, errReceiveInProgress)
272 }
273 cancel()
274 })
275 if err != nil {
276 t.Fatalf("got <%v>, want nil", err)
277 }
278 }
279
280 func TestStreamingPullConcurrent(t *testing.T) {
281 newMsg := func(i int) *pb.ReceivedMessage {
282 return &pb.ReceivedMessage{
283 AckId: strconv.Itoa(i),
284 Message: &pb.PubsubMessage{Data: []byte{byte(i)}, PublishTime: timestamp},
285 }
286 }
287
288
289 client, server := newMock(t)
290 defer server.srv.Close()
291 defer client.Close()
292
293 nMessages := 100
294 for i := 0; i < nMessages; i += 2 {
295 server.addStreamingPullMessages([]*pb.ReceivedMessage{newMsg(i), newMsg(i + 1)})
296 }
297 sub := client.Subscription("S")
298 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
299 defer cancel()
300 gotMsgs, err := pullN(ctx, sub, nMessages, 0, func(ctx context.Context, m *Message) {
301 m.Ack()
302 })
303 if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
304 t.Fatalf("Pull: %v", err)
305 }
306 seen := map[string]bool{}
307 for _, gm := range gotMsgs {
308 if seen[msgAckID(gm)] {
309 t.Fatalf("duplicate ID %q", msgAckID(gm))
310 }
311 seen[msgAckID(gm)] = true
312 }
313 if len(seen) != nMessages {
314 t.Fatalf("got %d messages, want %d", len(seen), nMessages)
315 }
316 }
317
318 func TestStreamingPullFlowControl(t *testing.T) {
319
320 client, server := newMock(t)
321 defer server.srv.Close()
322 defer client.Close()
323 server.addStreamingPullMessages(testMessages)
324 sub := client.Subscription("S")
325 sub.ReceiveSettings.MaxOutstandingMessages = 2
326 ctx, cancel := context.WithCancel(context.Background())
327 activec := make(chan int)
328 waitc := make(chan int)
329 errc := make(chan error)
330 go func() {
331 errc <- sub.Receive(ctx, func(_ context.Context, m *Message) {
332 activec <- 1
333 <-waitc
334 m.Ack()
335 })
336 }()
337
338
339 for i := 0; i < 2; i++ {
340 select {
341 case <-activec:
342 case <-time.After(time.Second):
343 t.Fatalf("timed out waiting for message %d", i+1)
344 }
345 }
346 select {
347 case <-activec:
348 t.Fatal("third callback in progress")
349 case <-time.After(100 * time.Millisecond):
350 }
351 cancel()
352
353 select {
354 case err := <-errc:
355 t.Fatalf("Receive returned early with error %v", err)
356 case <-time.After(100 * time.Millisecond):
357 }
358
359 waitc <- 1
360 waitc <- 1
361
362
363 if err := <-errc; err != nil {
364 t.Fatalf("got %v from Receive, want nil", err)
365 }
366 }
367
368 func TestStreamingPull_ClosedClient(t *testing.T) {
369 ctx := context.Background()
370 client, server := newMock(t)
371 defer server.srv.Close()
372 defer client.Close()
373 server.addStreamingPullMessages(testMessages)
374 sub := client.Subscription("S")
375 sub.ReceiveSettings.MaxOutstandingBytes = 1
376 recvFinished := make(chan error)
377
378 go func() {
379 err := sub.Receive(ctx, func(_ context.Context, m *Message) {
380 m.Ack()
381 })
382 recvFinished <- err
383 }()
384
385
386 time.Sleep(100 * time.Millisecond)
387
388 if err := client.Close(); err != nil {
389 t.Fatalf("Got error while closing client: %v", err)
390 }
391
392
393 time.Sleep(100 * time.Millisecond)
394
395 select {
396 case recvErr := <-recvFinished:
397 s, ok := status.FromError(recvErr)
398 if !ok {
399 t.Fatalf("Expected a gRPC failure, got %v", recvErr)
400 }
401 if s.Code() != codes.Canceled {
402 t.Fatalf("Expected canceled, got %v", s.Code())
403 }
404 case <-time.After(time.Second):
405 t.Fatal("Receive should have exited immediately after the client was closed, but it did not")
406 }
407 }
408
409 func TestStreamingPull_RetriesAfterUnavailable(t *testing.T) {
410 ctx := context.Background()
411 client, server := newMock(t)
412 defer server.srv.Close()
413 defer client.Close()
414
415 unavail := status.Error(codes.Unavailable, "There is no connection available")
416 server.addStreamingPullMessages(testMessages)
417 server.addStreamingPullError(unavail)
418 server.addAckResponse(unavail)
419 server.addModAckResponse(unavail)
420 server.addStreamingPullMessages(testMessages)
421 server.addStreamingPullError(unavail)
422
423 sub := client.Subscription("S")
424 sub.ReceiveSettings.MaxOutstandingBytes = 1
425 recvErr := make(chan error, 1)
426 recvdMsgs := make(chan *Message, len(testMessages)*2)
427
428 go func() {
429 recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
430 m.Ack()
431 recvdMsgs <- m
432 })
433 }()
434
435
436 var n int
437 for {
438 select {
439 case <-time.After(10 * time.Second):
440 t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
441 case err := <-recvErr:
442 t.Fatal(err)
443 case <-recvdMsgs:
444 n++
445 if n == len(testMessages)*2 {
446 return
447 }
448 }
449 }
450 }
451
452 func TestStreamingPull_ReconnectsAfterServerDies(t *testing.T) {
453 ctx := context.Background()
454 client, server := newMock(t)
455 defer server.srv.Close()
456 defer client.Close()
457 server.addStreamingPullMessages(testMessages)
458 sub := client.Subscription("S")
459 sub.ReceiveSettings.MaxOutstandingBytes = 1
460 recvErr := make(chan error, 1)
461 recvdMsgs := make(chan interface{}, len(testMessages)*2)
462
463 go func() {
464 recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
465 m.Ack()
466 recvdMsgs <- struct{}{}
467 })
468 }()
469
470
471 var n int
472 for {
473 select {
474 case <-time.After(5 * time.Second):
475 t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
476 case err := <-recvErr:
477 t.Fatal(err)
478 case <-recvdMsgs:
479 n++
480 if n == len(testMessages) {
481
482 server.srv.Close()
483 server2, err := newMockServer(server.srv.Port)
484 if err != nil {
485 t.Fatal(err)
486 }
487 defer server2.srv.Close()
488 server2.addStreamingPullMessages(testMessages)
489 }
490
491 if n == len(testMessages)*2 {
492 return
493 }
494 }
495 }
496 }
497
498 func newMock(t *testing.T) (*Client, *mockServer) {
499 srv, err := newMockServer(0)
500 if err != nil {
501 t.Fatal(err)
502 }
503 conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
504 if err != nil {
505 t.Fatal(err)
506 }
507 opts := withGRPCHeadersAssertion(t, option.WithGRPCConn(conn))
508 client, err := NewClient(context.Background(), "P", opts...)
509 if err != nil {
510 t.Fatal(err)
511 }
512 return client, srv
513 }
514
515
516
517 func pullN(ctx context.Context, sub *Subscription, n int, wait time.Duration, f func(context.Context, *Message)) ([]*Message, error) {
518 var (
519 mu sync.Mutex
520 msgs []*Message
521 )
522 cctx, cancel := context.WithCancel(ctx)
523 err := sub.Receive(cctx, func(ctx context.Context, m *Message) {
524 mu.Lock()
525 msgs = append(msgs, m)
526 nSeen := len(msgs)
527 mu.Unlock()
528 f(ctx, m)
529 if nSeen >= n {
530
531
532 time.Sleep(wait)
533 cancel()
534 }
535 })
536 if err != nil {
537 return msgs, err
538 }
539 return msgs, nil
540 }
541
View as plain text