package server import ( "context" "errors" "fmt" "sync" "time" "edge-infra.dev/pkg/lib/fog" "cloud.google.com/go/pubsub" "google.golang.org/api/option" "google.golang.org/grpc" ) var ErrPollMaxRetries = fmt.Errorf("maximum poll retries reached") func (p *PSQLInjector) receiverMux() (*ReceiverMux, error) { return NewReceiverMux(&ReceiverMuxConfig{ PollFunc: p.sql.GetBannerProjectIDs, Handler: p.HandleMsg, ForemanProjectID: p.cfg.ForemanProjectID, SubscriptionID: p.cfg.SubscriptionID, TopicID: p.cfg.TopicID, Conn: p.cfg.TestPubSubConn, PollPeriod: p.cfg.PollBannersPeriod, PollMaxRetries: p.cfg.PollBannersMaxRetries, PollSubscriptionExistsPeriod: p.cfg.PollSubscriptionExistsPeriod, }) } type ReceiverMuxConfig struct { // Handler receives messages from each of the multiplexed subscriptions Handler ReceiverHandler // PollFunc returns slice of GCP Project IDs PollFunc func(context.Context) ([]string, error) PollPeriod time.Duration PollMaxRetries int SubscriptionID string TopicID string // TODO perhaps check this in the Receiver... or remove it since it's not used Conn *grpc.ClientConn // optional grpc client for testing ForemanProjectID string // needed to create a shared pubsub.Client // PollSubscriptionExistsPeriod is passed into Receivers to periodically check if their subscription still exists PollSubscriptionExistsPeriod time.Duration } // ReceiverMux is a pubsub.Subscription multiplexer. // // It polls the PollFunc for projectIDs, dynamically subscribes to the configured pubsub.Subscription in that project, and routes every pubsub.Message to a common Handler function. type ReceiverMux struct { cfg *ReceiverMuxConfig client *pubsub.Client // pollErrs is used synchronously (please don't introduce a race condition). pollErrs []error sync.Mutex // protects "receivers" receivers map[string]*Receiver } func NewReceiverMux(cfg *ReceiverMuxConfig) (*ReceiverMux, error) { var googleOptions []option.ClientOption if cfg.Conn != nil { googleOptions = append(googleOptions, option.WithGRPCConn(cfg.Conn)) } client, err := pubsub.NewClient(context.Background(), cfg.ForemanProjectID, googleOptions...) if err != nil { return nil, fmt.Errorf("error creating pubsub client: %w", err) } // test the poll func and exit early so we can enter CLBO quickly if there is a startup problem. var timeoutCtx, timeoutCtxCancel = context.WithTimeout(context.Background(), time.Minute) defer timeoutCtxCancel() if banners, err := cfg.PollFunc(timeoutCtx); err != nil { return nil, fmt.Errorf("failed to retrieve banners at startup: %w", err) } else if len(banners) == 0 { return nil, fmt.Errorf("found zero banners at startup") } return &ReceiverMux{ cfg: cfg, client: client, receivers: make(map[string]*Receiver), }, nil } func (rm *ReceiverMux) Run(ctx context.Context) error { log := fog.FromContext(ctx).WithName("banner-mux-run") ctx, stopReceivers := context.WithCancel(ctx) var wg sync.WaitGroup defer func() { log.Info("stopping all receivers") stopReceivers() wg.Wait() log.Info("all receivers have exited") }() var pollTicker = time.NewTicker(rm.cfg.PollPeriod) defer pollTicker.Stop() for { // poll for banner changes added, dropped, err := rm.Poll(ctx) if err != nil { // When the amount of consecutive poll errors equals PollMaxRetries, then an error is returned so ReceiverMux.Run knows to exit. // Otherwise, the poll function logs the error and returns nil. return err } for _, r := range added { wg.Add(1) go func() { defer wg.Done() r.Start(ctx) // Wait until the receiver exits before deleting it from the rm.receivers map. // This ensures concurrent `pubsub.Subscription.Receive` calls never occur on the same Subscription. rm.Lock() defer rm.Unlock() delete(rm.receivers, r.projectID) }() } for _, r := range dropped { // The close function does not block. // It wraps a context.CancelFunc. _ = r.Close() } select { case <-pollTicker.C: case <-ctx.Done(): return nil } } } func (rm *ReceiverMux) Poll(ctx context.Context) (added, dropped []*Receiver, maxRetryError error) { projectIDs, err := rm.cfg.PollFunc(ctx) if err != nil { log := fog.FromContext(ctx).WithName("banner-poll") log.Error(err, "failed to poll banners", "retry_count", len(rm.pollErrs), "max_retries", rm.cfg.PollMaxRetries, ) rm.pollErrs = append(rm.pollErrs, err) if len(rm.pollErrs) > rm.cfg.PollMaxRetries { log.Error(ErrPollMaxRetries, "max retries exhausted", "max_retries", rm.cfg.PollMaxRetries, "errs", rm.pollErrs, ) return nil, nil, errors.Join(ErrPollMaxRetries, errors.Join(rm.pollErrs...)) } // pretend nothing changed. return nil, nil, nil } rm.pollErrs = nil // diffReceivers locks the ReceiverMux and must be performant. // Therefore, the lock defers heavy tasks like spawning goroutines and closing receivers. added, dropped = rm.diffReceivers(projectIDs...) return added, dropped, nil } // NOTE: // The `rm.receivers` field is populated here to prevent lock churn, however the `added` receivers are not started yet. // The ReceiverMux.Run function starts the `added` receivers in their own goroutine without locking. // When the `Receiver.Start` function returns, the goroutine locks ReceiverMux and deletes the receiver from `rm.receivers`. func (rm *ReceiverMux) diffReceivers(projectIDs ...string) (added, dropped []*Receiver) { var keep = make(map[string]bool) for _, projectID := range projectIDs { keep[projectID] = true } rm.Lock() defer rm.Unlock() // create receivers when we find new banners (this also recreates receivers that quit unexpectedly) for _, projectID := range projectIDs { if _, found := rm.receivers[projectID]; found { continue } var pollPeriod = rm.cfg.PollSubscriptionExistsPeriod var sub = rm.client.SubscriptionInProject(rm.cfg.SubscriptionID, projectID) var r = NewReceiver(projectID, sub, rm.cfg.Handler, pollPeriod) added = append(added, r) // add all the new receivers to the map here, but don't start them while ReceiverMux is locked. rm.receivers[projectID] = r } // find banners that have been dropped. for projectID, r := range rm.receivers { if !keep[projectID] { // The run function loops through these and simply calls `r.Close()` which causes the `Receiver.Start` function to return. // Before the goroutine that calls `Receiver.Start` exits, it locks the ReceiverMux and deletes the receiver from `rm.receivers`. dropped = append(dropped, r) } } return added, dropped } // TODO plan on converting the health check's fields into prometheus metrics. type ReceiverMuxHealthCheck struct { Count int HealthyCount int UnhealthyCount int Receivers map[string]ReceiverHealthCheck } func (rm *ReceiverMux) HealthCheck() ReceiverMuxHealthCheck { rm.Lock() defer rm.Unlock() var bmhc = ReceiverMuxHealthCheck{ Count: len(rm.receivers), Receivers: make(map[string]ReceiverHealthCheck), } for projectID, r := range rm.receivers { var rhc = r.HealthCheck() if rhc.Healthy { bmhc.HealthyCount++ } else { bmhc.UnhealthyCount++ } bmhc.Receivers[projectID] = rhc } return bmhc }