
Source file src/cloud.google.com/go/pubsub/mock_test.go

Documentation: cloud.google.com/go/pubsub

     1  // Copyright 2017 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package pubsub
    17  // This file provides a mock in-memory pubsub server for streaming pull testing.
    19  import (
    20  	"context"
    21  	"io"
    22  	"sync"
    23  	"time"
    25  	"cloud.google.com/go/internal/testutil"
    26  	pb "cloud.google.com/go/pubsub/apiv1/pubsubpb"
    27  	"google.golang.org/protobuf/types/known/emptypb"
    28  )
    30  type mockServer struct {
    31  	srv *testutil.Server
    33  	pb.SubscriberServer
    35  	Addr string
    37  	mu            sync.Mutex
    38  	Acked         map[string]bool  // acked message IDs
    39  	Deadlines     map[string]int32 // deadlines by message ID
    40  	pullResponses []*pullResponse
    41  	ackErrs       []error
    42  	modAckErrs    []error
    43  	wg            sync.WaitGroup
    44  	sub           *pb.Subscription
    45  }
    47  type pullResponse struct {
    48  	msgs []*pb.ReceivedMessage
    49  	err  error
    50  }
    52  func newMockServer(port int) (*mockServer, error) {
    53  	srv, err := testutil.NewServerWithPort(port)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	mock := &mockServer{
    58  		srv:       srv,
    59  		Addr:      srv.Addr,
    60  		Acked:     map[string]bool{},
    61  		Deadlines: map[string]int32{},
    62  		sub: &pb.Subscription{
    63  			AckDeadlineSeconds: 10,
    64  			PushConfig:         &pb.PushConfig{},
    65  		},
    66  	}
    67  	pb.RegisterSubscriberServer(srv.Gsrv, mock)
    68  	srv.Start()
    69  	return mock, nil
    70  }
    72  // Each call to addStreamingPullMessages results in one StreamingPullResponse.
    73  func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
    74  	s.mu.Lock()
    75  	s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
    76  	s.mu.Unlock()
    77  }
    79  func (s *mockServer) addStreamingPullError(err error) {
    80  	s.mu.Lock()
    81  	s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
    82  	s.mu.Unlock()
    83  }
    85  func (s *mockServer) addAckResponse(err error) {
    86  	s.mu.Lock()
    87  	s.ackErrs = append(s.ackErrs, err)
    88  	s.mu.Unlock()
    89  }
    91  func (s *mockServer) addModAckResponse(err error) {
    92  	s.mu.Lock()
    93  	s.modAckErrs = append(s.modAckErrs, err)
    94  	s.mu.Unlock()
    95  }
    97  func (s *mockServer) wait() {
    98  	s.wg.Wait()
    99  }
   101  func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
   102  	s.wg.Add(1)
   103  	defer s.wg.Done()
   104  	errc := make(chan error, 1)
   105  	s.wg.Add(1)
   106  	go func() {
   107  		defer s.wg.Done()
   108  		for {
   109  			req, err := stream.Recv()
   110  			if err != nil {
   111  				errc <- err
   112  				return
   113  			}
   114  			s.mu.Lock()
   115  			for _, id := range req.AckIds {
   116  				s.Acked[id] = true
   117  			}
   118  			for i, id := range req.ModifyDeadlineAckIds {
   119  				s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
   120  			}
   121  			s.mu.Unlock()
   122  		}
   123  	}()
   124  	// Send responses.
   125  	for {
   126  		s.mu.Lock()
   127  		if len(s.pullResponses) == 0 {
   128  			s.mu.Unlock()
   129  			// Nothing to send, so wait for the client to shut down the stream.
   130  			err := <-errc // a real error, or at least EOF
   131  			if err == io.EOF {
   132  				return nil
   133  			}
   134  			return err
   135  		}
   136  		pr := s.pullResponses[0]
   137  		s.pullResponses = s.pullResponses[1:]
   138  		s.mu.Unlock()
   139  		if pr.err != nil {
   140  			// Add a slight delay to ensure the server receives any
   141  			// messages en route from the client before shutting down the stream.
   142  			// This reduces flakiness of tests involving retry.
   143  			time.Sleep(200 * time.Millisecond)
   144  		}
   145  		if pr.err == io.EOF {
   146  			return nil
   147  		}
   148  		if pr.err != nil {
   149  			return pr.err
   150  		}
   151  		// Return any error from Recv.
   152  		select {
   153  		case err := <-errc:
   154  			return err
   155  		default:
   156  		}
   157  		res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
   158  		if err := stream.Send(res); err != nil {
   159  			return err
   160  		}
   161  	}
   162  }
   164  func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
   165  	var err error
   166  	s.mu.Lock()
   167  	if len(s.ackErrs) > 0 {
   168  		err = s.ackErrs[0]
   169  		s.ackErrs = s.ackErrs[1:]
   170  	}
   171  	if err != nil {
   172  		s.mu.Unlock()
   173  		return nil, err
   174  	}
   175  	for _, id := range req.AckIds {
   176  		s.Acked[id] = true
   177  	}
   178  	s.mu.Unlock()
   179  	return &emptypb.Empty{}, nil
   180  }
   182  func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
   183  	var err error
   184  	s.mu.Lock()
   185  	if len(s.modAckErrs) > 0 {
   186  		err = s.modAckErrs[0]
   187  		s.modAckErrs = s.modAckErrs[1:]
   188  	}
   189  	if err != nil {
   190  		s.mu.Unlock()
   191  		return nil, err
   192  	}
   193  	for _, id := range req.AckIds {
   194  		s.Deadlines[id] = req.AckDeadlineSeconds
   195  	}
   196  	s.mu.Unlock()
   197  	return &emptypb.Empty{}, nil
   198  }
   200  func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
   201  	return s.sub, nil
   202  }

View as plain text