...

Source file src/cloud.google.com/go/pubsub/pstest/fake.go

Documentation: cloud.google.com/go/pubsub/pstest

     1  // Copyright 2017 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package pstest provides a fake Cloud PubSub service for testing. It implements a
    16  // simplified form of the service, suitable for unit tests. It may behave
    17  // differently from the actual service in ways in which the service is
    18  // non-deterministic or unspecified: timing, delivery order, etc.
    19  //
    20  // This package is EXPERIMENTAL and is subject to change without notice.
    21  //
    22  // See the example for usage.
    23  package pstest
    24  
    25  import (
    26  	"context"
    27  	"fmt"
    28  	"io"
    29  	"math/rand"
    30  	"path"
    31  	"sort"
    32  	"strings"
    33  	"sync"
    34  	"sync/atomic"
    35  	"time"
    36  
    37  	"cloud.google.com/go/internal/testutil"
    38  	pb "cloud.google.com/go/pubsub/apiv1/pubsubpb"
    39  	"go.einride.tech/aip/filtering"
    40  	"google.golang.org/grpc/codes"
    41  	"google.golang.org/grpc/status"
    42  	durpb "google.golang.org/protobuf/types/known/durationpb"
    43  	"google.golang.org/protobuf/types/known/emptypb"
    44  	"google.golang.org/protobuf/types/known/timestamppb"
    45  )
    46  
    47  // ReactorOptions is a map that Server uses to look up reactors.
    48  // Key is the function name, value is array of reactor for the function.
    49  type ReactorOptions map[string][]Reactor
    50  
    51  // Reactor is an interface to allow reaction function to a certain call.
    52  type Reactor interface {
    53  	// React handles the message types and returns results.  If "handled" is false,
    54  	// then the test server will ignore the results and continue to the next reactor
    55  	// or the original handler.
    56  	React(_ interface{}) (handled bool, ret interface{}, err error)
    57  }
    58  
    59  // ServerReactorOption is options passed to the server for reactor creation.
    60  type ServerReactorOption struct {
    61  	FuncName string
    62  	Reactor  Reactor
    63  }
    64  
    65  type publishResponse struct {
    66  	resp *pb.PublishResponse
    67  	err  error
    68  }
    69  
    70  // Server is a fake Pub/Sub server.
    71  type Server struct {
    72  	srv     *testutil.Server
    73  	Addr    string  // The address that the server is listening on.
    74  	GServer GServer // Not intended to be used directly.
    75  }
    76  
    77  // GServer is the underlying service implementor. It is not intended to be used
    78  // directly.
    79  type GServer struct {
    80  	pb.UnimplementedPublisherServer
    81  	pb.UnimplementedSubscriberServer
    82  	pb.UnimplementedSchemaServiceServer
    83  
    84  	timeNowFunc atomic.Value
    85  
    86  	mu             sync.Mutex
    87  	topics         map[string]*topic
    88  	subs           map[string]*subscription
    89  	msgs           []*Message // all messages ever published
    90  	msgsByID       map[string]*Message
    91  	wg             sync.WaitGroup
    92  	nextID         int
    93  	streamTimeout  time.Duration
    94  	reactorOptions ReactorOptions
    95  	// schemas is a map of schemaIDs to a slice of schema revisions.
    96  	// the last element in the slice is the most recent schema.
    97  	schemas map[string][]*pb.Schema
    98  
    99  	// PublishResponses is a channel of responses to use for Publish.
   100  	publishResponses chan *publishResponse
   101  	// autoPublishResponse enables the server to automatically generate
   102  	// PublishResponse when publish is called. Otherwise, responses
   103  	// are generated from the publishResponses channel.
   104  	autoPublishResponse bool
   105  }
   106  
   107  // NewServer creates a new fake server running in the current process.
   108  func NewServer(opts ...ServerReactorOption) *Server {
   109  	return NewServerWithPort(0, opts...)
   110  }
   111  
   112  // NewServerWithPort creates a new fake server running in the current process at the specified port.
   113  func NewServerWithPort(port int, opts ...ServerReactorOption) *Server {
   114  	srv, err := testutil.NewServerWithPort(port)
   115  	if err != nil {
   116  		panic(fmt.Sprintf("pstest.NewServerWithPort: %v", err))
   117  	}
   118  	reactorOptions := ReactorOptions{}
   119  	for _, opt := range opts {
   120  		reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
   121  	}
   122  	s := &Server{
   123  		srv:  srv,
   124  		Addr: srv.Addr,
   125  		GServer: GServer{
   126  			topics:              map[string]*topic{},
   127  			subs:                map[string]*subscription{},
   128  			msgsByID:            map[string]*Message{},
   129  			reactorOptions:      reactorOptions,
   130  			publishResponses:    make(chan *publishResponse, 100),
   131  			autoPublishResponse: true,
   132  			schemas:             map[string][]*pb.Schema{},
   133  		},
   134  	}
   135  	s.GServer.timeNowFunc.Store(time.Now)
   136  	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
   137  	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
   138  	pb.RegisterSchemaServiceServer(srv.Gsrv, &s.GServer)
   139  	srv.Start()
   140  	return s
   141  }
   142  
   143  // SetTimeNowFunc registers f as a function to
   144  // be used instead of time.Now for this server.
   145  func (s *Server) SetTimeNowFunc(f func() time.Time) {
   146  	s.GServer.timeNowFunc.Store(f)
   147  }
   148  
   149  func (s *GServer) now() time.Time {
   150  	return s.timeNowFunc.Load().(func() time.Time)()
   151  }
   152  
   153  // Publish behaves as if the Publish RPC was called with a message with the given
   154  // data and attrs. It returns the ID of the message.
   155  // The topic will be created if it doesn't exist.
   156  //
   157  // Publish panics if there is an error, which is appropriate for testing.
   158  func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
   159  	return s.PublishOrdered(topic, data, attrs, "")
   160  }
   161  
   162  // PublishOrdered behaves as if the Publish RPC was called with a message with the given
   163  // data, attrs and ordering key. It returns the ID of the message.
   164  // The topic will be created if it doesn't exist.
   165  //
   166  // PublishOrdered panics if there is an error, which is appropriate for testing.
   167  func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string {
   168  	const topicPattern = "projects/*/topics/*"
   169  	ok, err := path.Match(topicPattern, topic)
   170  	if err != nil {
   171  		panic(err)
   172  	}
   173  	if !ok {
   174  		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
   175  	}
   176  	s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
   177  	req := &pb.PublishRequest{
   178  		Topic:    topic,
   179  		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}},
   180  	}
   181  	res, err := s.GServer.Publish(context.TODO(), req)
   182  	if err != nil {
   183  		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
   184  	}
   185  	return res.MessageIds[0]
   186  }
   187  
   188  // AddPublishResponse adds a new publish response to the channel used for
   189  // responding to publish requests.
   190  func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
   191  	pr := &publishResponse{}
   192  	if err != nil {
   193  		pr.err = err
   194  	} else {
   195  		pr.resp = pbr
   196  	}
   197  	s.GServer.publishResponses <- pr
   198  }
   199  
   200  // SetAutoPublishResponse controls whether to automatically respond
   201  // to messages published or to use user-added responses from the
   202  // publishResponses channel.
   203  func (s *Server) SetAutoPublishResponse(autoPublishResponse bool) {
   204  	s.GServer.mu.Lock()
   205  	defer s.GServer.mu.Unlock()
   206  	s.GServer.autoPublishResponse = autoPublishResponse
   207  }
   208  
   209  // ResetPublishResponses resets the buffered publishResponses channel
   210  // with a new buffered channel with the given size.
   211  func (s *Server) ResetPublishResponses(size int) {
   212  	s.GServer.mu.Lock()
   213  	defer s.GServer.mu.Unlock()
   214  	s.GServer.publishResponses = make(chan *publishResponse, size)
   215  }
   216  
   217  // SetStreamTimeout sets the amount of time a stream will be active before it shuts
   218  // itself down. This mimics the real service's behavior of closing streams after 30
   219  // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
   220  // down.
   221  func (s *Server) SetStreamTimeout(d time.Duration) {
   222  	s.GServer.mu.Lock()
   223  	defer s.GServer.mu.Unlock()
   224  	s.GServer.streamTimeout = d
   225  }
   226  
   227  // A Message is a message that was published to the server.
   228  type Message struct {
   229  	ID          string
   230  	Data        []byte
   231  	Attributes  map[string]string
   232  	PublishTime time.Time
   233  	Deliveries  int      // number of times delivery of the message was attempted
   234  	Acks        int      // number of acks received from clients
   235  	Modacks     []Modack // modacks received by server for this message
   236  	OrderingKey string
   237  
   238  	// protected by server mutex
   239  	deliveries int
   240  	acks       int
   241  	modacks    []Modack
   242  }
   243  
   244  // Modack represents a modack sent to the server.
   245  type Modack struct {
   246  	AckID       string
   247  	AckDeadline int32
   248  	ReceivedAt  time.Time
   249  }
   250  
   251  // Messages returns information about all messages ever published.
   252  func (s *Server) Messages() []*Message {
   253  	s.GServer.mu.Lock()
   254  	defer s.GServer.mu.Unlock()
   255  
   256  	var msgs []*Message
   257  	for _, m := range s.GServer.msgs {
   258  		m.Deliveries = m.deliveries
   259  		m.Acks = m.acks
   260  		m.Modacks = append([]Modack(nil), m.modacks...)
   261  		msgs = append(msgs, m)
   262  	}
   263  	return msgs
   264  }
   265  
   266  // Message returns the message with the given ID, or nil if no message
   267  // with that ID was published.
   268  func (s *Server) Message(id string) *Message {
   269  	s.GServer.mu.Lock()
   270  	defer s.GServer.mu.Unlock()
   271  
   272  	m := s.GServer.msgsByID[id]
   273  	if m != nil {
   274  		m.Deliveries = m.deliveries
   275  		m.Acks = m.acks
   276  		m.Modacks = append([]Modack(nil), m.modacks...)
   277  	}
   278  	return m
   279  }
   280  
   281  // Wait blocks until all server activity has completed.
   282  func (s *Server) Wait() {
   283  	s.GServer.wg.Wait()
   284  }
   285  
   286  // ClearMessages removes all published messages
   287  // from internal containers.
   288  func (s *Server) ClearMessages() {
   289  	s.GServer.mu.Lock()
   290  	s.GServer.msgs = nil
   291  	s.GServer.msgsByID = make(map[string]*Message)
   292  	for _, sub := range s.GServer.subs {
   293  		sub.msgs = map[string]*message{}
   294  	}
   295  	s.GServer.mu.Unlock()
   296  }
   297  
   298  // Close shuts down the server and releases all resources.
   299  func (s *Server) Close() error {
   300  	s.srv.Close()
   301  	s.GServer.mu.Lock()
   302  	defer s.GServer.mu.Unlock()
   303  	for _, sub := range s.GServer.subs {
   304  		sub.stop()
   305  	}
   306  	return nil
   307  }
   308  
   309  func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
   310  	s.mu.Lock()
   311  	defer s.mu.Unlock()
   312  
   313  	if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
   314  		return ret.(*pb.Topic), err
   315  	}
   316  
   317  	if s.topics[t.Name] != nil {
   318  		return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
   319  	}
   320  	if err := checkTopicMessageRetention(t.MessageRetentionDuration); err != nil {
   321  		return nil, err
   322  	}
   323  	// Take any ingestion setting to mean the topic is active.
   324  	if t.IngestionDataSourceSettings != nil {
   325  		t.State = pb.Topic_ACTIVE
   326  	}
   327  	top := newTopic(t)
   328  	s.topics[t.Name] = top
   329  	return top.proto, nil
   330  }
   331  
   332  func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
   333  	s.mu.Lock()
   334  	defer s.mu.Unlock()
   335  
   336  	if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
   337  		return ret.(*pb.Topic), err
   338  	}
   339  
   340  	if t := s.topics[req.Topic]; t != nil {
   341  		return t.proto, nil
   342  	}
   343  	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
   344  }
   345  
   346  func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
   347  	s.mu.Lock()
   348  	defer s.mu.Unlock()
   349  
   350  	if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
   351  		return ret.(*pb.Topic), err
   352  	}
   353  
   354  	t := s.topics[req.Topic.Name]
   355  	if t == nil {
   356  		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
   357  	}
   358  	for _, path := range req.UpdateMask.Paths {
   359  		switch path {
   360  		case "labels":
   361  			t.proto.Labels = req.Topic.Labels
   362  		case "message_storage_policy":
   363  			t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy
   364  		case "message_retention_duration":
   365  			if err := checkTopicMessageRetention(req.Topic.MessageRetentionDuration); err != nil {
   366  				return nil, err
   367  			}
   368  			t.proto.MessageRetentionDuration = req.Topic.MessageRetentionDuration
   369  		case "schema_settings":
   370  			t.proto.SchemaSettings = req.Topic.SchemaSettings
   371  		case "schema_settings.schema":
   372  			if t.proto.SchemaSettings == nil {
   373  				t.proto.SchemaSettings = &pb.SchemaSettings{}
   374  			}
   375  			t.proto.SchemaSettings.Schema = req.Topic.SchemaSettings.Schema
   376  		case "schema_settings.encoding":
   377  			if t.proto.SchemaSettings == nil {
   378  				t.proto.SchemaSettings = &pb.SchemaSettings{}
   379  			}
   380  			t.proto.SchemaSettings.Encoding = req.Topic.SchemaSettings.Encoding
   381  		case "schema_settings.first_revision_id":
   382  			if t.proto.SchemaSettings == nil {
   383  				t.proto.SchemaSettings = &pb.SchemaSettings{}
   384  			}
   385  			t.proto.SchemaSettings.FirstRevisionId = req.Topic.SchemaSettings.FirstRevisionId
   386  		case "schema_settings.last_revision_id":
   387  			if t.proto.SchemaSettings == nil {
   388  				t.proto.SchemaSettings = &pb.SchemaSettings{}
   389  			}
   390  			t.proto.SchemaSettings.LastRevisionId = req.Topic.SchemaSettings.LastRevisionId
   391  		case "ingestion_data_source_settings":
   392  			if t.proto.IngestionDataSourceSettings == nil {
   393  				t.proto.IngestionDataSourceSettings = &pb.IngestionDataSourceSettings{}
   394  			}
   395  			t.proto.IngestionDataSourceSettings = req.Topic.IngestionDataSourceSettings
   396  			// Take any ingestion setting to mean the topic is active.
   397  			if t.proto.IngestionDataSourceSettings != nil {
   398  				t.proto.State = pb.Topic_ACTIVE
   399  			}
   400  		default:
   401  			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
   402  		}
   403  	}
   404  	return t.proto, nil
   405  }
   406  
   407  func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
   408  	s.mu.Lock()
   409  	defer s.mu.Unlock()
   410  
   411  	if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
   412  		return ret.(*pb.ListTopicsResponse), err
   413  	}
   414  
   415  	var names []string
   416  	for n := range s.topics {
   417  		if strings.HasPrefix(n, req.Project) {
   418  			names = append(names, n)
   419  		}
   420  	}
   421  	sort.Strings(names)
   422  	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  	res := &pb.ListTopicsResponse{NextPageToken: nextToken}
   427  	for i := from; i < to; i++ {
   428  		res.Topics = append(res.Topics, s.topics[names[i]].proto)
   429  	}
   430  	return res, nil
   431  }
   432  
   433  func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
   434  	s.mu.Lock()
   435  	defer s.mu.Unlock()
   436  
   437  	if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
   438  		return ret.(*pb.ListTopicSubscriptionsResponse), err
   439  	}
   440  
   441  	var names []string
   442  	for name, sub := range s.subs {
   443  		if sub.topic.proto.Name == req.Topic {
   444  			names = append(names, name)
   445  		}
   446  	}
   447  	sort.Strings(names)
   448  	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  	return &pb.ListTopicSubscriptionsResponse{
   453  		Subscriptions: names[from:to],
   454  		NextPageToken: nextToken,
   455  	}, nil
   456  }
   457  
   458  func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
   459  	s.mu.Lock()
   460  	defer s.mu.Unlock()
   461  
   462  	if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
   463  		return ret.(*emptypb.Empty), err
   464  	}
   465  
   466  	t := s.topics[req.Topic]
   467  	if t == nil {
   468  		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
   469  	}
   470  	for _, sub := range s.subs {
   471  		if sub.deadLetterTopic == nil {
   472  			continue
   473  		}
   474  		if req.Topic == sub.deadLetterTopic.proto.Name {
   475  			return nil, status.Errorf(codes.FailedPrecondition, "topic %q used as deadLetter for %s", req.Topic, sub.proto.Name)
   476  		}
   477  	}
   478  	t.stop()
   479  	delete(s.topics, req.Topic)
   480  	return &emptypb.Empty{}, nil
   481  }
   482  
   483  func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
   484  	s.mu.Lock()
   485  	defer s.mu.Unlock()
   486  
   487  	if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
   488  		return ret.(*pb.Subscription), err
   489  	}
   490  
   491  	if ps.Name == "" {
   492  		return nil, status.Errorf(codes.InvalidArgument, "missing name")
   493  	}
   494  	if s.subs[ps.Name] != nil {
   495  		return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
   496  	}
   497  	if ps.Topic == "" {
   498  		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
   499  	}
   500  	top := s.topics[ps.Topic]
   501  	if top == nil {
   502  		return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
   503  	}
   504  	if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
   505  		return nil, err
   506  	}
   507  	if ps.MessageRetentionDuration == nil {
   508  		ps.MessageRetentionDuration = defaultMessageRetentionDuration
   509  	}
   510  	if err := checkSubMessageRetention(ps.MessageRetentionDuration); err != nil {
   511  		return nil, err
   512  	}
   513  	if ps.PushConfig == nil {
   514  		ps.PushConfig = &pb.PushConfig{}
   515  	} else if ps.PushConfig.Wrapper == nil {
   516  		// Wrapper should default to PubsubWrapper.
   517  		ps.PushConfig.Wrapper = &pb.PushConfig_PubsubWrapper_{
   518  			PubsubWrapper: &pb.PushConfig_PubsubWrapper{},
   519  		}
   520  	}
   521  	// Consider any table set to mean the config is active.
   522  	// We don't convert nil config to empty like with PushConfig above
   523  	// as this mimics the live service behavior.
   524  	if ps.GetBigqueryConfig() != nil && ps.GetBigqueryConfig().GetTable() != "" {
   525  		ps.BigqueryConfig.State = pb.BigQueryConfig_ACTIVE
   526  	}
   527  	if ps.CloudStorageConfig != nil && ps.CloudStorageConfig.Bucket != "" {
   528  		ps.CloudStorageConfig.State = pb.CloudStorageConfig_ACTIVE
   529  	}
   530  	ps.TopicMessageRetentionDuration = top.proto.MessageRetentionDuration
   531  	var deadLetterTopic *topic
   532  	if ps.DeadLetterPolicy != nil {
   533  		dlTopic, ok := s.topics[ps.DeadLetterPolicy.DeadLetterTopic]
   534  		if !ok {
   535  			return nil, status.Errorf(codes.NotFound, "deadLetter topic %q", ps.DeadLetterPolicy.DeadLetterTopic)
   536  		}
   537  		deadLetterTopic = dlTopic
   538  	}
   539  
   540  	sub := newSubscription(top, &s.mu, s.now, deadLetterTopic, ps)
   541  	top.subs[ps.Name] = sub
   542  	s.subs[ps.Name] = sub
   543  	sub.start(&s.wg)
   544  	return ps, nil
   545  }
   546  
   547  // Can be set for testing.
   548  var minAckDeadlineSecs int32
   549  
   550  // SetMinAckDeadline changes the minack deadline to n. Must be
   551  // greater than or equal to 1 second. Remember to reset this value
   552  // to the default after your test changes it. Example usage:
   553  //
   554  //	pstest.SetMinAckDeadlineSecs(1)
   555  //	defer pstest.ResetMinAckDeadlineSecs()
   556  func SetMinAckDeadline(n time.Duration) {
   557  	if n < time.Second {
   558  		panic("SetMinAckDeadline expects a value greater than 1 second")
   559  	}
   560  
   561  	minAckDeadlineSecs = int32(n / time.Second)
   562  }
   563  
   564  // ResetMinAckDeadline resets the minack deadline to the default.
   565  func ResetMinAckDeadline() {
   566  	minAckDeadlineSecs = 10
   567  }
   568  
   569  func checkAckDeadline(ads int32) error {
   570  	if ads < minAckDeadlineSecs || ads > 600 {
   571  		// PubSub service returns Unknown.
   572  		return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
   573  	}
   574  	return nil
   575  }
   576  
   577  const (
   578  	minTopicMessageRetentionDuration = 10 * time.Minute
   579  	// 31 days is the maximum topic supported duration (https://cloud.google.com/pubsub/docs/replay-overview#configuring_message_retention)
   580  	maxTopicMessageRetentionDuration = 31 * 24 * time.Hour
   581  	minSubMessageRetentionDuration   = 10 * time.Minute
   582  	// 7 days is the maximum subscription supported duration (https://cloud.google.com/pubsub/docs/replay-overview#configuring_message_retention)
   583  	maxSubMessageRetentionDuration = 7 * 24 * time.Hour
   584  )
   585  
   586  var defaultMessageRetentionDuration = durpb.New(168 * time.Hour) // default is 7 days
   587  
   588  func checkTopicMessageRetention(pmrd *durpb.Duration) error {
   589  	if pmrd == nil {
   590  		return nil
   591  	}
   592  	mrd := pmrd.AsDuration()
   593  	if mrd < minTopicMessageRetentionDuration || mrd > maxTopicMessageRetentionDuration {
   594  		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
   595  	}
   596  	return nil
   597  }
   598  
   599  func checkSubMessageRetention(pmrd *durpb.Duration) error {
   600  	if pmrd == nil {
   601  		return nil
   602  	}
   603  	mrd := pmrd.AsDuration()
   604  	if mrd < minSubMessageRetentionDuration || mrd > maxSubMessageRetentionDuration {
   605  		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
   606  	}
   607  	return nil
   608  }
   609  
   610  func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
   611  	s.mu.Lock()
   612  	defer s.mu.Unlock()
   613  
   614  	if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
   615  		return ret.(*pb.Subscription), err
   616  	}
   617  
   618  	sub, err := s.findSubscription(req.Subscription)
   619  	if err != nil {
   620  		return nil, err
   621  	}
   622  	return sub.proto, nil
   623  }
   624  
   625  func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
   626  	if req.Subscription == nil {
   627  		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
   628  	}
   629  	s.mu.Lock()
   630  	defer s.mu.Unlock()
   631  
   632  	if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
   633  		return ret.(*pb.Subscription), err
   634  	}
   635  
   636  	sub, err := s.findSubscription(req.Subscription.Name)
   637  	if err != nil {
   638  		return nil, err
   639  	}
   640  	for _, path := range req.UpdateMask.Paths {
   641  		switch path {
   642  		case "push_config":
   643  			sub.proto.PushConfig = req.Subscription.PushConfig
   644  
   645  		case "bigquery_config":
   646  			// If bq config is nil here, it will be cleared.
   647  			// Otherwise, we'll consider the subscription active if any table is set.
   648  			sub.proto.BigqueryConfig = req.GetSubscription().GetBigqueryConfig()
   649  			if sub.proto.GetBigqueryConfig() != nil {
   650  				if sub.proto.GetBigqueryConfig().GetTable() != "" {
   651  					sub.proto.BigqueryConfig.State = pb.BigQueryConfig_ACTIVE
   652  				} else {
   653  					return nil, status.Errorf(codes.InvalidArgument, "table must be provided")
   654  				}
   655  			}
   656  
   657  		case "cloud_storage_config":
   658  			sub.proto.CloudStorageConfig = req.GetSubscription().GetCloudStorageConfig()
   659  			// As long as the storage config is not nil, we assume it's valid
   660  			// without additional checks.
   661  			if sub.proto.GetCloudStorageConfig() != nil {
   662  				sub.proto.CloudStorageConfig.State = pb.CloudStorageConfig_ACTIVE
   663  			}
   664  
   665  		case "ack_deadline_seconds":
   666  			a := req.Subscription.AckDeadlineSeconds
   667  			if err := checkAckDeadline(a); err != nil {
   668  				return nil, err
   669  			}
   670  			sub.proto.AckDeadlineSeconds = a
   671  
   672  		case "retain_acked_messages":
   673  			sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
   674  
   675  		case "message_retention_duration":
   676  			if err := checkSubMessageRetention(req.Subscription.MessageRetentionDuration); err != nil {
   677  				return nil, err
   678  			}
   679  			sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
   680  
   681  		case "labels":
   682  			sub.proto.Labels = req.Subscription.Labels
   683  
   684  		case "expiration_policy":
   685  			sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
   686  
   687  		case "dead_letter_policy":
   688  			sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy
   689  			if sub.proto.DeadLetterPolicy != nil {
   690  				dlTopic, ok := s.topics[sub.proto.DeadLetterPolicy.DeadLetterTopic]
   691  				if !ok {
   692  					return nil, status.Errorf(codes.NotFound, "topic %q", sub.proto.DeadLetterPolicy.DeadLetterTopic)
   693  				}
   694  				sub.deadLetterTopic = dlTopic
   695  			}
   696  
   697  		case "retry_policy":
   698  			sub.proto.RetryPolicy = req.Subscription.RetryPolicy
   699  
   700  		case "filter":
   701  			filter, err := parseFilter(req.Subscription.Filter)
   702  			if err != nil {
   703  				return nil, status.Errorf(codes.InvalidArgument, "bad filter: %v", err)
   704  			}
   705  			sub.filter = &filter
   706  			sub.proto.Filter = req.Subscription.Filter
   707  
   708  		case "enable_exactly_once_delivery":
   709  			sub.proto.EnableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery
   710  			for _, st := range sub.streams {
   711  				st.enableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery
   712  			}
   713  
   714  		default:
   715  			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
   716  		}
   717  	}
   718  	return sub.proto, nil
   719  }
   720  
   721  func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
   722  	s.mu.Lock()
   723  	defer s.mu.Unlock()
   724  
   725  	if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
   726  		return ret.(*pb.ListSubscriptionsResponse), err
   727  	}
   728  
   729  	var names []string
   730  	for name := range s.subs {
   731  		if strings.HasPrefix(name, req.Project) {
   732  			names = append(names, name)
   733  		}
   734  	}
   735  	sort.Strings(names)
   736  	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
   737  	if err != nil {
   738  		return nil, err
   739  	}
   740  	res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
   741  	for i := from; i < to; i++ {
   742  		res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
   743  	}
   744  	return res, nil
   745  }
   746  
   747  func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
   748  	s.mu.Lock()
   749  	defer s.mu.Unlock()
   750  
   751  	if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
   752  		return ret.(*emptypb.Empty), err
   753  	}
   754  
   755  	sub, err := s.findSubscription(req.Subscription)
   756  	if err != nil {
   757  		return nil, err
   758  	}
   759  	sub.stop()
   760  	delete(s.subs, req.Subscription)
   761  	sub.topic.deleteSub(sub)
   762  	return &emptypb.Empty{}, nil
   763  }
   764  
   765  func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
   766  	s.mu.Lock()
   767  	defer s.mu.Unlock()
   768  
   769  	if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
   770  		return ret.(*pb.DetachSubscriptionResponse), err
   771  	}
   772  
   773  	sub, err := s.findSubscription(req.Subscription)
   774  	if err != nil {
   775  		return nil, err
   776  	}
   777  	sub.topic.deleteSub(sub)
   778  	return &pb.DetachSubscriptionResponse{}, nil
   779  }
   780  
   781  func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
   782  	s.mu.Lock()
   783  	defer s.mu.Unlock()
   784  
   785  	if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
   786  		return ret.(*pb.PublishResponse), err
   787  	}
   788  
   789  	if req.Topic == "" {
   790  		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
   791  	}
   792  	top := s.topics[req.Topic]
   793  	if top == nil {
   794  		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
   795  	}
   796  
   797  	if !s.autoPublishResponse {
   798  		r := <-s.publishResponses
   799  		if r.err != nil {
   800  			return nil, r.err
   801  		}
   802  		return r.resp, nil
   803  	}
   804  
   805  	var ids []string
   806  	for _, pm := range req.Messages {
   807  		id := fmt.Sprintf("m%d", s.nextID)
   808  		s.nextID++
   809  		pm.MessageId = id
   810  		pubTime := s.now()
   811  		tsPubTime := timestamppb.New(pubTime)
   812  		pm.PublishTime = tsPubTime
   813  		m := &Message{
   814  			ID:          id,
   815  			Data:        pm.Data,
   816  			Attributes:  pm.Attributes,
   817  			PublishTime: pubTime,
   818  			OrderingKey: pm.OrderingKey,
   819  		}
   820  		top.publish(pm, m)
   821  		ids = append(ids, id)
   822  		s.msgs = append(s.msgs, m)
   823  		s.msgsByID[id] = m
   824  	}
   825  	return &pb.PublishResponse{MessageIds: ids}, nil
   826  }
   827  
   828  type topic struct {
   829  	proto *pb.Topic
   830  	subs  map[string]*subscription
   831  }
   832  
   833  func newTopic(pt *pb.Topic) *topic {
   834  	return &topic{
   835  		proto: pt,
   836  		subs:  map[string]*subscription{},
   837  	}
   838  }
   839  
   840  func (t *topic) stop() {
   841  	for _, sub := range t.subs {
   842  		sub.proto.Topic = "_deleted-topic_"
   843  	}
   844  }
   845  
   846  func (t *topic) deleteSub(sub *subscription) {
   847  	delete(t.subs, sub.proto.Name)
   848  }
   849  
   850  func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
   851  	for _, s := range t.subs {
   852  		s.msgs[pm.MessageId] = &message{
   853  			publishTime: m.PublishTime,
   854  			proto: &pb.ReceivedMessage{
   855  				AckId:   pm.MessageId,
   856  				Message: pm,
   857  			},
   858  			deliveries:  &m.deliveries,
   859  			acks:        &m.acks,
   860  			streamIndex: -1,
   861  		}
   862  	}
   863  }
   864  
   865  type subscription struct {
   866  	topic           *topic
   867  	deadLetterTopic *topic
   868  	mu              *sync.Mutex // the server mutex, here for convenience
   869  	proto           *pb.Subscription
   870  	ackTimeout      time.Duration
   871  	msgs            map[string]*message // unacked messages by message ID
   872  	streams         []*stream
   873  	done            chan struct{}
   874  	timeNowFunc     func() time.Time
   875  	filter          *filtering.Filter
   876  }
   877  
   878  func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, deadLetterTopic *topic, ps *pb.Subscription) *subscription {
   879  	at := time.Duration(ps.AckDeadlineSeconds) * time.Second
   880  	if at == 0 {
   881  		at = 10 * time.Second
   882  	}
   883  	ps.State = pb.Subscription_ACTIVE
   884  	sub := &subscription{
   885  		topic:           t,
   886  		deadLetterTopic: deadLetterTopic,
   887  		mu:              mu,
   888  		proto:           ps,
   889  		ackTimeout:      at,
   890  		msgs:            map[string]*message{},
   891  		done:            make(chan struct{}),
   892  		timeNowFunc:     timeNowFunc,
   893  	}
   894  	if ps.Filter != "" {
   895  		filter, err := parseFilter(ps.Filter)
   896  		if err != nil {
   897  			panic(fmt.Sprintf("pstest: bad filter: %v", err))
   898  		}
   899  		sub.filter = &filter
   900  	}
   901  	return sub
   902  }
   903  
   904  func (s *subscription) start(wg *sync.WaitGroup) {
   905  	wg.Add(1)
   906  	go func() {
   907  		defer wg.Done()
   908  		for {
   909  			select {
   910  			case <-s.done:
   911  				return
   912  			case <-time.After(10 * time.Millisecond):
   913  				s.deliver()
   914  			}
   915  		}
   916  	}()
   917  }
   918  
   919  func (s *subscription) stop() {
   920  	close(s.done)
   921  }
   922  
   923  func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
   924  	s.mu.Lock()
   925  	defer s.mu.Unlock()
   926  
   927  	if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
   928  		return ret.(*emptypb.Empty), err
   929  	}
   930  
   931  	sub, err := s.findSubscription(req.Subscription)
   932  	if err != nil {
   933  		return nil, err
   934  	}
   935  	for _, id := range req.AckIds {
   936  		sub.ack(id)
   937  	}
   938  	return &emptypb.Empty{}, nil
   939  }
   940  
   941  func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
   942  	s.mu.Lock()
   943  	defer s.mu.Unlock()
   944  
   945  	if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
   946  		return ret.(*emptypb.Empty), err
   947  	}
   948  
   949  	sub, err := s.findSubscription(req.Subscription)
   950  	if err != nil {
   951  		return nil, err
   952  	}
   953  	now := time.Now()
   954  	for _, id := range req.AckIds {
   955  		s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
   956  	}
   957  	dur := secsToDur(req.AckDeadlineSeconds)
   958  	for _, id := range req.AckIds {
   959  		sub.modifyAckDeadline(id, dur)
   960  	}
   961  	return &emptypb.Empty{}, nil
   962  }
   963  
   964  func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
   965  	s.mu.Lock()
   966  
   967  	if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
   968  		s.mu.Unlock()
   969  		return ret.(*pb.PullResponse), err
   970  	}
   971  
   972  	sub, err := s.findSubscription(req.Subscription)
   973  	if err != nil {
   974  		s.mu.Unlock()
   975  		return nil, err
   976  	}
   977  	max := int(req.MaxMessages)
   978  	if max < 0 {
   979  		s.mu.Unlock()
   980  		return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
   981  	}
   982  	if max == 0 { // MaxMessages not specified; use a default.
   983  		max = 1000
   984  	}
   985  	msgs := sub.pull(max)
   986  	s.mu.Unlock()
   987  	// Implement the spec from the pubsub proto:
   988  	// "If ReturnImmediately set to true, the system will respond immediately even if
   989  	// it there are no messages available to return in the `Pull` response.
   990  	// Otherwise, the system may wait (for a bounded amount of time) until at
   991  	// least one message is available, rather than returning no messages."
   992  	if len(msgs) == 0 && !req.ReturnImmediately {
   993  		// Wait for a short amount of time for a message.
   994  		// TODO: signal when a message arrives, so we don't wait the whole time.
   995  		select {
   996  		case <-ctx.Done():
   997  			return nil, ctx.Err()
   998  		case <-time.After(500 * time.Millisecond):
   999  			s.mu.Lock()
  1000  			msgs = sub.pull(max)
  1001  			s.mu.Unlock()
  1002  		}
  1003  	}
  1004  	return &pb.PullResponse{ReceivedMessages: msgs}, nil
  1005  }
  1006  
  1007  func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
  1008  	// Receive initial message configuring the pull.
  1009  	req, err := sps.Recv()
  1010  	if err != nil {
  1011  		return err
  1012  	}
  1013  	s.mu.Lock()
  1014  	sub, err := s.findSubscription(req.Subscription)
  1015  	s.mu.Unlock()
  1016  	if err != nil {
  1017  		return err
  1018  	}
  1019  	// Create a new stream to handle the pull.
  1020  	st := sub.newStream(sps, s.streamTimeout)
  1021  	st.ackTimeout = time.Duration(req.StreamAckDeadlineSeconds) * time.Second
  1022  	err = st.pull(&s.wg)
  1023  	sub.deleteStream(st)
  1024  	return err
  1025  }
  1026  
  1027  func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
  1028  	// Only handle time-based seeking for now.
  1029  	// This fake doesn't deal with snapshots.
  1030  	var target time.Time
  1031  	switch v := req.Target.(type) {
  1032  	case nil:
  1033  		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
  1034  	case *pb.SeekRequest_Time:
  1035  		target = v.Time.AsTime()
  1036  	default:
  1037  		return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
  1038  	}
  1039  
  1040  	// The entire server must be locked while doing the work below,
  1041  	// because the messages don't have any other synchronization.
  1042  	s.mu.Lock()
  1043  	defer s.mu.Unlock()
  1044  
  1045  	if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
  1046  		return ret.(*pb.SeekResponse), err
  1047  	}
  1048  
  1049  	sub, err := s.findSubscription(req.Subscription)
  1050  	if err != nil {
  1051  		return nil, err
  1052  	}
  1053  	// Drop all messages from sub that were published before the target time.
  1054  	for id, m := range sub.msgs {
  1055  		if m.publishTime.Before(target) {
  1056  			delete(sub.msgs, id)
  1057  			(*m.acks)++
  1058  		}
  1059  	}
  1060  	// Un-ack any already-acked messages after this time;
  1061  	// redelivering them to the subscription is the closest analogue here.
  1062  	for _, m := range s.msgs {
  1063  		if m.PublishTime.Before(target) {
  1064  			continue
  1065  		}
  1066  		sub.msgs[m.ID] = &message{
  1067  			publishTime: m.PublishTime,
  1068  			proto: &pb.ReceivedMessage{
  1069  				AckId: m.ID,
  1070  				// This was not preserved!
  1071  				//Message: pm,
  1072  			},
  1073  			deliveries:  &m.deliveries,
  1074  			acks:        &m.acks,
  1075  			streamIndex: -1,
  1076  		}
  1077  	}
  1078  	return &pb.SeekResponse{}, nil
  1079  }
  1080  
  1081  // Gets a subscription that must exist.
  1082  // Must be called with the lock held.
  1083  func (s *GServer) findSubscription(name string) (*subscription, error) {
  1084  	if name == "" {
  1085  		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
  1086  	}
  1087  	sub := s.subs[name]
  1088  	if sub == nil {
  1089  		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
  1090  	}
  1091  	return sub, nil
  1092  }
  1093  
  1094  // Must be called with the lock held.
  1095  func (s *subscription) pull(max int) []*pb.ReceivedMessage {
  1096  	now := s.timeNowFunc()
  1097  	s.maintainMessages(now)
  1098  	var msgs []*pb.ReceivedMessage
  1099  	filterMsgs(s.msgs, s.filter)
  1100  	for id, m := range orderMsgs(s.msgs, s.proto.EnableMessageOrdering) {
  1101  		if m.outstanding() {
  1102  			continue
  1103  		}
  1104  		if s.deadLetterCandidate(m) {
  1105  			s.ack(id)
  1106  			s.publishToDeadLetter(m)
  1107  			continue
  1108  		}
  1109  		(*m.deliveries)++
  1110  		if s.proto.DeadLetterPolicy != nil {
  1111  			m.proto.DeliveryAttempt = int32(*m.deliveries)
  1112  		}
  1113  		m.ackDeadline = now.Add(s.ackTimeout)
  1114  		msgs = append(msgs, m.proto)
  1115  		if len(msgs) >= max {
  1116  			break
  1117  		}
  1118  	}
  1119  	return msgs
  1120  }
  1121  
  1122  func orderMsgs(msgs map[string]*message, enableMessageOrdering bool) map[string]*message {
  1123  	if !enableMessageOrdering {
  1124  		return msgs
  1125  	}
  1126  	result := make(map[string]*message)
  1127  
  1128  	type msg struct {
  1129  		id string
  1130  		m  *message
  1131  	}
  1132  	orderingKeyMap := make(map[string]msg)
  1133  	for id, m := range msgs {
  1134  		orderingKey := m.proto.Message.OrderingKey
  1135  		if orderingKey == "" {
  1136  			orderingKey = id
  1137  		}
  1138  		if val, ok := orderingKeyMap[orderingKey]; !ok || m.proto.Message.PublishTime.AsTime().Before(val.m.proto.Message.PublishTime.AsTime()) {
  1139  			orderingKeyMap[orderingKey] = msg{m: m, id: id}
  1140  		}
  1141  	}
  1142  	for _, val := range orderingKeyMap {
  1143  		result[val.id] = val.m
  1144  	}
  1145  	return result
  1146  }
  1147  
  1148  func filterMsgs(msgs map[string]*message, filter *filtering.Filter) {
  1149  	if filter == nil {
  1150  		return
  1151  	}
  1152  
  1153  	filterByAttrs(msgs, filter, func(m *message) messageAttrs {
  1154  		return m.proto.Message.Attributes
  1155  	})
  1156  }
  1157  
  1158  func (s *subscription) deliver() {
  1159  	s.mu.Lock()
  1160  	defer s.mu.Unlock()
  1161  
  1162  	now := s.timeNowFunc()
  1163  	s.maintainMessages(now)
  1164  	// Try to deliver each remaining message.
  1165  	curIndex := 0
  1166  	filterMsgs(s.msgs, s.filter)
  1167  	for id, m := range orderMsgs(s.msgs, s.proto.EnableMessageOrdering) {
  1168  		if m.outstanding() {
  1169  			continue
  1170  		}
  1171  		if s.deadLetterCandidate(m) {
  1172  			s.ack(id)
  1173  			s.publishToDeadLetter(m)
  1174  			continue
  1175  		}
  1176  		// If the message was never delivered before, start with the stream at
  1177  		// curIndex. If it was delivered before, start with the stream after the one
  1178  		// that owned it.
  1179  		if m.streamIndex < 0 {
  1180  			delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
  1181  			if !ok {
  1182  				break
  1183  			}
  1184  			curIndex = delIndex + 1
  1185  			m.streamIndex = curIndex
  1186  		} else {
  1187  			delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
  1188  			if !ok {
  1189  				break
  1190  			}
  1191  			m.streamIndex = delIndex
  1192  		}
  1193  	}
  1194  }
  1195  
  1196  // tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
  1197  // tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
  1198  // exits.
  1199  //
  1200  // It returns the index of the stream it delivered the message to, or 0, false if
  1201  // it didn't deliver the message.
  1202  //
  1203  // Must be called with the lock held.
  1204  func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
  1205  	// Optimistically increment DeliveryAttempt assuming we'll be able to deliver the message.  This is
  1206  	// safe since the lock is held for the duration of this function, and the channel receiver does not
  1207  	// modify the message.
  1208  	if s.proto.DeadLetterPolicy != nil {
  1209  		m.proto.DeliveryAttempt = int32(*m.deliveries) + 1
  1210  	}
  1211  
  1212  	for i := 0; i < len(s.streams); i++ {
  1213  		idx := (i + start) % len(s.streams)
  1214  
  1215  		st := s.streams[idx]
  1216  		select {
  1217  		case <-st.done:
  1218  			s.streams = deleteStreamAt(s.streams, idx)
  1219  			i--
  1220  
  1221  		case st.msgc <- m.proto:
  1222  			(*m.deliveries)++
  1223  			m.ackDeadline = now.Add(st.ackTimeout)
  1224  			return idx, true
  1225  
  1226  		default:
  1227  		}
  1228  	}
  1229  	// Restore the correct value of DeliveryAttempt if we were not able to deliver the message.
  1230  	if s.proto.DeadLetterPolicy != nil {
  1231  		m.proto.DeliveryAttempt = int32(*m.deliveries)
  1232  	}
  1233  	return 0, false
  1234  }
  1235  
  1236  const retentionDuration = 10 * time.Minute
  1237  
  1238  // Must be called with the lock held.
  1239  func (s *subscription) maintainMessages(now time.Time) {
  1240  	for id, m := range s.msgs {
  1241  		// Mark a message as re-deliverable if its ack deadline has expired.
  1242  		if m.outstanding() && now.After(m.ackDeadline) {
  1243  			m.makeAvailable()
  1244  		}
  1245  		pubTime := m.proto.Message.PublishTime.AsTime()
  1246  		// Remove messages that have been undelivered for a long time.
  1247  		if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
  1248  			delete(s.msgs, id)
  1249  		}
  1250  	}
  1251  }
  1252  
  1253  func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
  1254  	st := &stream{
  1255  		sub:                       s,
  1256  		done:                      make(chan struct{}),
  1257  		msgc:                      make(chan *pb.ReceivedMessage),
  1258  		gstream:                   gs,
  1259  		ackTimeout:                s.ackTimeout,
  1260  		timeout:                   timeout,
  1261  		enableExactlyOnceDelivery: s.proto.EnableExactlyOnceDelivery,
  1262  		enableOrdering:            s.proto.EnableMessageOrdering,
  1263  	}
  1264  	s.mu.Lock()
  1265  	s.streams = append(s.streams, st)
  1266  	s.mu.Unlock()
  1267  	return st
  1268  }
  1269  
  1270  func (s *subscription) deleteStream(st *stream) {
  1271  	s.mu.Lock()
  1272  	defer s.mu.Unlock()
  1273  	var i int
  1274  	for i = 0; i < len(s.streams); i++ {
  1275  		if s.streams[i] == st {
  1276  			break
  1277  		}
  1278  	}
  1279  	if i < len(s.streams) {
  1280  		s.streams = deleteStreamAt(s.streams, i)
  1281  	}
  1282  }
  1283  
  1284  func (s *subscription) deadLetterCandidate(m *message) bool {
  1285  	if s.proto.DeadLetterPolicy == nil {
  1286  		return false
  1287  	}
  1288  	if m.retriesDone(s.proto.DeadLetterPolicy.MaxDeliveryAttempts) {
  1289  		return true
  1290  	}
  1291  	return false
  1292  }
  1293  
  1294  func (s *subscription) publishToDeadLetter(m *message) {
  1295  	acks := 0
  1296  	if m.acks != nil {
  1297  		acks = *m.acks
  1298  	}
  1299  	deliveries := 0
  1300  	if m.deliveries != nil {
  1301  		deliveries = *m.deliveries
  1302  	}
  1303  	s.deadLetterTopic.publish(m.proto.Message, &Message{
  1304  		PublishTime: m.publishTime,
  1305  		Acks:        acks,
  1306  		Deliveries:  deliveries,
  1307  	})
  1308  }
  1309  
  1310  func deleteStreamAt(s []*stream, i int) []*stream {
  1311  	// Preserve order for round-robin delivery.
  1312  	return append(s[:i], s[i+1:]...)
  1313  }
  1314  
  1315  type message struct {
  1316  	proto       *pb.ReceivedMessage
  1317  	publishTime time.Time
  1318  	ackDeadline time.Time
  1319  	deliveries  *int
  1320  	acks        *int
  1321  	streamIndex int // index of stream that currently owns msg, for round-robin delivery
  1322  }
  1323  
  1324  // A message is outstanding if it is owned by some stream.
  1325  func (m *message) outstanding() bool {
  1326  	return !m.ackDeadline.IsZero()
  1327  }
  1328  
  1329  // A message is outstanding if it is owned by some stream.
  1330  func (m *message) retriesDone(maxRetries int32) bool {
  1331  	return m.deliveries != nil && int32(*m.deliveries) >= maxRetries
  1332  }
  1333  
  1334  func (m *message) makeAvailable() {
  1335  	m.ackDeadline = time.Time{}
  1336  }
  1337  
  1338  type stream struct {
  1339  	sub                       *subscription
  1340  	done                      chan struct{} // closed when the stream is finished
  1341  	msgc                      chan *pb.ReceivedMessage
  1342  	gstream                   pb.Subscriber_StreamingPullServer
  1343  	ackTimeout                time.Duration
  1344  	timeout                   time.Duration
  1345  	enableExactlyOnceDelivery bool
  1346  	enableOrdering            bool
  1347  }
  1348  
  1349  // pull manages the StreamingPull interaction for the life of the stream.
  1350  func (st *stream) pull(wg *sync.WaitGroup) error {
  1351  	errc := make(chan error, 2)
  1352  	wg.Add(2)
  1353  	go func() {
  1354  		defer wg.Done()
  1355  		errc <- st.sendLoop()
  1356  	}()
  1357  	go func() {
  1358  		defer wg.Done()
  1359  		errc <- st.recvLoop()
  1360  	}()
  1361  	var tchan <-chan time.Time
  1362  	if st.timeout > 0 {
  1363  		tchan = time.After(st.timeout)
  1364  	}
  1365  	// Wait until one of the goroutines returns an error, or we time out.
  1366  	var err error
  1367  	select {
  1368  	case err = <-errc:
  1369  		if err == io.EOF {
  1370  			err = nil
  1371  		}
  1372  	case <-tchan:
  1373  	}
  1374  	close(st.done) // stop the other goroutine
  1375  	return err
  1376  }
  1377  
  1378  func (st *stream) sendLoop() error {
  1379  	for {
  1380  		select {
  1381  		case <-st.done:
  1382  			return nil
  1383  		case rm := <-st.msgc:
  1384  			res := &pb.StreamingPullResponse{
  1385  				ReceivedMessages: []*pb.ReceivedMessage{rm},
  1386  				SubscriptionProperties: &pb.StreamingPullResponse_SubscriptionProperties{
  1387  					ExactlyOnceDeliveryEnabled: st.enableExactlyOnceDelivery,
  1388  					MessageOrderingEnabled:     st.enableOrdering,
  1389  				},
  1390  			}
  1391  			if err := st.gstream.Send(res); err != nil {
  1392  				return err
  1393  			}
  1394  		}
  1395  	}
  1396  }
  1397  
  1398  func (st *stream) recvLoop() error {
  1399  	for {
  1400  		req, err := st.gstream.Recv()
  1401  		if err != nil {
  1402  			return err
  1403  		}
  1404  		st.sub.handleStreamingPullRequest(st, req)
  1405  	}
  1406  }
  1407  
  1408  func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
  1409  	// Lock the entire server.
  1410  	s.mu.Lock()
  1411  	defer s.mu.Unlock()
  1412  
  1413  	for _, ackID := range req.AckIds {
  1414  		s.ack(ackID)
  1415  	}
  1416  	for i, id := range req.ModifyDeadlineAckIds {
  1417  		s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
  1418  	}
  1419  	if req.StreamAckDeadlineSeconds > 0 {
  1420  		st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
  1421  	}
  1422  }
  1423  
  1424  // Must be called with the lock held.
  1425  func (s *subscription) ack(id string) {
  1426  	m := s.msgs[id]
  1427  	if m != nil {
  1428  		(*m.acks)++
  1429  		delete(s.msgs, id)
  1430  	}
  1431  }
  1432  
  1433  // Must be called with the lock held.
  1434  func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
  1435  	m := s.msgs[id]
  1436  	if m == nil { // already acked: ignore.
  1437  		return
  1438  	}
  1439  	if d == 0 { // nack
  1440  		m.makeAvailable()
  1441  	} else { // extend the deadline by d
  1442  		m.ackDeadline = s.timeNowFunc().Add(d)
  1443  	}
  1444  }
  1445  
  1446  func secsToDur(secs int32) time.Duration {
  1447  	return time.Duration(secs) * time.Second
  1448  }
  1449  
  1450  // runReactor looks up the reactors for a function, then launches them until handled=true
  1451  // or err is returned. If the reactor returns nil, the function returns defaultObj instead.
  1452  func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
  1453  	if val, ok := s.reactorOptions[funcName]; ok {
  1454  		for _, reactor := range val {
  1455  			handled, ret, err := reactor.React(req)
  1456  			// If handled=true, that means the reactor has successfully reacted to the request,
  1457  			// so use the output directly. If err occurs, that means the request is invalidated
  1458  			// by the reactor somehow.
  1459  			if handled || err != nil {
  1460  				if ret == nil {
  1461  					ret = defaultObj
  1462  				}
  1463  				return true, ret, err
  1464  			}
  1465  		}
  1466  	}
  1467  	return false, nil, nil
  1468  }
  1469  
  1470  // errorInjectionReactor is a reactor to inject an error message with status code.
  1471  type errorInjectionReactor struct {
  1472  	code codes.Code
  1473  	msg  string
  1474  }
  1475  
  1476  // React simply returns an error with defined error message and status code.
  1477  func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
  1478  	return true, nil, status.Errorf(e.code, e.msg)
  1479  }
  1480  
  1481  // WithErrorInjection creates a ServerReactorOption that injects error with defined status code and
  1482  // message for a certain function.
  1483  func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption {
  1484  	return ServerReactorOption{
  1485  		FuncName: funcName,
  1486  		Reactor:  &errorInjectionReactor{code: code, msg: msg},
  1487  	}
  1488  }
  1489  
  1490  const letters = "abcdef1234567890"
  1491  
  1492  func genRevID() string {
  1493  	id := make([]byte, 8)
  1494  	for i := range id {
  1495  		id[i] = letters[rand.Intn(len(letters))]
  1496  	}
  1497  	return string(id)
  1498  }
  1499  
  1500  func (s *GServer) CreateSchema(_ context.Context, req *pb.CreateSchemaRequest) (*pb.Schema, error) {
  1501  	s.mu.Lock()
  1502  	defer s.mu.Unlock()
  1503  
  1504  	if handled, ret, err := s.runReactor(req, "CreateSchema", &pb.Schema{}); handled || err != nil {
  1505  		return ret.(*pb.Schema), err
  1506  	}
  1507  
  1508  	name := fmt.Sprintf("%s/schemas/%s", req.Parent, req.SchemaId)
  1509  	sc := &pb.Schema{
  1510  		Name:               name,
  1511  		Type:               req.Schema.Type,
  1512  		Definition:         req.Schema.Definition,
  1513  		RevisionId:         genRevID(),
  1514  		RevisionCreateTime: timestamppb.Now(),
  1515  	}
  1516  	s.schemas[name] = append(s.schemas[name], sc)
  1517  
  1518  	return sc, nil
  1519  }
  1520  
  1521  func (s *GServer) GetSchema(_ context.Context, req *pb.GetSchemaRequest) (*pb.Schema, error) {
  1522  	s.mu.Lock()
  1523  	defer s.mu.Unlock()
  1524  
  1525  	if handled, ret, err := s.runReactor(req, "GetSchema", &pb.Schema{}); handled || err != nil {
  1526  		return ret.(*pb.Schema), err
  1527  	}
  1528  
  1529  	ss := strings.Split(req.Name, "@")
  1530  	var schemaName, revisionID string
  1531  	if len := len(ss); len == 1 {
  1532  		schemaName = ss[0]
  1533  	} else if len == 2 {
  1534  		schemaName = ss[0]
  1535  		revisionID = ss[1]
  1536  	} else {
  1537  		return nil, status.Errorf(codes.InvalidArgument, "schema(%q) name parse error", req.Name)
  1538  	}
  1539  
  1540  	schemaRev, ok := s.schemas[schemaName]
  1541  	if !ok {
  1542  		return nil, status.Errorf(codes.NotFound, "schema(%q) not found", req.Name)
  1543  	}
  1544  
  1545  	if revisionID == "" {
  1546  		return schemaRev[len(schemaRev)-1], nil
  1547  	}
  1548  
  1549  	for _, sc := range schemaRev {
  1550  		if sc.RevisionId == revisionID {
  1551  			return sc, nil
  1552  		}
  1553  	}
  1554  
  1555  	return nil, status.Errorf(codes.NotFound, "schema %q not found", req.Name)
  1556  }
  1557  
  1558  func (s *GServer) ListSchemas(_ context.Context, req *pb.ListSchemasRequest) (*pb.ListSchemasResponse, error) {
  1559  	s.mu.Lock()
  1560  	defer s.mu.Unlock()
  1561  
  1562  	if handled, ret, err := s.runReactor(req, "ListSchemas", &pb.ListSchemasResponse{}); handled || err != nil {
  1563  		return ret.(*pb.ListSchemasResponse), err
  1564  	}
  1565  	ss := make([]*pb.Schema, 0)
  1566  	for _, sc := range s.schemas {
  1567  		ss = append(ss, sc[len(sc)-1])
  1568  	}
  1569  	return &pb.ListSchemasResponse{
  1570  		Schemas: ss,
  1571  	}, nil
  1572  }
  1573  
  1574  func (s *GServer) ListSchemaRevisions(_ context.Context, req *pb.ListSchemaRevisionsRequest) (*pb.ListSchemaRevisionsResponse, error) {
  1575  	s.mu.Lock()
  1576  	defer s.mu.Unlock()
  1577  
  1578  	if handled, ret, err := s.runReactor(req, "ListSchemaRevisions", &pb.ListSchemasResponse{}); handled || err != nil {
  1579  		return ret.(*pb.ListSchemaRevisionsResponse), err
  1580  	}
  1581  	ss := make([]*pb.Schema, 0)
  1582  	ss = append(ss, s.schemas[req.Name]...)
  1583  	return &pb.ListSchemaRevisionsResponse{
  1584  		Schemas: ss,
  1585  	}, nil
  1586  }
  1587  
  1588  func (s *GServer) CommitSchema(_ context.Context, req *pb.CommitSchemaRequest) (*pb.Schema, error) {
  1589  	s.mu.Lock()
  1590  	defer s.mu.Unlock()
  1591  
  1592  	if handled, ret, err := s.runReactor(req, "CommitSchema", &pb.Schema{}); handled || err != nil {
  1593  		return ret.(*pb.Schema), err
  1594  	}
  1595  
  1596  	sc := &pb.Schema{
  1597  		Name:       req.Name,
  1598  		Type:       req.Schema.Type,
  1599  		Definition: req.Schema.Definition,
  1600  	}
  1601  	sc.RevisionId = genRevID()
  1602  	sc.RevisionCreateTime = timestamppb.Now()
  1603  
  1604  	s.schemas[req.Name] = append(s.schemas[req.Name], sc)
  1605  
  1606  	return sc, nil
  1607  }
  1608  
  1609  // RollbackSchema rolls back the current schema to a previous revision by copying and creating a new revision.
  1610  func (s *GServer) RollbackSchema(_ context.Context, req *pb.RollbackSchemaRequest) (*pb.Schema, error) {
  1611  	s.mu.Lock()
  1612  	defer s.mu.Unlock()
  1613  
  1614  	if handled, ret, err := s.runReactor(req, "RollbackSchema", &pb.Schema{}); handled || err != nil {
  1615  		return ret.(*pb.Schema), err
  1616  	}
  1617  
  1618  	for _, sc := range s.schemas[req.Name] {
  1619  		if sc.RevisionId == req.RevisionId {
  1620  			newSchema := *sc
  1621  			newSchema.RevisionId = genRevID()
  1622  			newSchema.RevisionCreateTime = timestamppb.Now()
  1623  			s.schemas[req.Name] = append(s.schemas[req.Name], &newSchema)
  1624  			return &newSchema, nil
  1625  		}
  1626  	}
  1627  	return nil, status.Errorf(codes.NotFound, "schema %q@%q not found", req.Name, req.RevisionId)
  1628  }
  1629  
  1630  func (s *GServer) DeleteSchemaRevision(_ context.Context, req *pb.DeleteSchemaRevisionRequest) (*pb.Schema, error) {
  1631  	s.mu.Lock()
  1632  	defer s.mu.Unlock()
  1633  
  1634  	if handled, ret, err := s.runReactor(req, "DeleteSchemaRevision", &pb.Schema{}); handled || err != nil {
  1635  		return ret.(*pb.Schema), err
  1636  	}
  1637  
  1638  	schemaPath := strings.Split(req.Name, "@")
  1639  	if len(schemaPath) != 2 {
  1640  		return nil, status.Errorf(codes.InvalidArgument, "could not parse revision ID from schema name: %q", req.Name)
  1641  	}
  1642  	schemaName := schemaPath[0]
  1643  	revID := schemaPath[1]
  1644  	schemaRevisions, ok := s.schemas[schemaName]
  1645  	if ok {
  1646  		if len(schemaRevisions) == 1 {
  1647  			return nil, status.Errorf(codes.InvalidArgument, "cannot delete last revision for schema %q", req.Name)
  1648  		}
  1649  		for i, sc := range schemaRevisions {
  1650  			if sc.RevisionId == revID {
  1651  				s.schemas[schemaName] = append(schemaRevisions[:i], schemaRevisions[i+1:]...)
  1652  				return schemaRevisions[len(schemaRevisions)-1], nil
  1653  			}
  1654  		}
  1655  	}
  1656  
  1657  	return nil, status.Errorf(codes.NotFound, "schema %q not found", req.Name)
  1658  }
  1659  
  1660  func (s *GServer) DeleteSchema(_ context.Context, req *pb.DeleteSchemaRequest) (*emptypb.Empty, error) {
  1661  	s.mu.Lock()
  1662  	defer s.mu.Unlock()
  1663  
  1664  	if handled, ret, err := s.runReactor(req, "DeleteSchema", &emptypb.Empty{}); handled || err != nil {
  1665  		return ret.(*emptypb.Empty), err
  1666  	}
  1667  
  1668  	schema := s.schemas[req.Name]
  1669  	if schema == nil {
  1670  		return nil, status.Errorf(codes.NotFound, "schema %q", req.Name)
  1671  	}
  1672  
  1673  	delete(s.schemas, req.Name)
  1674  	return &emptypb.Empty{}, nil
  1675  }
  1676  
  1677  // ValidateSchema mocks the ValidateSchema call but only checks that the schema definition is not empty.
  1678  func (s *GServer) ValidateSchema(_ context.Context, req *pb.ValidateSchemaRequest) (*pb.ValidateSchemaResponse, error) {
  1679  	s.mu.Lock()
  1680  	defer s.mu.Unlock()
  1681  
  1682  	if handled, ret, err := s.runReactor(req, "ValidateSchema", &pb.ValidateSchemaResponse{}); handled || err != nil {
  1683  		return ret.(*pb.ValidateSchemaResponse), err
  1684  	}
  1685  
  1686  	if req.Schema.Definition == "" {
  1687  		return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
  1688  	}
  1689  	return &pb.ValidateSchemaResponse{}, nil
  1690  }
  1691  
  1692  // ValidateMessage mocks the ValidateMessage call but only checks that the schema definition to validate the
  1693  // message against is not empty.
  1694  func (s *GServer) ValidateMessage(_ context.Context, req *pb.ValidateMessageRequest) (*pb.ValidateMessageResponse, error) {
  1695  	s.mu.Lock()
  1696  	defer s.mu.Unlock()
  1697  
  1698  	if handled, ret, err := s.runReactor(req, "ValidateMessage", &pb.ValidateMessageResponse{}); handled || err != nil {
  1699  		return ret.(*pb.ValidateMessageResponse), err
  1700  	}
  1701  
  1702  	spec := req.GetSchemaSpec()
  1703  	if valReq, ok := spec.(*pb.ValidateMessageRequest_Name); ok {
  1704  		sc, ok := s.schemas[valReq.Name]
  1705  		if !ok {
  1706  			return nil, status.Errorf(codes.NotFound, "schema(%q) not found", valReq.Name)
  1707  		}
  1708  		schema := sc[len(sc)-1]
  1709  		if schema.Definition == "" {
  1710  			return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
  1711  		}
  1712  	}
  1713  	if valReq, ok := spec.(*pb.ValidateMessageRequest_Schema); ok {
  1714  		if valReq.Schema.Definition == "" {
  1715  			return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
  1716  		}
  1717  	}
  1718  
  1719  	return &pb.ValidateMessageResponse{}, nil
  1720  }
  1721  

View as plain text