...

Source file src/cloud.google.com/go/pubsub/streaming_pull_test.go

Documentation: cloud.google.com/go/pubsub

     1  // Copyright 2017 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pubsub
    16  
    17  // TODO(jba): test keepalive
    18  // TODO(jba): test that expired messages are not kept alive
    19  // TODO(jba): test that when all messages expire, Stop returns.
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"strconv"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"cloud.google.com/go/internal/testutil"
    31  	pb "cloud.google.com/go/pubsub/apiv1/pubsubpb"
    32  	"github.com/google/go-cmp/cmp"
    33  	"github.com/google/go-cmp/cmp/cmpopts"
    34  	"google.golang.org/api/option"
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/status"
    38  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    39  )
    40  
    41  var (
    42  	timestamp    = &tspb.Timestamp{}
    43  	testMessages = []*pb.ReceivedMessage{
    44  		{AckId: "0", Message: &pb.PubsubMessage{Data: []byte{1}, PublishTime: timestamp}},
    45  		{AckId: "1", Message: &pb.PubsubMessage{Data: []byte{2}, PublishTime: timestamp}},
    46  		{AckId: "2", Message: &pb.PubsubMessage{Data: []byte{3}, PublishTime: timestamp}},
    47  	}
    48  )
    49  
    50  func TestStreamingPullBasic(t *testing.T) {
    51  	client, server := newMock(t)
    52  	defer server.srv.Close()
    53  	defer client.Close()
    54  	server.addStreamingPullMessages(testMessages)
    55  	testStreamingPullIteration(t, client, server, testMessages)
    56  }
    57  
    58  func TestStreamingPullMultipleFetches(t *testing.T) {
    59  	client, server := newMock(t)
    60  	defer server.srv.Close()
    61  	defer client.Close()
    62  	server.addStreamingPullMessages(testMessages[:1])
    63  	server.addStreamingPullMessages(testMessages[1:])
    64  	testStreamingPullIteration(t, client, server, testMessages)
    65  }
    66  
    67  func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
    68  	sub := client.Subscription("S")
    69  	gotMsgs, err := pullN(context.Background(), sub, len(msgs), 0, func(_ context.Context, m *Message) {
    70  		id, err := strconv.Atoi(msgAckID(m))
    71  		if err != nil {
    72  			t.Fatalf("pullN err: %v", err)
    73  		}
    74  		// ack evens, nack odds
    75  		if id%2 == 0 {
    76  			m.Ack()
    77  		} else {
    78  			m.Nack()
    79  		}
    80  	})
    81  	if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
    82  		t.Fatalf("Pull: %v", err)
    83  	}
    84  	gotMap := map[string]*Message{}
    85  	for _, m := range gotMsgs {
    86  		gotMap[msgAckID(m)] = m
    87  	}
    88  	for i, msg := range msgs {
    89  		want, err := toMessage(msg, time.Time{}, nil)
    90  		if err != nil {
    91  			t.Fatal(err)
    92  		}
    93  		wantAckh, _ := msgAckHandler(want, false)
    94  		wantAckh.calledDone = true
    95  		got := gotMap[wantAckh.ackID]
    96  		if got == nil {
    97  			t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
    98  			continue
    99  		}
   100  		opts := []cmp.Option{
   101  			cmp.AllowUnexported(Message{}, psAckHandler{}),
   102  			cmpopts.IgnoreTypes(
   103  				time.Time{},
   104  				func(string, bool,
   105  					*AckResult, time.Time) {
   106  				},
   107  				AckResult{},
   108  			),
   109  		}
   110  		if !testutil.Equal(got, want, opts...) {
   111  			t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
   112  		}
   113  	}
   114  	server.wait()
   115  	for i := 0; i < len(msgs); i++ {
   116  		id := msgs[i].AckId
   117  		if i%2 == 0 {
   118  			if !server.Acked[id] {
   119  				t.Errorf("msg %q should have been acked but wasn't", id)
   120  			}
   121  		} else {
   122  			if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
   123  				t.Errorf("msg %q should have been nacked but wasn't", id)
   124  			}
   125  		}
   126  	}
   127  }
   128  
   129  func TestStreamingPullError(t *testing.T) {
   130  	// If an RPC to the service returns a non-retryable error, Pull should
   131  	// return after all callbacks return, without waiting for messages to be
   132  	// acked.
   133  	client, server := newMock(t)
   134  	defer server.srv.Close()
   135  	defer client.Close()
   136  	server.addStreamingPullMessages(testMessages[:1])
   137  	server.addStreamingPullError(status.Errorf(codes.Unknown, ""))
   138  	sub := client.Subscription("S")
   139  	// Use only one goroutine, since the fake server is configured to
   140  	// return only one error.
   141  	sub.ReceiveSettings.NumGoroutines = 1
   142  	callbackDone := make(chan struct{})
   143  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   144  	defer cancel()
   145  	err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
   146  		defer close(callbackDone)
   147  		<-ctx.Done()
   148  	})
   149  	select {
   150  	case <-callbackDone:
   151  	default:
   152  		t.Fatal("Receive returned but callback was not done")
   153  	}
   154  	if want := codes.Unknown; status.Code(err) != want {
   155  		t.Fatalf("got <%v>, want code %v", err, want)
   156  	}
   157  }
   158  
   159  func TestStreamingPullCancel(t *testing.T) {
   160  	// If Receive's context is canceled, it should return after all callbacks
   161  	// return and all messages have been acked.
   162  	client, server := newMock(t)
   163  	defer server.srv.Close()
   164  	defer client.Close()
   165  	server.addStreamingPullMessages(testMessages)
   166  	sub := client.Subscription("S")
   167  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   168  	var n int32
   169  	err := sub.Receive(ctx, func(ctx2 context.Context, m *Message) {
   170  		atomic.AddInt32(&n, 1)
   171  		defer atomic.AddInt32(&n, -1)
   172  		cancel()
   173  		m.Ack()
   174  	})
   175  	if got := atomic.LoadInt32(&n); got != 0 {
   176  		t.Fatalf("Receive returned with %d callbacks still running", got)
   177  	}
   178  	if err != nil {
   179  		t.Fatalf("Receive got <%v>, want nil", err)
   180  	}
   181  }
   182  
   183  func TestStreamingPullRetry(t *testing.T) {
   184  	// Check that we retry on io.EOF or Unavailable.
   185  	t.Parallel()
   186  	client, server := newMock(t)
   187  	defer server.srv.Close()
   188  	defer client.Close()
   189  	server.addStreamingPullMessages(testMessages[:1])
   190  	server.addStreamingPullError(io.EOF)
   191  	server.addStreamingPullError(io.EOF)
   192  	server.addStreamingPullMessages(testMessages[1:2])
   193  	server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
   194  	server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
   195  	server.addStreamingPullMessages(testMessages[2:])
   196  
   197  	sub := client.Subscription("S")
   198  	sub.ReceiveSettings.NumGoroutines = 1
   199  	gotMsgs, err := pullN(context.Background(), sub, len(testMessages), 0, func(_ context.Context, m *Message) {
   200  		id, err := strconv.Atoi(msgAckID(m))
   201  		if err != nil {
   202  			t.Fatalf("pullN err: %v", err)
   203  		}
   204  		// ack evens, nack odds
   205  		if id%2 == 0 {
   206  			m.Ack()
   207  		} else {
   208  			m.Nack()
   209  		}
   210  	})
   211  	if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
   212  		t.Fatalf("Pull: %v", err)
   213  	}
   214  	gotMap := map[string]*Message{}
   215  	for _, m := range gotMsgs {
   216  		gotMap[msgAckID(m)] = m
   217  	}
   218  	for i, msg := range testMessages {
   219  		want, err := toMessage(msg, time.Time{}, nil)
   220  		if err != nil {
   221  			t.Fatal(err)
   222  		}
   223  		wantAckh, _ := msgAckHandler(want, false)
   224  		wantAckh.calledDone = true
   225  		got := gotMap[wantAckh.ackID]
   226  		if got == nil {
   227  			t.Errorf("%d: no message for ackID %q", i, wantAckh.ackID)
   228  			continue
   229  		}
   230  		opts := []cmp.Option{
   231  			cmp.AllowUnexported(Message{}, psAckHandler{}),
   232  			cmpopts.IgnoreTypes(
   233  				time.Time{},
   234  				func(string, bool,
   235  					*AckResult, time.Time) {
   236  				},
   237  				AckResult{},
   238  			),
   239  		}
   240  		if !testutil.Equal(got, want, opts...) {
   241  			t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
   242  		}
   243  	}
   244  	server.wait()
   245  	for i := 0; i < len(testMessages); i++ {
   246  		id := testMessages[i].AckId
   247  		if i%2 == 0 {
   248  			if !server.Acked[id] {
   249  				t.Errorf("msg %q should have been acked but wasn't", id)
   250  			}
   251  		} else {
   252  			if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
   253  				t.Errorf("msg %q should have been nacked but wasn't", id)
   254  			}
   255  		}
   256  	}
   257  }
   258  
   259  func TestStreamingPullOneActive(t *testing.T) {
   260  	// Only one call to Pull can be active at a time.
   261  	client, srv := newMock(t)
   262  	defer client.Close()
   263  	defer srv.srv.Close()
   264  	srv.addStreamingPullMessages(testMessages[:1])
   265  	sub := client.Subscription("S")
   266  	ctx, cancel := context.WithCancel(context.Background())
   267  	err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
   268  		m.Ack()
   269  		err := sub.Receive(ctx, func(context.Context, *Message) {})
   270  		if err != errReceiveInProgress {
   271  			t.Errorf("got <%v>, want <%v>", err, errReceiveInProgress)
   272  		}
   273  		cancel()
   274  	})
   275  	if err != nil {
   276  		t.Fatalf("got <%v>, want nil", err)
   277  	}
   278  }
   279  
   280  func TestStreamingPullConcurrent(t *testing.T) {
   281  	newMsg := func(i int) *pb.ReceivedMessage {
   282  		return &pb.ReceivedMessage{
   283  			AckId:   strconv.Itoa(i),
   284  			Message: &pb.PubsubMessage{Data: []byte{byte(i)}, PublishTime: timestamp},
   285  		}
   286  	}
   287  
   288  	// Multiple goroutines should be able to read from the same iterator.
   289  	client, server := newMock(t)
   290  	defer server.srv.Close()
   291  	defer client.Close()
   292  	// Add a lot of messages, a few at a time, to make sure both threads get a chance.
   293  	nMessages := 100
   294  	for i := 0; i < nMessages; i += 2 {
   295  		server.addStreamingPullMessages([]*pb.ReceivedMessage{newMsg(i), newMsg(i + 1)})
   296  	}
   297  	sub := client.Subscription("S")
   298  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   299  	defer cancel()
   300  	gotMsgs, err := pullN(ctx, sub, nMessages, 0, func(ctx context.Context, m *Message) {
   301  		m.Ack()
   302  	})
   303  	if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
   304  		t.Fatalf("Pull: %v", err)
   305  	}
   306  	seen := map[string]bool{}
   307  	for _, gm := range gotMsgs {
   308  		if seen[msgAckID(gm)] {
   309  			t.Fatalf("duplicate ID %q", msgAckID(gm))
   310  		}
   311  		seen[msgAckID(gm)] = true
   312  	}
   313  	if len(seen) != nMessages {
   314  		t.Fatalf("got %d messages, want %d", len(seen), nMessages)
   315  	}
   316  }
   317  
   318  func TestStreamingPullFlowControl(t *testing.T) {
   319  	// Callback invocations should not occur if flow control limits are exceeded.
   320  	client, server := newMock(t)
   321  	defer server.srv.Close()
   322  	defer client.Close()
   323  	server.addStreamingPullMessages(testMessages)
   324  	sub := client.Subscription("S")
   325  	sub.ReceiveSettings.MaxOutstandingMessages = 2
   326  	ctx, cancel := context.WithCancel(context.Background())
   327  	activec := make(chan int)
   328  	waitc := make(chan int)
   329  	errc := make(chan error)
   330  	go func() {
   331  		errc <- sub.Receive(ctx, func(_ context.Context, m *Message) {
   332  			activec <- 1
   333  			<-waitc
   334  			m.Ack()
   335  		})
   336  	}()
   337  	// Here, two callbacks are active. Receive should be blocked in the flow
   338  	// control acquire method on the third message.
   339  	for i := 0; i < 2; i++ {
   340  		select {
   341  		case <-activec:
   342  		case <-time.After(time.Second):
   343  			t.Fatalf("timed out waiting for message %d", i+1)
   344  		}
   345  	}
   346  	select {
   347  	case <-activec:
   348  		t.Fatal("third callback in progress")
   349  	case <-time.After(100 * time.Millisecond):
   350  	}
   351  	cancel()
   352  	// Receive still has not returned, because both callbacks are still blocked on waitc.
   353  	select {
   354  	case err := <-errc:
   355  		t.Fatalf("Receive returned early with error %v", err)
   356  	case <-time.After(100 * time.Millisecond):
   357  	}
   358  	// Let both callbacks proceed.
   359  	waitc <- 1
   360  	waitc <- 1
   361  	// The third callback will never run, because acquire returned a non-nil
   362  	// error, causing Receive to return. So now Receive should end.
   363  	if err := <-errc; err != nil {
   364  		t.Fatalf("got %v from Receive, want nil", err)
   365  	}
   366  }
   367  
   368  func TestStreamingPull_ClosedClient(t *testing.T) {
   369  	ctx := context.Background()
   370  	client, server := newMock(t)
   371  	defer server.srv.Close()
   372  	defer client.Close()
   373  	server.addStreamingPullMessages(testMessages)
   374  	sub := client.Subscription("S")
   375  	sub.ReceiveSettings.MaxOutstandingBytes = 1
   376  	recvFinished := make(chan error)
   377  
   378  	go func() {
   379  		err := sub.Receive(ctx, func(_ context.Context, m *Message) {
   380  			m.Ack()
   381  		})
   382  		recvFinished <- err
   383  	}()
   384  
   385  	// wait for receives to happen
   386  	time.Sleep(100 * time.Millisecond)
   387  
   388  	if err := client.Close(); err != nil {
   389  		t.Fatalf("Got error while closing client: %v", err)
   390  	}
   391  
   392  	// wait for things to close
   393  	time.Sleep(100 * time.Millisecond)
   394  
   395  	select {
   396  	case recvErr := <-recvFinished:
   397  		s, ok := status.FromError(recvErr)
   398  		if !ok {
   399  			t.Fatalf("Expected a gRPC failure, got %v", recvErr)
   400  		}
   401  		if s.Code() != codes.Canceled {
   402  			t.Fatalf("Expected canceled, got %v", s.Code())
   403  		}
   404  	case <-time.After(time.Second):
   405  		t.Fatal("Receive should have exited immediately after the client was closed, but it did not")
   406  	}
   407  }
   408  
   409  func TestStreamingPull_RetriesAfterUnavailable(t *testing.T) {
   410  	ctx := context.Background()
   411  	client, server := newMock(t)
   412  	defer server.srv.Close()
   413  	defer client.Close()
   414  
   415  	unavail := status.Error(codes.Unavailable, "There is no connection available")
   416  	server.addStreamingPullMessages(testMessages)
   417  	server.addStreamingPullError(unavail)
   418  	server.addAckResponse(unavail)
   419  	server.addModAckResponse(unavail)
   420  	server.addStreamingPullMessages(testMessages)
   421  	server.addStreamingPullError(unavail)
   422  
   423  	sub := client.Subscription("S")
   424  	sub.ReceiveSettings.MaxOutstandingBytes = 1
   425  	recvErr := make(chan error, 1)
   426  	recvdMsgs := make(chan *Message, len(testMessages)*2)
   427  
   428  	go func() {
   429  		recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
   430  			m.Ack()
   431  			recvdMsgs <- m
   432  		})
   433  	}()
   434  
   435  	// wait for receive to happen
   436  	var n int
   437  	for {
   438  		select {
   439  		case <-time.After(10 * time.Second):
   440  			t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
   441  		case err := <-recvErr:
   442  			t.Fatal(err)
   443  		case <-recvdMsgs:
   444  			n++
   445  			if n == len(testMessages)*2 {
   446  				return
   447  			}
   448  		}
   449  	}
   450  }
   451  
   452  func TestStreamingPull_ReconnectsAfterServerDies(t *testing.T) {
   453  	ctx := context.Background()
   454  	client, server := newMock(t)
   455  	defer server.srv.Close()
   456  	defer client.Close()
   457  	server.addStreamingPullMessages(testMessages)
   458  	sub := client.Subscription("S")
   459  	sub.ReceiveSettings.MaxOutstandingBytes = 1
   460  	recvErr := make(chan error, 1)
   461  	recvdMsgs := make(chan interface{}, len(testMessages)*2)
   462  
   463  	go func() {
   464  		recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
   465  			m.Ack()
   466  			recvdMsgs <- struct{}{}
   467  		})
   468  	}()
   469  
   470  	// wait for receive to happen
   471  	var n int
   472  	for {
   473  		select {
   474  		case <-time.After(5 * time.Second):
   475  			t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
   476  		case err := <-recvErr:
   477  			t.Fatal(err)
   478  		case <-recvdMsgs:
   479  			n++
   480  			if n == len(testMessages) {
   481  				// Restart the server
   482  				server.srv.Close()
   483  				server2, err := newMockServer(server.srv.Port)
   484  				if err != nil {
   485  					t.Fatal(err)
   486  				}
   487  				defer server2.srv.Close()
   488  				server2.addStreamingPullMessages(testMessages)
   489  			}
   490  
   491  			if n == len(testMessages)*2 {
   492  				return
   493  			}
   494  		}
   495  	}
   496  }
   497  
   498  func newMock(t *testing.T) (*Client, *mockServer) {
   499  	srv, err := newMockServer(0)
   500  	if err != nil {
   501  		t.Fatal(err)
   502  	}
   503  	conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
   504  	if err != nil {
   505  		t.Fatal(err)
   506  	}
   507  	opts := withGRPCHeadersAssertion(t, option.WithGRPCConn(conn))
   508  	client, err := NewClient(context.Background(), "P", opts...)
   509  	if err != nil {
   510  		t.Fatal(err)
   511  	}
   512  	return client, srv
   513  }
   514  
   515  // pullN calls sub.Receive until at least n messages are received.
   516  // Wait a provided duration before cancelling.
   517  func pullN(ctx context.Context, sub *Subscription, n int, wait time.Duration, f func(context.Context, *Message)) ([]*Message, error) {
   518  	var (
   519  		mu   sync.Mutex
   520  		msgs []*Message
   521  	)
   522  	cctx, cancel := context.WithCancel(ctx)
   523  	err := sub.Receive(cctx, func(ctx context.Context, m *Message) {
   524  		mu.Lock()
   525  		msgs = append(msgs, m)
   526  		nSeen := len(msgs)
   527  		mu.Unlock()
   528  		f(ctx, m)
   529  		if nSeen >= n {
   530  			// Wait a specified amount of time so that for exactly once delivery,
   531  			// Acks aren't cancelled immediately.
   532  			time.Sleep(wait)
   533  			cancel()
   534  		}
   535  	})
   536  	if err != nil {
   537  		return msgs, err
   538  	}
   539  	return msgs, nil
   540  }
   541  

View as plain text