1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package pstest
24
25 import (
26 "context"
27 "fmt"
28 "io"
29 "math/rand"
30 "path"
31 "sort"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "time"
36
37 "cloud.google.com/go/internal/testutil"
38 pb "cloud.google.com/go/pubsub/apiv1/pubsubpb"
39 "go.einride.tech/aip/filtering"
40 "google.golang.org/grpc/codes"
41 "google.golang.org/grpc/status"
42 durpb "google.golang.org/protobuf/types/known/durationpb"
43 "google.golang.org/protobuf/types/known/emptypb"
44 "google.golang.org/protobuf/types/known/timestamppb"
45 )
46
47
48
49 type ReactorOptions map[string][]Reactor
50
51
52 type Reactor interface {
53
54
55
56 React(_ interface{}) (handled bool, ret interface{}, err error)
57 }
58
59
60 type ServerReactorOption struct {
61 FuncName string
62 Reactor Reactor
63 }
64
65 type publishResponse struct {
66 resp *pb.PublishResponse
67 err error
68 }
69
70
71 type Server struct {
72 srv *testutil.Server
73 Addr string
74 GServer GServer
75 }
76
77
78
79 type GServer struct {
80 pb.UnimplementedPublisherServer
81 pb.UnimplementedSubscriberServer
82 pb.UnimplementedSchemaServiceServer
83
84 timeNowFunc atomic.Value
85
86 mu sync.Mutex
87 topics map[string]*topic
88 subs map[string]*subscription
89 msgs []*Message
90 msgsByID map[string]*Message
91 wg sync.WaitGroup
92 nextID int
93 streamTimeout time.Duration
94 reactorOptions ReactorOptions
95
96
97 schemas map[string][]*pb.Schema
98
99
100 publishResponses chan *publishResponse
101
102
103
104 autoPublishResponse bool
105 }
106
107
108 func NewServer(opts ...ServerReactorOption) *Server {
109 return NewServerWithPort(0, opts...)
110 }
111
112
113 func NewServerWithPort(port int, opts ...ServerReactorOption) *Server {
114 srv, err := testutil.NewServerWithPort(port)
115 if err != nil {
116 panic(fmt.Sprintf("pstest.NewServerWithPort: %v", err))
117 }
118 reactorOptions := ReactorOptions{}
119 for _, opt := range opts {
120 reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
121 }
122 s := &Server{
123 srv: srv,
124 Addr: srv.Addr,
125 GServer: GServer{
126 topics: map[string]*topic{},
127 subs: map[string]*subscription{},
128 msgsByID: map[string]*Message{},
129 reactorOptions: reactorOptions,
130 publishResponses: make(chan *publishResponse, 100),
131 autoPublishResponse: true,
132 schemas: map[string][]*pb.Schema{},
133 },
134 }
135 s.GServer.timeNowFunc.Store(time.Now)
136 pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
137 pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
138 pb.RegisterSchemaServiceServer(srv.Gsrv, &s.GServer)
139 srv.Start()
140 return s
141 }
142
143
144
145 func (s *Server) SetTimeNowFunc(f func() time.Time) {
146 s.GServer.timeNowFunc.Store(f)
147 }
148
149 func (s *GServer) now() time.Time {
150 return s.timeNowFunc.Load().(func() time.Time)()
151 }
152
153
154
155
156
157
158 func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
159 return s.PublishOrdered(topic, data, attrs, "")
160 }
161
162
163
164
165
166
167 func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string {
168 const topicPattern = "projects/*/topics/*"
169 ok, err := path.Match(topicPattern, topic)
170 if err != nil {
171 panic(err)
172 }
173 if !ok {
174 panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
175 }
176 s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
177 req := &pb.PublishRequest{
178 Topic: topic,
179 Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}},
180 }
181 res, err := s.GServer.Publish(context.TODO(), req)
182 if err != nil {
183 panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
184 }
185 return res.MessageIds[0]
186 }
187
188
189
190 func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
191 pr := &publishResponse{}
192 if err != nil {
193 pr.err = err
194 } else {
195 pr.resp = pbr
196 }
197 s.GServer.publishResponses <- pr
198 }
199
200
201
202
203 func (s *Server) SetAutoPublishResponse(autoPublishResponse bool) {
204 s.GServer.mu.Lock()
205 defer s.GServer.mu.Unlock()
206 s.GServer.autoPublishResponse = autoPublishResponse
207 }
208
209
210
211 func (s *Server) ResetPublishResponses(size int) {
212 s.GServer.mu.Lock()
213 defer s.GServer.mu.Unlock()
214 s.GServer.publishResponses = make(chan *publishResponse, size)
215 }
216
217
218
219
220
221 func (s *Server) SetStreamTimeout(d time.Duration) {
222 s.GServer.mu.Lock()
223 defer s.GServer.mu.Unlock()
224 s.GServer.streamTimeout = d
225 }
226
227
228 type Message struct {
229 ID string
230 Data []byte
231 Attributes map[string]string
232 PublishTime time.Time
233 Deliveries int
234 Acks int
235 Modacks []Modack
236 OrderingKey string
237
238
239 deliveries int
240 acks int
241 modacks []Modack
242 }
243
244
245 type Modack struct {
246 AckID string
247 AckDeadline int32
248 ReceivedAt time.Time
249 }
250
251
252 func (s *Server) Messages() []*Message {
253 s.GServer.mu.Lock()
254 defer s.GServer.mu.Unlock()
255
256 var msgs []*Message
257 for _, m := range s.GServer.msgs {
258 m.Deliveries = m.deliveries
259 m.Acks = m.acks
260 m.Modacks = append([]Modack(nil), m.modacks...)
261 msgs = append(msgs, m)
262 }
263 return msgs
264 }
265
266
267
268 func (s *Server) Message(id string) *Message {
269 s.GServer.mu.Lock()
270 defer s.GServer.mu.Unlock()
271
272 m := s.GServer.msgsByID[id]
273 if m != nil {
274 m.Deliveries = m.deliveries
275 m.Acks = m.acks
276 m.Modacks = append([]Modack(nil), m.modacks...)
277 }
278 return m
279 }
280
281
282 func (s *Server) Wait() {
283 s.GServer.wg.Wait()
284 }
285
286
287
288 func (s *Server) ClearMessages() {
289 s.GServer.mu.Lock()
290 s.GServer.msgs = nil
291 s.GServer.msgsByID = make(map[string]*Message)
292 for _, sub := range s.GServer.subs {
293 sub.msgs = map[string]*message{}
294 }
295 s.GServer.mu.Unlock()
296 }
297
298
299 func (s *Server) Close() error {
300 s.srv.Close()
301 s.GServer.mu.Lock()
302 defer s.GServer.mu.Unlock()
303 for _, sub := range s.GServer.subs {
304 sub.stop()
305 }
306 return nil
307 }
308
309 func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
310 s.mu.Lock()
311 defer s.mu.Unlock()
312
313 if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
314 return ret.(*pb.Topic), err
315 }
316
317 if s.topics[t.Name] != nil {
318 return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
319 }
320 if err := checkTopicMessageRetention(t.MessageRetentionDuration); err != nil {
321 return nil, err
322 }
323
324 if t.IngestionDataSourceSettings != nil {
325 t.State = pb.Topic_ACTIVE
326 }
327 top := newTopic(t)
328 s.topics[t.Name] = top
329 return top.proto, nil
330 }
331
332 func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
333 s.mu.Lock()
334 defer s.mu.Unlock()
335
336 if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
337 return ret.(*pb.Topic), err
338 }
339
340 if t := s.topics[req.Topic]; t != nil {
341 return t.proto, nil
342 }
343 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
344 }
345
346 func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
347 s.mu.Lock()
348 defer s.mu.Unlock()
349
350 if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
351 return ret.(*pb.Topic), err
352 }
353
354 t := s.topics[req.Topic.Name]
355 if t == nil {
356 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
357 }
358 for _, path := range req.UpdateMask.Paths {
359 switch path {
360 case "labels":
361 t.proto.Labels = req.Topic.Labels
362 case "message_storage_policy":
363 t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy
364 case "message_retention_duration":
365 if err := checkTopicMessageRetention(req.Topic.MessageRetentionDuration); err != nil {
366 return nil, err
367 }
368 t.proto.MessageRetentionDuration = req.Topic.MessageRetentionDuration
369 case "schema_settings":
370 t.proto.SchemaSettings = req.Topic.SchemaSettings
371 case "schema_settings.schema":
372 if t.proto.SchemaSettings == nil {
373 t.proto.SchemaSettings = &pb.SchemaSettings{}
374 }
375 t.proto.SchemaSettings.Schema = req.Topic.SchemaSettings.Schema
376 case "schema_settings.encoding":
377 if t.proto.SchemaSettings == nil {
378 t.proto.SchemaSettings = &pb.SchemaSettings{}
379 }
380 t.proto.SchemaSettings.Encoding = req.Topic.SchemaSettings.Encoding
381 case "schema_settings.first_revision_id":
382 if t.proto.SchemaSettings == nil {
383 t.proto.SchemaSettings = &pb.SchemaSettings{}
384 }
385 t.proto.SchemaSettings.FirstRevisionId = req.Topic.SchemaSettings.FirstRevisionId
386 case "schema_settings.last_revision_id":
387 if t.proto.SchemaSettings == nil {
388 t.proto.SchemaSettings = &pb.SchemaSettings{}
389 }
390 t.proto.SchemaSettings.LastRevisionId = req.Topic.SchemaSettings.LastRevisionId
391 case "ingestion_data_source_settings":
392 if t.proto.IngestionDataSourceSettings == nil {
393 t.proto.IngestionDataSourceSettings = &pb.IngestionDataSourceSettings{}
394 }
395 t.proto.IngestionDataSourceSettings = req.Topic.IngestionDataSourceSettings
396
397 if t.proto.IngestionDataSourceSettings != nil {
398 t.proto.State = pb.Topic_ACTIVE
399 }
400 default:
401 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
402 }
403 }
404 return t.proto, nil
405 }
406
407 func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
408 s.mu.Lock()
409 defer s.mu.Unlock()
410
411 if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
412 return ret.(*pb.ListTopicsResponse), err
413 }
414
415 var names []string
416 for n := range s.topics {
417 if strings.HasPrefix(n, req.Project) {
418 names = append(names, n)
419 }
420 }
421 sort.Strings(names)
422 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
423 if err != nil {
424 return nil, err
425 }
426 res := &pb.ListTopicsResponse{NextPageToken: nextToken}
427 for i := from; i < to; i++ {
428 res.Topics = append(res.Topics, s.topics[names[i]].proto)
429 }
430 return res, nil
431 }
432
433 func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
434 s.mu.Lock()
435 defer s.mu.Unlock()
436
437 if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
438 return ret.(*pb.ListTopicSubscriptionsResponse), err
439 }
440
441 var names []string
442 for name, sub := range s.subs {
443 if sub.topic.proto.Name == req.Topic {
444 names = append(names, name)
445 }
446 }
447 sort.Strings(names)
448 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
449 if err != nil {
450 return nil, err
451 }
452 return &pb.ListTopicSubscriptionsResponse{
453 Subscriptions: names[from:to],
454 NextPageToken: nextToken,
455 }, nil
456 }
457
458 func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
459 s.mu.Lock()
460 defer s.mu.Unlock()
461
462 if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
463 return ret.(*emptypb.Empty), err
464 }
465
466 t := s.topics[req.Topic]
467 if t == nil {
468 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
469 }
470 for _, sub := range s.subs {
471 if sub.deadLetterTopic == nil {
472 continue
473 }
474 if req.Topic == sub.deadLetterTopic.proto.Name {
475 return nil, status.Errorf(codes.FailedPrecondition, "topic %q used as deadLetter for %s", req.Topic, sub.proto.Name)
476 }
477 }
478 t.stop()
479 delete(s.topics, req.Topic)
480 return &emptypb.Empty{}, nil
481 }
482
483 func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
484 s.mu.Lock()
485 defer s.mu.Unlock()
486
487 if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
488 return ret.(*pb.Subscription), err
489 }
490
491 if ps.Name == "" {
492 return nil, status.Errorf(codes.InvalidArgument, "missing name")
493 }
494 if s.subs[ps.Name] != nil {
495 return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
496 }
497 if ps.Topic == "" {
498 return nil, status.Errorf(codes.InvalidArgument, "missing topic")
499 }
500 top := s.topics[ps.Topic]
501 if top == nil {
502 return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
503 }
504 if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
505 return nil, err
506 }
507 if ps.MessageRetentionDuration == nil {
508 ps.MessageRetentionDuration = defaultMessageRetentionDuration
509 }
510 if err := checkSubMessageRetention(ps.MessageRetentionDuration); err != nil {
511 return nil, err
512 }
513 if ps.PushConfig == nil {
514 ps.PushConfig = &pb.PushConfig{}
515 } else if ps.PushConfig.Wrapper == nil {
516
517 ps.PushConfig.Wrapper = &pb.PushConfig_PubsubWrapper_{
518 PubsubWrapper: &pb.PushConfig_PubsubWrapper{},
519 }
520 }
521
522
523
524 if ps.GetBigqueryConfig() != nil && ps.GetBigqueryConfig().GetTable() != "" {
525 ps.BigqueryConfig.State = pb.BigQueryConfig_ACTIVE
526 }
527 if ps.CloudStorageConfig != nil && ps.CloudStorageConfig.Bucket != "" {
528 ps.CloudStorageConfig.State = pb.CloudStorageConfig_ACTIVE
529 }
530 ps.TopicMessageRetentionDuration = top.proto.MessageRetentionDuration
531 var deadLetterTopic *topic
532 if ps.DeadLetterPolicy != nil {
533 dlTopic, ok := s.topics[ps.DeadLetterPolicy.DeadLetterTopic]
534 if !ok {
535 return nil, status.Errorf(codes.NotFound, "deadLetter topic %q", ps.DeadLetterPolicy.DeadLetterTopic)
536 }
537 deadLetterTopic = dlTopic
538 }
539
540 sub := newSubscription(top, &s.mu, s.now, deadLetterTopic, ps)
541 top.subs[ps.Name] = sub
542 s.subs[ps.Name] = sub
543 sub.start(&s.wg)
544 return ps, nil
545 }
546
547
548 var minAckDeadlineSecs int32
549
550
551
552
553
554
555
556 func SetMinAckDeadline(n time.Duration) {
557 if n < time.Second {
558 panic("SetMinAckDeadline expects a value greater than 1 second")
559 }
560
561 minAckDeadlineSecs = int32(n / time.Second)
562 }
563
564
565 func ResetMinAckDeadline() {
566 minAckDeadlineSecs = 10
567 }
568
569 func checkAckDeadline(ads int32) error {
570 if ads < minAckDeadlineSecs || ads > 600 {
571
572 return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
573 }
574 return nil
575 }
576
577 const (
578 minTopicMessageRetentionDuration = 10 * time.Minute
579
580 maxTopicMessageRetentionDuration = 31 * 24 * time.Hour
581 minSubMessageRetentionDuration = 10 * time.Minute
582
583 maxSubMessageRetentionDuration = 7 * 24 * time.Hour
584 )
585
586 var defaultMessageRetentionDuration = durpb.New(168 * time.Hour)
587
588 func checkTopicMessageRetention(pmrd *durpb.Duration) error {
589 if pmrd == nil {
590 return nil
591 }
592 mrd := pmrd.AsDuration()
593 if mrd < minTopicMessageRetentionDuration || mrd > maxTopicMessageRetentionDuration {
594 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
595 }
596 return nil
597 }
598
599 func checkSubMessageRetention(pmrd *durpb.Duration) error {
600 if pmrd == nil {
601 return nil
602 }
603 mrd := pmrd.AsDuration()
604 if mrd < minSubMessageRetentionDuration || mrd > maxSubMessageRetentionDuration {
605 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
606 }
607 return nil
608 }
609
610 func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
611 s.mu.Lock()
612 defer s.mu.Unlock()
613
614 if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
615 return ret.(*pb.Subscription), err
616 }
617
618 sub, err := s.findSubscription(req.Subscription)
619 if err != nil {
620 return nil, err
621 }
622 return sub.proto, nil
623 }
624
625 func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
626 if req.Subscription == nil {
627 return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
628 }
629 s.mu.Lock()
630 defer s.mu.Unlock()
631
632 if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
633 return ret.(*pb.Subscription), err
634 }
635
636 sub, err := s.findSubscription(req.Subscription.Name)
637 if err != nil {
638 return nil, err
639 }
640 for _, path := range req.UpdateMask.Paths {
641 switch path {
642 case "push_config":
643 sub.proto.PushConfig = req.Subscription.PushConfig
644
645 case "bigquery_config":
646
647
648 sub.proto.BigqueryConfig = req.GetSubscription().GetBigqueryConfig()
649 if sub.proto.GetBigqueryConfig() != nil {
650 if sub.proto.GetBigqueryConfig().GetTable() != "" {
651 sub.proto.BigqueryConfig.State = pb.BigQueryConfig_ACTIVE
652 } else {
653 return nil, status.Errorf(codes.InvalidArgument, "table must be provided")
654 }
655 }
656
657 case "cloud_storage_config":
658 sub.proto.CloudStorageConfig = req.GetSubscription().GetCloudStorageConfig()
659
660
661 if sub.proto.GetCloudStorageConfig() != nil {
662 sub.proto.CloudStorageConfig.State = pb.CloudStorageConfig_ACTIVE
663 }
664
665 case "ack_deadline_seconds":
666 a := req.Subscription.AckDeadlineSeconds
667 if err := checkAckDeadline(a); err != nil {
668 return nil, err
669 }
670 sub.proto.AckDeadlineSeconds = a
671
672 case "retain_acked_messages":
673 sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
674
675 case "message_retention_duration":
676 if err := checkSubMessageRetention(req.Subscription.MessageRetentionDuration); err != nil {
677 return nil, err
678 }
679 sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
680
681 case "labels":
682 sub.proto.Labels = req.Subscription.Labels
683
684 case "expiration_policy":
685 sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
686
687 case "dead_letter_policy":
688 sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy
689 if sub.proto.DeadLetterPolicy != nil {
690 dlTopic, ok := s.topics[sub.proto.DeadLetterPolicy.DeadLetterTopic]
691 if !ok {
692 return nil, status.Errorf(codes.NotFound, "topic %q", sub.proto.DeadLetterPolicy.DeadLetterTopic)
693 }
694 sub.deadLetterTopic = dlTopic
695 }
696
697 case "retry_policy":
698 sub.proto.RetryPolicy = req.Subscription.RetryPolicy
699
700 case "filter":
701 filter, err := parseFilter(req.Subscription.Filter)
702 if err != nil {
703 return nil, status.Errorf(codes.InvalidArgument, "bad filter: %v", err)
704 }
705 sub.filter = &filter
706 sub.proto.Filter = req.Subscription.Filter
707
708 case "enable_exactly_once_delivery":
709 sub.proto.EnableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery
710 for _, st := range sub.streams {
711 st.enableExactlyOnceDelivery = req.Subscription.EnableExactlyOnceDelivery
712 }
713
714 default:
715 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
716 }
717 }
718 return sub.proto, nil
719 }
720
721 func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
722 s.mu.Lock()
723 defer s.mu.Unlock()
724
725 if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
726 return ret.(*pb.ListSubscriptionsResponse), err
727 }
728
729 var names []string
730 for name := range s.subs {
731 if strings.HasPrefix(name, req.Project) {
732 names = append(names, name)
733 }
734 }
735 sort.Strings(names)
736 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
737 if err != nil {
738 return nil, err
739 }
740 res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
741 for i := from; i < to; i++ {
742 res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
743 }
744 return res, nil
745 }
746
747 func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
748 s.mu.Lock()
749 defer s.mu.Unlock()
750
751 if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
752 return ret.(*emptypb.Empty), err
753 }
754
755 sub, err := s.findSubscription(req.Subscription)
756 if err != nil {
757 return nil, err
758 }
759 sub.stop()
760 delete(s.subs, req.Subscription)
761 sub.topic.deleteSub(sub)
762 return &emptypb.Empty{}, nil
763 }
764
765 func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
766 s.mu.Lock()
767 defer s.mu.Unlock()
768
769 if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
770 return ret.(*pb.DetachSubscriptionResponse), err
771 }
772
773 sub, err := s.findSubscription(req.Subscription)
774 if err != nil {
775 return nil, err
776 }
777 sub.topic.deleteSub(sub)
778 return &pb.DetachSubscriptionResponse{}, nil
779 }
780
781 func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
782 s.mu.Lock()
783 defer s.mu.Unlock()
784
785 if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
786 return ret.(*pb.PublishResponse), err
787 }
788
789 if req.Topic == "" {
790 return nil, status.Errorf(codes.InvalidArgument, "missing topic")
791 }
792 top := s.topics[req.Topic]
793 if top == nil {
794 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
795 }
796
797 if !s.autoPublishResponse {
798 r := <-s.publishResponses
799 if r.err != nil {
800 return nil, r.err
801 }
802 return r.resp, nil
803 }
804
805 var ids []string
806 for _, pm := range req.Messages {
807 id := fmt.Sprintf("m%d", s.nextID)
808 s.nextID++
809 pm.MessageId = id
810 pubTime := s.now()
811 tsPubTime := timestamppb.New(pubTime)
812 pm.PublishTime = tsPubTime
813 m := &Message{
814 ID: id,
815 Data: pm.Data,
816 Attributes: pm.Attributes,
817 PublishTime: pubTime,
818 OrderingKey: pm.OrderingKey,
819 }
820 top.publish(pm, m)
821 ids = append(ids, id)
822 s.msgs = append(s.msgs, m)
823 s.msgsByID[id] = m
824 }
825 return &pb.PublishResponse{MessageIds: ids}, nil
826 }
827
828 type topic struct {
829 proto *pb.Topic
830 subs map[string]*subscription
831 }
832
833 func newTopic(pt *pb.Topic) *topic {
834 return &topic{
835 proto: pt,
836 subs: map[string]*subscription{},
837 }
838 }
839
840 func (t *topic) stop() {
841 for _, sub := range t.subs {
842 sub.proto.Topic = "_deleted-topic_"
843 }
844 }
845
846 func (t *topic) deleteSub(sub *subscription) {
847 delete(t.subs, sub.proto.Name)
848 }
849
850 func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
851 for _, s := range t.subs {
852 s.msgs[pm.MessageId] = &message{
853 publishTime: m.PublishTime,
854 proto: &pb.ReceivedMessage{
855 AckId: pm.MessageId,
856 Message: pm,
857 },
858 deliveries: &m.deliveries,
859 acks: &m.acks,
860 streamIndex: -1,
861 }
862 }
863 }
864
865 type subscription struct {
866 topic *topic
867 deadLetterTopic *topic
868 mu *sync.Mutex
869 proto *pb.Subscription
870 ackTimeout time.Duration
871 msgs map[string]*message
872 streams []*stream
873 done chan struct{}
874 timeNowFunc func() time.Time
875 filter *filtering.Filter
876 }
877
878 func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, deadLetterTopic *topic, ps *pb.Subscription) *subscription {
879 at := time.Duration(ps.AckDeadlineSeconds) * time.Second
880 if at == 0 {
881 at = 10 * time.Second
882 }
883 ps.State = pb.Subscription_ACTIVE
884 sub := &subscription{
885 topic: t,
886 deadLetterTopic: deadLetterTopic,
887 mu: mu,
888 proto: ps,
889 ackTimeout: at,
890 msgs: map[string]*message{},
891 done: make(chan struct{}),
892 timeNowFunc: timeNowFunc,
893 }
894 if ps.Filter != "" {
895 filter, err := parseFilter(ps.Filter)
896 if err != nil {
897 panic(fmt.Sprintf("pstest: bad filter: %v", err))
898 }
899 sub.filter = &filter
900 }
901 return sub
902 }
903
904 func (s *subscription) start(wg *sync.WaitGroup) {
905 wg.Add(1)
906 go func() {
907 defer wg.Done()
908 for {
909 select {
910 case <-s.done:
911 return
912 case <-time.After(10 * time.Millisecond):
913 s.deliver()
914 }
915 }
916 }()
917 }
918
919 func (s *subscription) stop() {
920 close(s.done)
921 }
922
923 func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
924 s.mu.Lock()
925 defer s.mu.Unlock()
926
927 if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
928 return ret.(*emptypb.Empty), err
929 }
930
931 sub, err := s.findSubscription(req.Subscription)
932 if err != nil {
933 return nil, err
934 }
935 for _, id := range req.AckIds {
936 sub.ack(id)
937 }
938 return &emptypb.Empty{}, nil
939 }
940
941 func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
942 s.mu.Lock()
943 defer s.mu.Unlock()
944
945 if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
946 return ret.(*emptypb.Empty), err
947 }
948
949 sub, err := s.findSubscription(req.Subscription)
950 if err != nil {
951 return nil, err
952 }
953 now := time.Now()
954 for _, id := range req.AckIds {
955 s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
956 }
957 dur := secsToDur(req.AckDeadlineSeconds)
958 for _, id := range req.AckIds {
959 sub.modifyAckDeadline(id, dur)
960 }
961 return &emptypb.Empty{}, nil
962 }
963
964 func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
965 s.mu.Lock()
966
967 if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
968 s.mu.Unlock()
969 return ret.(*pb.PullResponse), err
970 }
971
972 sub, err := s.findSubscription(req.Subscription)
973 if err != nil {
974 s.mu.Unlock()
975 return nil, err
976 }
977 max := int(req.MaxMessages)
978 if max < 0 {
979 s.mu.Unlock()
980 return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
981 }
982 if max == 0 {
983 max = 1000
984 }
985 msgs := sub.pull(max)
986 s.mu.Unlock()
987
988
989
990
991
992 if len(msgs) == 0 && !req.ReturnImmediately {
993
994
995 select {
996 case <-ctx.Done():
997 return nil, ctx.Err()
998 case <-time.After(500 * time.Millisecond):
999 s.mu.Lock()
1000 msgs = sub.pull(max)
1001 s.mu.Unlock()
1002 }
1003 }
1004 return &pb.PullResponse{ReceivedMessages: msgs}, nil
1005 }
1006
1007 func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
1008
1009 req, err := sps.Recv()
1010 if err != nil {
1011 return err
1012 }
1013 s.mu.Lock()
1014 sub, err := s.findSubscription(req.Subscription)
1015 s.mu.Unlock()
1016 if err != nil {
1017 return err
1018 }
1019
1020 st := sub.newStream(sps, s.streamTimeout)
1021 st.ackTimeout = time.Duration(req.StreamAckDeadlineSeconds) * time.Second
1022 err = st.pull(&s.wg)
1023 sub.deleteStream(st)
1024 return err
1025 }
1026
1027 func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
1028
1029
1030 var target time.Time
1031 switch v := req.Target.(type) {
1032 case nil:
1033 return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
1034 case *pb.SeekRequest_Time:
1035 target = v.Time.AsTime()
1036 default:
1037 return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
1038 }
1039
1040
1041
1042 s.mu.Lock()
1043 defer s.mu.Unlock()
1044
1045 if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
1046 return ret.(*pb.SeekResponse), err
1047 }
1048
1049 sub, err := s.findSubscription(req.Subscription)
1050 if err != nil {
1051 return nil, err
1052 }
1053
1054 for id, m := range sub.msgs {
1055 if m.publishTime.Before(target) {
1056 delete(sub.msgs, id)
1057 (*m.acks)++
1058 }
1059 }
1060
1061
1062 for _, m := range s.msgs {
1063 if m.PublishTime.Before(target) {
1064 continue
1065 }
1066 sub.msgs[m.ID] = &message{
1067 publishTime: m.PublishTime,
1068 proto: &pb.ReceivedMessage{
1069 AckId: m.ID,
1070
1071
1072 },
1073 deliveries: &m.deliveries,
1074 acks: &m.acks,
1075 streamIndex: -1,
1076 }
1077 }
1078 return &pb.SeekResponse{}, nil
1079 }
1080
1081
1082
1083 func (s *GServer) findSubscription(name string) (*subscription, error) {
1084 if name == "" {
1085 return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
1086 }
1087 sub := s.subs[name]
1088 if sub == nil {
1089 return nil, status.Errorf(codes.NotFound, "subscription %s", name)
1090 }
1091 return sub, nil
1092 }
1093
1094
1095 func (s *subscription) pull(max int) []*pb.ReceivedMessage {
1096 now := s.timeNowFunc()
1097 s.maintainMessages(now)
1098 var msgs []*pb.ReceivedMessage
1099 filterMsgs(s.msgs, s.filter)
1100 for id, m := range orderMsgs(s.msgs, s.proto.EnableMessageOrdering) {
1101 if m.outstanding() {
1102 continue
1103 }
1104 if s.deadLetterCandidate(m) {
1105 s.ack(id)
1106 s.publishToDeadLetter(m)
1107 continue
1108 }
1109 (*m.deliveries)++
1110 if s.proto.DeadLetterPolicy != nil {
1111 m.proto.DeliveryAttempt = int32(*m.deliveries)
1112 }
1113 m.ackDeadline = now.Add(s.ackTimeout)
1114 msgs = append(msgs, m.proto)
1115 if len(msgs) >= max {
1116 break
1117 }
1118 }
1119 return msgs
1120 }
1121
1122 func orderMsgs(msgs map[string]*message, enableMessageOrdering bool) map[string]*message {
1123 if !enableMessageOrdering {
1124 return msgs
1125 }
1126 result := make(map[string]*message)
1127
1128 type msg struct {
1129 id string
1130 m *message
1131 }
1132 orderingKeyMap := make(map[string]msg)
1133 for id, m := range msgs {
1134 orderingKey := m.proto.Message.OrderingKey
1135 if orderingKey == "" {
1136 orderingKey = id
1137 }
1138 if val, ok := orderingKeyMap[orderingKey]; !ok || m.proto.Message.PublishTime.AsTime().Before(val.m.proto.Message.PublishTime.AsTime()) {
1139 orderingKeyMap[orderingKey] = msg{m: m, id: id}
1140 }
1141 }
1142 for _, val := range orderingKeyMap {
1143 result[val.id] = val.m
1144 }
1145 return result
1146 }
1147
1148 func filterMsgs(msgs map[string]*message, filter *filtering.Filter) {
1149 if filter == nil {
1150 return
1151 }
1152
1153 filterByAttrs(msgs, filter, func(m *message) messageAttrs {
1154 return m.proto.Message.Attributes
1155 })
1156 }
1157
1158 func (s *subscription) deliver() {
1159 s.mu.Lock()
1160 defer s.mu.Unlock()
1161
1162 now := s.timeNowFunc()
1163 s.maintainMessages(now)
1164
1165 curIndex := 0
1166 filterMsgs(s.msgs, s.filter)
1167 for id, m := range orderMsgs(s.msgs, s.proto.EnableMessageOrdering) {
1168 if m.outstanding() {
1169 continue
1170 }
1171 if s.deadLetterCandidate(m) {
1172 s.ack(id)
1173 s.publishToDeadLetter(m)
1174 continue
1175 }
1176
1177
1178
1179 if m.streamIndex < 0 {
1180 delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
1181 if !ok {
1182 break
1183 }
1184 curIndex = delIndex + 1
1185 m.streamIndex = curIndex
1186 } else {
1187 delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
1188 if !ok {
1189 break
1190 }
1191 m.streamIndex = delIndex
1192 }
1193 }
1194 }
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204 func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
1205
1206
1207
1208 if s.proto.DeadLetterPolicy != nil {
1209 m.proto.DeliveryAttempt = int32(*m.deliveries) + 1
1210 }
1211
1212 for i := 0; i < len(s.streams); i++ {
1213 idx := (i + start) % len(s.streams)
1214
1215 st := s.streams[idx]
1216 select {
1217 case <-st.done:
1218 s.streams = deleteStreamAt(s.streams, idx)
1219 i--
1220
1221 case st.msgc <- m.proto:
1222 (*m.deliveries)++
1223 m.ackDeadline = now.Add(st.ackTimeout)
1224 return idx, true
1225
1226 default:
1227 }
1228 }
1229
1230 if s.proto.DeadLetterPolicy != nil {
1231 m.proto.DeliveryAttempt = int32(*m.deliveries)
1232 }
1233 return 0, false
1234 }
1235
1236 const retentionDuration = 10 * time.Minute
1237
1238
1239 func (s *subscription) maintainMessages(now time.Time) {
1240 for id, m := range s.msgs {
1241
1242 if m.outstanding() && now.After(m.ackDeadline) {
1243 m.makeAvailable()
1244 }
1245 pubTime := m.proto.Message.PublishTime.AsTime()
1246
1247 if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
1248 delete(s.msgs, id)
1249 }
1250 }
1251 }
1252
1253 func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
1254 st := &stream{
1255 sub: s,
1256 done: make(chan struct{}),
1257 msgc: make(chan *pb.ReceivedMessage),
1258 gstream: gs,
1259 ackTimeout: s.ackTimeout,
1260 timeout: timeout,
1261 enableExactlyOnceDelivery: s.proto.EnableExactlyOnceDelivery,
1262 enableOrdering: s.proto.EnableMessageOrdering,
1263 }
1264 s.mu.Lock()
1265 s.streams = append(s.streams, st)
1266 s.mu.Unlock()
1267 return st
1268 }
1269
1270 func (s *subscription) deleteStream(st *stream) {
1271 s.mu.Lock()
1272 defer s.mu.Unlock()
1273 var i int
1274 for i = 0; i < len(s.streams); i++ {
1275 if s.streams[i] == st {
1276 break
1277 }
1278 }
1279 if i < len(s.streams) {
1280 s.streams = deleteStreamAt(s.streams, i)
1281 }
1282 }
1283
1284 func (s *subscription) deadLetterCandidate(m *message) bool {
1285 if s.proto.DeadLetterPolicy == nil {
1286 return false
1287 }
1288 if m.retriesDone(s.proto.DeadLetterPolicy.MaxDeliveryAttempts) {
1289 return true
1290 }
1291 return false
1292 }
1293
1294 func (s *subscription) publishToDeadLetter(m *message) {
1295 acks := 0
1296 if m.acks != nil {
1297 acks = *m.acks
1298 }
1299 deliveries := 0
1300 if m.deliveries != nil {
1301 deliveries = *m.deliveries
1302 }
1303 s.deadLetterTopic.publish(m.proto.Message, &Message{
1304 PublishTime: m.publishTime,
1305 Acks: acks,
1306 Deliveries: deliveries,
1307 })
1308 }
1309
1310 func deleteStreamAt(s []*stream, i int) []*stream {
1311
1312 return append(s[:i], s[i+1:]...)
1313 }
1314
1315 type message struct {
1316 proto *pb.ReceivedMessage
1317 publishTime time.Time
1318 ackDeadline time.Time
1319 deliveries *int
1320 acks *int
1321 streamIndex int
1322 }
1323
1324
1325 func (m *message) outstanding() bool {
1326 return !m.ackDeadline.IsZero()
1327 }
1328
1329
1330 func (m *message) retriesDone(maxRetries int32) bool {
1331 return m.deliveries != nil && int32(*m.deliveries) >= maxRetries
1332 }
1333
1334 func (m *message) makeAvailable() {
1335 m.ackDeadline = time.Time{}
1336 }
1337
1338 type stream struct {
1339 sub *subscription
1340 done chan struct{}
1341 msgc chan *pb.ReceivedMessage
1342 gstream pb.Subscriber_StreamingPullServer
1343 ackTimeout time.Duration
1344 timeout time.Duration
1345 enableExactlyOnceDelivery bool
1346 enableOrdering bool
1347 }
1348
1349
1350 func (st *stream) pull(wg *sync.WaitGroup) error {
1351 errc := make(chan error, 2)
1352 wg.Add(2)
1353 go func() {
1354 defer wg.Done()
1355 errc <- st.sendLoop()
1356 }()
1357 go func() {
1358 defer wg.Done()
1359 errc <- st.recvLoop()
1360 }()
1361 var tchan <-chan time.Time
1362 if st.timeout > 0 {
1363 tchan = time.After(st.timeout)
1364 }
1365
1366 var err error
1367 select {
1368 case err = <-errc:
1369 if err == io.EOF {
1370 err = nil
1371 }
1372 case <-tchan:
1373 }
1374 close(st.done)
1375 return err
1376 }
1377
1378 func (st *stream) sendLoop() error {
1379 for {
1380 select {
1381 case <-st.done:
1382 return nil
1383 case rm := <-st.msgc:
1384 res := &pb.StreamingPullResponse{
1385 ReceivedMessages: []*pb.ReceivedMessage{rm},
1386 SubscriptionProperties: &pb.StreamingPullResponse_SubscriptionProperties{
1387 ExactlyOnceDeliveryEnabled: st.enableExactlyOnceDelivery,
1388 MessageOrderingEnabled: st.enableOrdering,
1389 },
1390 }
1391 if err := st.gstream.Send(res); err != nil {
1392 return err
1393 }
1394 }
1395 }
1396 }
1397
1398 func (st *stream) recvLoop() error {
1399 for {
1400 req, err := st.gstream.Recv()
1401 if err != nil {
1402 return err
1403 }
1404 st.sub.handleStreamingPullRequest(st, req)
1405 }
1406 }
1407
1408 func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
1409
1410 s.mu.Lock()
1411 defer s.mu.Unlock()
1412
1413 for _, ackID := range req.AckIds {
1414 s.ack(ackID)
1415 }
1416 for i, id := range req.ModifyDeadlineAckIds {
1417 s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
1418 }
1419 if req.StreamAckDeadlineSeconds > 0 {
1420 st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
1421 }
1422 }
1423
1424
1425 func (s *subscription) ack(id string) {
1426 m := s.msgs[id]
1427 if m != nil {
1428 (*m.acks)++
1429 delete(s.msgs, id)
1430 }
1431 }
1432
1433
1434 func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
1435 m := s.msgs[id]
1436 if m == nil {
1437 return
1438 }
1439 if d == 0 {
1440 m.makeAvailable()
1441 } else {
1442 m.ackDeadline = s.timeNowFunc().Add(d)
1443 }
1444 }
1445
1446 func secsToDur(secs int32) time.Duration {
1447 return time.Duration(secs) * time.Second
1448 }
1449
1450
1451
1452 func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
1453 if val, ok := s.reactorOptions[funcName]; ok {
1454 for _, reactor := range val {
1455 handled, ret, err := reactor.React(req)
1456
1457
1458
1459 if handled || err != nil {
1460 if ret == nil {
1461 ret = defaultObj
1462 }
1463 return true, ret, err
1464 }
1465 }
1466 }
1467 return false, nil, nil
1468 }
1469
1470
1471 type errorInjectionReactor struct {
1472 code codes.Code
1473 msg string
1474 }
1475
1476
1477 func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
1478 return true, nil, status.Errorf(e.code, e.msg)
1479 }
1480
1481
1482
1483 func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption {
1484 return ServerReactorOption{
1485 FuncName: funcName,
1486 Reactor: &errorInjectionReactor{code: code, msg: msg},
1487 }
1488 }
1489
1490 const letters = "abcdef1234567890"
1491
1492 func genRevID() string {
1493 id := make([]byte, 8)
1494 for i := range id {
1495 id[i] = letters[rand.Intn(len(letters))]
1496 }
1497 return string(id)
1498 }
1499
1500 func (s *GServer) CreateSchema(_ context.Context, req *pb.CreateSchemaRequest) (*pb.Schema, error) {
1501 s.mu.Lock()
1502 defer s.mu.Unlock()
1503
1504 if handled, ret, err := s.runReactor(req, "CreateSchema", &pb.Schema{}); handled || err != nil {
1505 return ret.(*pb.Schema), err
1506 }
1507
1508 name := fmt.Sprintf("%s/schemas/%s", req.Parent, req.SchemaId)
1509 sc := &pb.Schema{
1510 Name: name,
1511 Type: req.Schema.Type,
1512 Definition: req.Schema.Definition,
1513 RevisionId: genRevID(),
1514 RevisionCreateTime: timestamppb.Now(),
1515 }
1516 s.schemas[name] = append(s.schemas[name], sc)
1517
1518 return sc, nil
1519 }
1520
1521 func (s *GServer) GetSchema(_ context.Context, req *pb.GetSchemaRequest) (*pb.Schema, error) {
1522 s.mu.Lock()
1523 defer s.mu.Unlock()
1524
1525 if handled, ret, err := s.runReactor(req, "GetSchema", &pb.Schema{}); handled || err != nil {
1526 return ret.(*pb.Schema), err
1527 }
1528
1529 ss := strings.Split(req.Name, "@")
1530 var schemaName, revisionID string
1531 if len := len(ss); len == 1 {
1532 schemaName = ss[0]
1533 } else if len == 2 {
1534 schemaName = ss[0]
1535 revisionID = ss[1]
1536 } else {
1537 return nil, status.Errorf(codes.InvalidArgument, "schema(%q) name parse error", req.Name)
1538 }
1539
1540 schemaRev, ok := s.schemas[schemaName]
1541 if !ok {
1542 return nil, status.Errorf(codes.NotFound, "schema(%q) not found", req.Name)
1543 }
1544
1545 if revisionID == "" {
1546 return schemaRev[len(schemaRev)-1], nil
1547 }
1548
1549 for _, sc := range schemaRev {
1550 if sc.RevisionId == revisionID {
1551 return sc, nil
1552 }
1553 }
1554
1555 return nil, status.Errorf(codes.NotFound, "schema %q not found", req.Name)
1556 }
1557
1558 func (s *GServer) ListSchemas(_ context.Context, req *pb.ListSchemasRequest) (*pb.ListSchemasResponse, error) {
1559 s.mu.Lock()
1560 defer s.mu.Unlock()
1561
1562 if handled, ret, err := s.runReactor(req, "ListSchemas", &pb.ListSchemasResponse{}); handled || err != nil {
1563 return ret.(*pb.ListSchemasResponse), err
1564 }
1565 ss := make([]*pb.Schema, 0)
1566 for _, sc := range s.schemas {
1567 ss = append(ss, sc[len(sc)-1])
1568 }
1569 return &pb.ListSchemasResponse{
1570 Schemas: ss,
1571 }, nil
1572 }
1573
1574 func (s *GServer) ListSchemaRevisions(_ context.Context, req *pb.ListSchemaRevisionsRequest) (*pb.ListSchemaRevisionsResponse, error) {
1575 s.mu.Lock()
1576 defer s.mu.Unlock()
1577
1578 if handled, ret, err := s.runReactor(req, "ListSchemaRevisions", &pb.ListSchemasResponse{}); handled || err != nil {
1579 return ret.(*pb.ListSchemaRevisionsResponse), err
1580 }
1581 ss := make([]*pb.Schema, 0)
1582 ss = append(ss, s.schemas[req.Name]...)
1583 return &pb.ListSchemaRevisionsResponse{
1584 Schemas: ss,
1585 }, nil
1586 }
1587
1588 func (s *GServer) CommitSchema(_ context.Context, req *pb.CommitSchemaRequest) (*pb.Schema, error) {
1589 s.mu.Lock()
1590 defer s.mu.Unlock()
1591
1592 if handled, ret, err := s.runReactor(req, "CommitSchema", &pb.Schema{}); handled || err != nil {
1593 return ret.(*pb.Schema), err
1594 }
1595
1596 sc := &pb.Schema{
1597 Name: req.Name,
1598 Type: req.Schema.Type,
1599 Definition: req.Schema.Definition,
1600 }
1601 sc.RevisionId = genRevID()
1602 sc.RevisionCreateTime = timestamppb.Now()
1603
1604 s.schemas[req.Name] = append(s.schemas[req.Name], sc)
1605
1606 return sc, nil
1607 }
1608
1609
1610 func (s *GServer) RollbackSchema(_ context.Context, req *pb.RollbackSchemaRequest) (*pb.Schema, error) {
1611 s.mu.Lock()
1612 defer s.mu.Unlock()
1613
1614 if handled, ret, err := s.runReactor(req, "RollbackSchema", &pb.Schema{}); handled || err != nil {
1615 return ret.(*pb.Schema), err
1616 }
1617
1618 for _, sc := range s.schemas[req.Name] {
1619 if sc.RevisionId == req.RevisionId {
1620 newSchema := *sc
1621 newSchema.RevisionId = genRevID()
1622 newSchema.RevisionCreateTime = timestamppb.Now()
1623 s.schemas[req.Name] = append(s.schemas[req.Name], &newSchema)
1624 return &newSchema, nil
1625 }
1626 }
1627 return nil, status.Errorf(codes.NotFound, "schema %q@%q not found", req.Name, req.RevisionId)
1628 }
1629
1630 func (s *GServer) DeleteSchemaRevision(_ context.Context, req *pb.DeleteSchemaRevisionRequest) (*pb.Schema, error) {
1631 s.mu.Lock()
1632 defer s.mu.Unlock()
1633
1634 if handled, ret, err := s.runReactor(req, "DeleteSchemaRevision", &pb.Schema{}); handled || err != nil {
1635 return ret.(*pb.Schema), err
1636 }
1637
1638 schemaPath := strings.Split(req.Name, "@")
1639 if len(schemaPath) != 2 {
1640 return nil, status.Errorf(codes.InvalidArgument, "could not parse revision ID from schema name: %q", req.Name)
1641 }
1642 schemaName := schemaPath[0]
1643 revID := schemaPath[1]
1644 schemaRevisions, ok := s.schemas[schemaName]
1645 if ok {
1646 if len(schemaRevisions) == 1 {
1647 return nil, status.Errorf(codes.InvalidArgument, "cannot delete last revision for schema %q", req.Name)
1648 }
1649 for i, sc := range schemaRevisions {
1650 if sc.RevisionId == revID {
1651 s.schemas[schemaName] = append(schemaRevisions[:i], schemaRevisions[i+1:]...)
1652 return schemaRevisions[len(schemaRevisions)-1], nil
1653 }
1654 }
1655 }
1656
1657 return nil, status.Errorf(codes.NotFound, "schema %q not found", req.Name)
1658 }
1659
1660 func (s *GServer) DeleteSchema(_ context.Context, req *pb.DeleteSchemaRequest) (*emptypb.Empty, error) {
1661 s.mu.Lock()
1662 defer s.mu.Unlock()
1663
1664 if handled, ret, err := s.runReactor(req, "DeleteSchema", &emptypb.Empty{}); handled || err != nil {
1665 return ret.(*emptypb.Empty), err
1666 }
1667
1668 schema := s.schemas[req.Name]
1669 if schema == nil {
1670 return nil, status.Errorf(codes.NotFound, "schema %q", req.Name)
1671 }
1672
1673 delete(s.schemas, req.Name)
1674 return &emptypb.Empty{}, nil
1675 }
1676
1677
1678 func (s *GServer) ValidateSchema(_ context.Context, req *pb.ValidateSchemaRequest) (*pb.ValidateSchemaResponse, error) {
1679 s.mu.Lock()
1680 defer s.mu.Unlock()
1681
1682 if handled, ret, err := s.runReactor(req, "ValidateSchema", &pb.ValidateSchemaResponse{}); handled || err != nil {
1683 return ret.(*pb.ValidateSchemaResponse), err
1684 }
1685
1686 if req.Schema.Definition == "" {
1687 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1688 }
1689 return &pb.ValidateSchemaResponse{}, nil
1690 }
1691
1692
1693
1694 func (s *GServer) ValidateMessage(_ context.Context, req *pb.ValidateMessageRequest) (*pb.ValidateMessageResponse, error) {
1695 s.mu.Lock()
1696 defer s.mu.Unlock()
1697
1698 if handled, ret, err := s.runReactor(req, "ValidateMessage", &pb.ValidateMessageResponse{}); handled || err != nil {
1699 return ret.(*pb.ValidateMessageResponse), err
1700 }
1701
1702 spec := req.GetSchemaSpec()
1703 if valReq, ok := spec.(*pb.ValidateMessageRequest_Name); ok {
1704 sc, ok := s.schemas[valReq.Name]
1705 if !ok {
1706 return nil, status.Errorf(codes.NotFound, "schema(%q) not found", valReq.Name)
1707 }
1708 schema := sc[len(sc)-1]
1709 if schema.Definition == "" {
1710 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1711 }
1712 }
1713 if valReq, ok := spec.(*pb.ValidateMessageRequest_Schema); ok {
1714 if valReq.Schema.Definition == "" {
1715 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1716 }
1717 }
1718
1719 return &pb.ValidateMessageResponse{}, nil
1720 }
1721
View as plain text