package drivers import ( "context" "fmt" "time" "github.com/go-logr/logr" "edge-infra.dev/pkg/edge/rollouts" "edge-infra.dev/pkg/edge/rollouts/internal" ) var _ RolloutStore = &InMemoryRolloutStore{} var _ rolloutDriver = &inMemoryDriver{} type InMemoryRolloutStore struct { Store *internal.InMemStore } func NewInMemoryRolloutStore(store *internal.InMemStore) *InMemoryRolloutStore { return &InMemoryRolloutStore{ Store: store, } } func (s *InMemoryRolloutStore) GetRollouts(_ context.Context) ([]*rollouts.RolloutGraph, error) { rollouts := []*rollouts.RolloutGraph{} for _, bannerRollouts := range s.Store.GraphStore { for _, rollout := range bannerRollouts { rollouts = append(rollouts, rollout) } } return rollouts, nil } func (s *InMemoryRolloutStore) GetBannerRollouts(_ context.Context, bannerID string) ([]*rollouts.RolloutGraph, error) { rollouts := []*rollouts.RolloutGraph{} if bannerRollouts, found := s.Store.GraphStore[bannerID]; found { for _, rollout := range bannerRollouts { rollouts = append(rollouts, rollout) } return rollouts, nil } return rollouts, fmt.Errorf("banner %s has no rollouts", bannerID) } func (s *InMemoryRolloutStore) GetRollout(_ context.Context, rolloutID string) (*rollouts.RolloutGraph, error) { for _, bannerRollouts := range s.Store.GraphStore { if rollout, found := bannerRollouts[rolloutID]; found { return rollout, nil } } return nil, fmt.Errorf("rollout %s not found", rolloutID) } func (s *InMemoryRolloutStore) GetRolloutPlans(_ context.Context) ([]rollouts.RolloutPlan, error) { plans := []rollouts.RolloutPlan{} for _, bannerPlans := range s.Store.PlanStore { for _, plan := range bannerPlans { plans = append(plans, plan) } } return plans, nil } func (s *InMemoryRolloutStore) GetRolloutPlan(_ context.Context, planID string) (rollouts.RolloutPlan, error) { for _, bannerPlans := range s.Store.PlanStore { if plan, found := bannerPlans[planID]; found { return plan, nil } } return rollouts.RolloutPlan{}, fmt.Errorf("plan %s not found", planID) } func (s *InMemoryRolloutStore) GetBannerRolloutPlans(_ context.Context, bannerID string) ([]rollouts.RolloutPlan, error) { plans := []rollouts.RolloutPlan{} if bannerPlans, found := s.Store.PlanStore[bannerID]; found { for _, plan := range bannerPlans { plans = append(plans, plan) } } return plans, nil } func (s *InMemoryRolloutStore) OpenApprovalGate(_ context.Context, _ string, nodeKey rollouts.NodeKey) error { if _, found := s.Store.ApprovalGateStore[nodeKey]; found { s.Store.ApprovalGateStore[nodeKey] = rollouts.GateApproved return nil } return fmt.Errorf("gate %s requested for approval, but not found", nodeKey) } func (s *InMemoryRolloutStore) NewRolloutInstance(_ context.Context, logger logr.Logger, rollout *rollouts.RolloutGraph, resultChan chan rollouts.NodeExecutionResult) (*RolloutInstance, error) { driver, err := newInMemoryDriver(s.Store, logger.WithName("inmemory-driver"), rollout) if err != nil { return nil, err } return &RolloutInstance{ driver: driver, logger: logger, NotifyChan: resultChan, }, nil } type inMemoryDriver struct { store *internal.InMemStore logger logr.Logger rollout *rollouts.RolloutGraph } func newInMemoryDriver(store *internal.InMemStore, logger logr.Logger, rollout *rollouts.RolloutGraph) (*inMemoryDriver, error) { if store == nil { return nil, fmt.Errorf("nil store passed to driver creation") } if rollout == nil { return nil, fmt.Errorf("nil rollout passed to driver creation") } return &inMemoryDriver{ store: store, logger: logger, rollout: rollout, }, nil } func (d *inMemoryDriver) getRollout() *rollouts.RolloutGraph { return d.rollout } func (d *inMemoryDriver) execute(ctx context.Context, node rollouts.RolloutGraphNode) (rollouts.NodeExecutionResult, error) { switch n := node.(type) { case *rollouts.TargetGroup: return d.executeTargetGroup(ctx, n) case *rollouts.TimerGate: return d.checkTimerGate(ctx, n) case *rollouts.ApprovalGate: return d.checkApprovalGate(ctx, n) default: return rollouts.NodeExecutionResult{ RolloutState: rollouts.RolloutError, Message: fmt.Sprintf("unknown type for node %v", node), }, fmt.Errorf("unknown type for node %v", node) } } func (d *inMemoryDriver) matchTargetGroup(_ context.Context, selector string) (rollouts.MatchResult, error) { return d.store.GetClusterLabelMatches(selector) } func (d *inMemoryDriver) checkTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) (rollouts.TargetGroupState, error) { tgVersionStates, err := d.store.GetTargetGroupVersionStates(artifactName, clusterMatches) if err != nil { return rollouts.TargetGroupError, fmt.Errorf("could not find states for [%s]", clusterMatches.Matches) } for _, artifactState := range tgVersionStates { if artifactState.Version == artifactVersion && !artifactState.Ready { return rollouts.TargetGroupApplied, nil } if artifactState.Version != artifactVersion { return rollouts.TargetGroupPending, nil } } return rollouts.TargetGroupReady, nil } func (d *inMemoryDriver) applyTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) error { //nolint return d.store.ApplyClusterArtifactVersions(artifactName, artifactVersion, clusterMatches) } func (d *inMemoryDriver) executeTargetGroup(ctx context.Context, tg *rollouts.TargetGroup) (rollouts.NodeExecutionResult, error) { d.logger.Info(fmt.Sprintln("executing node:", tg.GetLabel())) d.logger.Info(fmt.Sprintln("selector:", tg.Selector)) result := rollouts.NodeExecutionResult{Key: tg.GetKey()} switch tg.State { case rollouts.Complete: result.State = rollouts.Complete result.Message = fmt.Sprintf("rollout %s node %s complete", d.rollout.ID, tg.GetKey()) return result, nil // If first time seeing node, (Pending status), set InProgress case rollouts.Pending: tg.State = rollouts.InProgress case rollouts.InProgress: break default: return result, fmt.Errorf("unknown NodeState %s", tg.State) } // steps: // 1. check if the targets in the group are at the desired version or not. if yes - return // 2. set targets to the desired version / apply the desired configuration matches, err := d.matchTargetGroup(ctx, tg.Selector) if err != nil { result.RolloutState = rollouts.RolloutError result.Message = fmt.Sprintf("rollout %s failed to match target group %s for selector %s", d.rollout.ID, tg.GetKey(), tg.Selector) return result, err } tgState, err := d.checkTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches) if err != nil { result.RolloutState = rollouts.RolloutError result.Message = fmt.Sprintf("rollout %s failed to check target group %s state", d.rollout.ID, tg.GetKey()) return result, err } switch tgState { case rollouts.TargetGroupError: result.Message = fmt.Sprintf("Target Group %s in error", tg.GetKey()) result.RolloutState = rollouts.RolloutError return result, err case rollouts.TargetGroupApplied: result.Message = fmt.Sprintf("Target Group %s waiting for completion", tg.GetKey()) return result, nil case rollouts.TargetGroupReady: tg.State = rollouts.Complete result.State = rollouts.Complete return result, nil } if err := d.applyTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches); err == nil { tg.State = rollouts.InProgress result.State = rollouts.InProgress result.Message = fmt.Sprintf("Target Group %s applied", tg.GetKey()) } // node has not finished yet, but no error occurred return result, nil } func (d *inMemoryDriver) checkTimerGate(_ context.Context, tg *rollouts.TimerGate) (rollouts.NodeExecutionResult, error) { tg.State = rollouts.Pending open := tg.IsOpen() result := rollouts.NodeExecutionResult{Key: tg.GetKey()} result.Message = fmt.Sprintf("rollout %s timer gate %s", d.rollout.ID, tg.GetKey()) if open { tg.State = rollouts.Complete result.State = rollouts.Complete result.Message = fmt.Sprintf("%s is open", result.Message) } else { timeLeft := tg.StartTime.Add(tg.Delay).UTC().Format(time.RFC3339) result.Message = fmt.Sprintf("%s is closed and will open at %s", result.Message, timeLeft) } return result, nil } func (d *inMemoryDriver) checkApprovalGate(_ context.Context, ag *rollouts.ApprovalGate) (rollouts.NodeExecutionResult, error) { result := rollouts.NodeExecutionResult{Key: ag.GetKey()} if ag.IsOpen() { result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey()) result.State = rollouts.Complete return result, nil } gateStatus, err := d.store.GetApprovalGateStatus(ag) if err != nil { result.Message = fmt.Sprintf("rollout %s approval gate %s status not found", d.rollout.ID, ag.GetKey()) return result, fmt.Errorf("did not find gate %s", string(ag.GetKey())) } switch gateStatus { case rollouts.GateApproved: ag.GateState = rollouts.Open ag.State = rollouts.Complete result.State = rollouts.Complete result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey()) return result, nil case rollouts.GatePending: ag.GateState = rollouts.Closed ag.State = rollouts.Pending result.State = rollouts.Pending result.Message = fmt.Sprintf("rollout %s approval gate %s is pending", d.rollout.ID, ag.GetKey()) return result, nil case rollouts.GateDenied: ag.GateState = rollouts.Closed result.State = rollouts.Halted result.Message = fmt.Sprintf("rollout %s approval gate %s is denied, rollout halted", d.rollout.ID, ag.GetKey()) return result, nil default: result.Message = fmt.Sprintf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey()) return result, fmt.Errorf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey()) } } // The graph store for an in memory store is a map to pointers, so all changes are // automatically persisted func (d *inMemoryDriver) persistGraph(_ context.Context) error { return nil }