...

Source file src/edge-infra.dev/pkg/edge/rollouts/drivers/inmem.go

Documentation: edge-infra.dev/pkg/edge/rollouts/drivers

     1  package drivers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/go-logr/logr"
     9  
    10  	"edge-infra.dev/pkg/edge/rollouts"
    11  	"edge-infra.dev/pkg/edge/rollouts/internal"
    12  )
    13  
    14  var _ RolloutStore = &InMemoryRolloutStore{}
    15  var _ rolloutDriver = &inMemoryDriver{}
    16  
    17  type InMemoryRolloutStore struct {
    18  	Store *internal.InMemStore
    19  }
    20  
    21  func NewInMemoryRolloutStore(store *internal.InMemStore) *InMemoryRolloutStore {
    22  	return &InMemoryRolloutStore{
    23  		Store: store,
    24  	}
    25  }
    26  
    27  func (s *InMemoryRolloutStore) GetRollouts(_ context.Context) ([]*rollouts.RolloutGraph, error) {
    28  	rollouts := []*rollouts.RolloutGraph{}
    29  	for _, bannerRollouts := range s.Store.GraphStore {
    30  		for _, rollout := range bannerRollouts {
    31  			rollouts = append(rollouts, rollout)
    32  		}
    33  	}
    34  	return rollouts, nil
    35  }
    36  
    37  func (s *InMemoryRolloutStore) GetBannerRollouts(_ context.Context, bannerID string) ([]*rollouts.RolloutGraph, error) {
    38  	rollouts := []*rollouts.RolloutGraph{}
    39  
    40  	if bannerRollouts, found := s.Store.GraphStore[bannerID]; found {
    41  		for _, rollout := range bannerRollouts {
    42  			rollouts = append(rollouts, rollout)
    43  		}
    44  		return rollouts, nil
    45  	}
    46  
    47  	return rollouts, fmt.Errorf("banner %s has no rollouts", bannerID)
    48  }
    49  
    50  func (s *InMemoryRolloutStore) GetRollout(_ context.Context, rolloutID string) (*rollouts.RolloutGraph, error) {
    51  	for _, bannerRollouts := range s.Store.GraphStore {
    52  		if rollout, found := bannerRollouts[rolloutID]; found {
    53  			return rollout, nil
    54  		}
    55  	}
    56  	return nil, fmt.Errorf("rollout %s not found", rolloutID)
    57  }
    58  
    59  func (s *InMemoryRolloutStore) GetRolloutPlans(_ context.Context) ([]rollouts.RolloutPlan, error) {
    60  	plans := []rollouts.RolloutPlan{}
    61  	for _, bannerPlans := range s.Store.PlanStore {
    62  		for _, plan := range bannerPlans {
    63  			plans = append(plans, plan)
    64  		}
    65  	}
    66  	return plans, nil
    67  }
    68  
    69  func (s *InMemoryRolloutStore) GetRolloutPlan(_ context.Context, planID string) (rollouts.RolloutPlan, error) {
    70  	for _, bannerPlans := range s.Store.PlanStore {
    71  		if plan, found := bannerPlans[planID]; found {
    72  			return plan, nil
    73  		}
    74  	}
    75  	return rollouts.RolloutPlan{}, fmt.Errorf("plan %s not found", planID)
    76  }
    77  
    78  func (s *InMemoryRolloutStore) GetBannerRolloutPlans(_ context.Context, bannerID string) ([]rollouts.RolloutPlan, error) {
    79  	plans := []rollouts.RolloutPlan{}
    80  	if bannerPlans, found := s.Store.PlanStore[bannerID]; found {
    81  		for _, plan := range bannerPlans {
    82  			plans = append(plans, plan)
    83  		}
    84  	}
    85  	return plans, nil
    86  }
    87  
    88  func (s *InMemoryRolloutStore) OpenApprovalGate(_ context.Context, _ string, nodeKey rollouts.NodeKey) error {
    89  	if _, found := s.Store.ApprovalGateStore[nodeKey]; found {
    90  		s.Store.ApprovalGateStore[nodeKey] = rollouts.GateApproved
    91  		return nil
    92  	}
    93  	return fmt.Errorf("gate %s requested for approval, but not found", nodeKey)
    94  }
    95  
    96  func (s *InMemoryRolloutStore) NewRolloutInstance(_ context.Context, logger logr.Logger, rollout *rollouts.RolloutGraph, resultChan chan rollouts.NodeExecutionResult) (*RolloutInstance, error) {
    97  	driver, err := newInMemoryDriver(s.Store, logger.WithName("inmemory-driver"), rollout)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return &RolloutInstance{
   102  		driver:     driver,
   103  		logger:     logger,
   104  		NotifyChan: resultChan,
   105  	}, nil
   106  }
   107  
   108  type inMemoryDriver struct {
   109  	store   *internal.InMemStore
   110  	logger  logr.Logger
   111  	rollout *rollouts.RolloutGraph
   112  }
   113  
   114  func newInMemoryDriver(store *internal.InMemStore, logger logr.Logger, rollout *rollouts.RolloutGraph) (*inMemoryDriver, error) {
   115  	if store == nil {
   116  		return nil, fmt.Errorf("nil store passed to driver creation")
   117  	}
   118  	if rollout == nil {
   119  		return nil, fmt.Errorf("nil rollout passed to driver creation")
   120  	}
   121  
   122  	return &inMemoryDriver{
   123  		store:   store,
   124  		logger:  logger,
   125  		rollout: rollout,
   126  	}, nil
   127  }
   128  
   129  func (d *inMemoryDriver) getRollout() *rollouts.RolloutGraph {
   130  	return d.rollout
   131  }
   132  
   133  func (d *inMemoryDriver) execute(ctx context.Context, node rollouts.RolloutGraphNode) (rollouts.NodeExecutionResult, error) {
   134  	switch n := node.(type) {
   135  	case *rollouts.TargetGroup:
   136  		return d.executeTargetGroup(ctx, n)
   137  	case *rollouts.TimerGate:
   138  		return d.checkTimerGate(ctx, n)
   139  	case *rollouts.ApprovalGate:
   140  		return d.checkApprovalGate(ctx, n)
   141  	default:
   142  		return rollouts.NodeExecutionResult{
   143  			RolloutState: rollouts.RolloutError,
   144  			Message:      fmt.Sprintf("unknown type for node %v", node),
   145  		}, fmt.Errorf("unknown type for node %v", node)
   146  	}
   147  }
   148  
   149  func (d *inMemoryDriver) matchTargetGroup(_ context.Context, selector string) (rollouts.MatchResult, error) {
   150  	return d.store.GetClusterLabelMatches(selector)
   151  }
   152  
   153  func (d *inMemoryDriver) checkTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) (rollouts.TargetGroupState, error) {
   154  	tgVersionStates, err := d.store.GetTargetGroupVersionStates(artifactName, clusterMatches)
   155  	if err != nil {
   156  		return rollouts.TargetGroupError, fmt.Errorf("could not find states for [%s]", clusterMatches.Matches)
   157  	}
   158  	for _, artifactState := range tgVersionStates {
   159  		if artifactState.Version == artifactVersion && !artifactState.Ready {
   160  			return rollouts.TargetGroupApplied, nil
   161  		}
   162  
   163  		if artifactState.Version != artifactVersion {
   164  			return rollouts.TargetGroupPending, nil
   165  		}
   166  	}
   167  	return rollouts.TargetGroupReady, nil
   168  }
   169  
   170  func (d *inMemoryDriver) applyTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) error { //nolint
   171  	return d.store.ApplyClusterArtifactVersions(artifactName, artifactVersion, clusterMatches)
   172  }
   173  
   174  func (d *inMemoryDriver) executeTargetGroup(ctx context.Context, tg *rollouts.TargetGroup) (rollouts.NodeExecutionResult, error) {
   175  	d.logger.Info(fmt.Sprintln("executing node:", tg.GetLabel()))
   176  	d.logger.Info(fmt.Sprintln("selector:", tg.Selector))
   177  	result := rollouts.NodeExecutionResult{Key: tg.GetKey()}
   178  	switch tg.State {
   179  	case rollouts.Complete:
   180  		result.State = rollouts.Complete
   181  		result.Message = fmt.Sprintf("rollout %s node %s complete", d.rollout.ID, tg.GetKey())
   182  		return result, nil
   183  	// If first time seeing node, (Pending status), set InProgress
   184  	case rollouts.Pending:
   185  		tg.State = rollouts.InProgress
   186  	case rollouts.InProgress:
   187  		break
   188  	default:
   189  		return result, fmt.Errorf("unknown NodeState %s", tg.State)
   190  	}
   191  
   192  	// steps:
   193  	// 1. check if the targets in the group are at the desired version or not. if yes - return
   194  	// 2. set targets to the desired version / apply the desired configuration
   195  
   196  	matches, err := d.matchTargetGroup(ctx, tg.Selector)
   197  	if err != nil {
   198  		result.RolloutState = rollouts.RolloutError
   199  		result.Message = fmt.Sprintf("rollout %s failed to match target group %s for selector %s", d.rollout.ID, tg.GetKey(), tg.Selector)
   200  		return result, err
   201  	}
   202  
   203  	tgState, err := d.checkTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches)
   204  	if err != nil {
   205  		result.RolloutState = rollouts.RolloutError
   206  		result.Message = fmt.Sprintf("rollout %s failed to check target group %s state", d.rollout.ID, tg.GetKey())
   207  		return result, err
   208  	}
   209  
   210  	switch tgState {
   211  	case rollouts.TargetGroupError:
   212  		result.Message = fmt.Sprintf("Target Group %s in error", tg.GetKey())
   213  		result.RolloutState = rollouts.RolloutError
   214  		return result, err
   215  	case rollouts.TargetGroupApplied:
   216  		result.Message = fmt.Sprintf("Target Group %s waiting for completion", tg.GetKey())
   217  		return result, nil
   218  	case rollouts.TargetGroupReady:
   219  		tg.State = rollouts.Complete
   220  		result.State = rollouts.Complete
   221  		return result, nil
   222  	}
   223  
   224  	if err := d.applyTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches); err == nil {
   225  		tg.State = rollouts.InProgress
   226  		result.State = rollouts.InProgress
   227  		result.Message = fmt.Sprintf("Target Group %s applied", tg.GetKey())
   228  	}
   229  
   230  	// node has not finished yet, but no error occurred
   231  	return result, nil
   232  }
   233  
   234  func (d *inMemoryDriver) checkTimerGate(_ context.Context, tg *rollouts.TimerGate) (rollouts.NodeExecutionResult, error) {
   235  	tg.State = rollouts.Pending
   236  	open := tg.IsOpen()
   237  	result := rollouts.NodeExecutionResult{Key: tg.GetKey()}
   238  	result.Message = fmt.Sprintf("rollout %s timer gate %s", d.rollout.ID, tg.GetKey())
   239  	if open {
   240  		tg.State = rollouts.Complete
   241  		result.State = rollouts.Complete
   242  		result.Message = fmt.Sprintf("%s is open", result.Message)
   243  	} else {
   244  		timeLeft := tg.StartTime.Add(tg.Delay).UTC().Format(time.RFC3339)
   245  		result.Message = fmt.Sprintf("%s is closed and will open at %s", result.Message, timeLeft)
   246  	}
   247  	return result, nil
   248  }
   249  
   250  func (d *inMemoryDriver) checkApprovalGate(_ context.Context, ag *rollouts.ApprovalGate) (rollouts.NodeExecutionResult, error) {
   251  	result := rollouts.NodeExecutionResult{Key: ag.GetKey()}
   252  	if ag.IsOpen() {
   253  		result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey())
   254  		result.State = rollouts.Complete
   255  		return result, nil
   256  	}
   257  
   258  	gateStatus, err := d.store.GetApprovalGateStatus(ag)
   259  	if err != nil {
   260  		result.Message = fmt.Sprintf("rollout %s approval gate %s status not found", d.rollout.ID, ag.GetKey())
   261  		return result, fmt.Errorf("did not find gate %s", string(ag.GetKey()))
   262  	}
   263  
   264  	switch gateStatus {
   265  	case rollouts.GateApproved:
   266  		ag.GateState = rollouts.Open
   267  		ag.State = rollouts.Complete
   268  		result.State = rollouts.Complete
   269  		result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey())
   270  		return result, nil
   271  	case rollouts.GatePending:
   272  		ag.GateState = rollouts.Closed
   273  		ag.State = rollouts.Pending
   274  		result.State = rollouts.Pending
   275  		result.Message = fmt.Sprintf("rollout %s approval gate %s is pending", d.rollout.ID, ag.GetKey())
   276  		return result, nil
   277  	case rollouts.GateDenied:
   278  		ag.GateState = rollouts.Closed
   279  		result.State = rollouts.Halted
   280  		result.Message = fmt.Sprintf("rollout %s approval gate %s is denied, rollout halted", d.rollout.ID, ag.GetKey())
   281  		return result, nil
   282  	default:
   283  		result.Message = fmt.Sprintf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey())
   284  		return result, fmt.Errorf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey())
   285  	}
   286  }
   287  
   288  // The graph store for an in memory store is a map to pointers, so all changes are
   289  // automatically persisted
   290  func (d *inMemoryDriver) persistGraph(_ context.Context) error {
   291  	return nil
   292  }
   293  

View as plain text