...

Source file src/edge-infra.dev/pkg/edge/rollouts/drivers/driver.go

Documentation: edge-infra.dev/pkg/edge/rollouts/drivers

     1  package drivers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/go-logr/logr"
     9  
    10  	"edge-infra.dev/pkg/edge/rollouts"
    11  )
    12  
    13  type RolloutStore interface {
    14  	GetRollout(ctx context.Context, rolloutID string) (*rollouts.RolloutGraph, error)
    15  	GetRollouts(ctx context.Context) ([]*rollouts.RolloutGraph, error)
    16  	GetBannerRollouts(ctx context.Context, bannerID string) ([]*rollouts.RolloutGraph, error)
    17  	GetRolloutPlan(ctx context.Context, planID string) (rollouts.RolloutPlan, error)
    18  	GetRolloutPlans(ctx context.Context) ([]rollouts.RolloutPlan, error)
    19  	GetBannerRolloutPlans(ctx context.Context, bannerID string) ([]rollouts.RolloutPlan, error)
    20  	OpenApprovalGate(ctx context.Context, rolloutID string, nodeKey rollouts.NodeKey) error
    21  	NewRolloutInstance(ctx context.Context, logger logr.Logger, rollout *rollouts.RolloutGraph, notifyChan chan rollouts.NodeExecutionResult) (*RolloutInstance, error)
    22  }
    23  
    24  type RolloutInstance struct {
    25  	driver     rolloutDriver
    26  	logger     logr.Logger
    27  	NotifyChan chan rollouts.NodeExecutionResult
    28  }
    29  
    30  func (ri *RolloutInstance) sendEvent(event rollouts.NodeExecutionResult) {
    31  	ri.NotifyChan <- event
    32  }
    33  
    34  type rolloutDriver interface {
    35  	getRollout() *rollouts.RolloutGraph
    36  	execute(ctx context.Context, node rollouts.RolloutGraphNode) (rollouts.NodeExecutionResult, error)
    37  	executeTargetGroup(ctx context.Context, tg *rollouts.TargetGroup) (rollouts.NodeExecutionResult, error)
    38  	checkTimerGate(ctx context.Context, tg *rollouts.TimerGate) (rollouts.NodeExecutionResult, error)
    39  	checkApprovalGate(ctx context.Context, ag *rollouts.ApprovalGate) (rollouts.NodeExecutionResult, error)
    40  	persistGraph(ctx context.Context) error
    41  }
    42  
    43  func RunRollout(ctx context.Context, ri *RolloutInstance) (bool, error) {
    44  	// start the rollout
    45  	done, results, err := processStep(ctx, ri)
    46  	if err != nil {
    47  		return false, err
    48  	}
    49  
    50  	// TODO: persist here
    51  	// So far inmem doesn't do anything, but SQL instance would update the jsonb
    52  	// field with the current state of the graph
    53  	if err = ri.driver.persistGraph(ctx); err != nil {
    54  		return false, err
    55  	}
    56  
    57  	if done {
    58  		// doneEvent := events.RolloutEvent{Rollout: rollout, RolloutState: events.RolloutComplete}
    59  		doneResult := rollouts.NodeExecutionResult{
    60  			Message:      fmt.Sprintf("rollout %s complete", ri.driver.getRollout().ID),
    61  			RolloutState: rollouts.RolloutComplete,
    62  		}
    63  		ri.sendEvent(doneResult)
    64  		return true, nil
    65  	}
    66  
    67  	for _, result := range results {
    68  		result.RolloutState = rollouts.RolloutInProgress
    69  		ri.sendEvent(result)
    70  	}
    71  
    72  	// nodes should not be in-progress if their dependencies have not finished
    73  	for _, currKey := range ri.driver.getRollout().Current {
    74  		curr := ri.driver.getRollout().Nodes[currKey]
    75  		allDepsComplete := true
    76  		var incompleteDep rollouts.RolloutGraphNode
    77  		for _, depOfCurr := range curr.GetDependsOn() {
    78  			if depOfCurr.GetState() != rollouts.Complete {
    79  				allDepsComplete = false
    80  				incompleteDep = depOfCurr
    81  			}
    82  		}
    83  		if !allDepsComplete && curr.GetState() == rollouts.InProgress {
    84  			return false, fmt.Errorf("node %s was started without finishing its deps first. incomplete dep: %s",
    85  				curr.GetKey(), incompleteDep.GetKey())
    86  		}
    87  	}
    88  
    89  	return false, nil
    90  }
    91  
    92  func processStep(ctx context.Context, ri *RolloutInstance) (bool, []rollouts.NodeExecutionResult, error) {
    93  	nextNodes := map[rollouts.NodeKey]rollouts.RolloutGraphNode{}
    94  	results := []rollouts.NodeExecutionResult{}
    95  
    96  	// TODO(edge-foundation): Replace/remove print debugging beflre merging
    97  	ri.logger.Info(fmt.Sprint("Processing step. Graph:", ri.driver.getRollout().ID))
    98  
    99  	// there can be multiple "current" nodes, ie in a fan-out style graph
   100  	// TODO: phase 2+ parallelism
   101  	for _, currKey := range ri.driver.getRollout().Current {
   102  		// check if all of the current node's deps are met before proceeding
   103  		curr := ri.driver.getRollout().Nodes[currKey]
   104  		allDepsComplete := true
   105  		for _, dep := range curr.GetDependsOn() {
   106  			// approach 1: let nodes track their state. ie, pending, in-progress, or complete
   107  			if dep.GetState() != rollouts.Complete {
   108  				allDepsComplete = false
   109  				break
   110  			}
   111  			// approach 2: track node results in graph. nodes completely stateless
   112  			// result, ok := g.NodeExecutionResults[dep.Key()]
   113  			// if !ok || !(result.Done()) {
   114  			// 	continue
   115  			// }
   116  		}
   117  		if !allDepsComplete {
   118  			continue
   119  		}
   120  
   121  		result, err := ri.driver.execute(ctx, curr)
   122  		if err != nil {
   123  			// handle error. log, return, and make sure caller can rollback and transactions etc
   124  			return false, nil, err
   125  		}
   126  		results = append(results, result)
   127  
   128  		// if the execution was successful, save that result in the graph by marking the edge as succeeded
   129  		ri.driver.getRollout().NodeExecutionResults[curr.GetKey()] = result
   130  
   131  		success := result.Done()
   132  		if success {
   133  			// if node succeeds, and add its Next nodes to current, without duplicates (eg > 1 incoming edges)
   134  			for _, maybeNext := range curr.GetNext() {
   135  				if _, seen := nextNodes[maybeNext.GetKey()]; !seen {
   136  					ri.logger.Info(fmt.Sprintf("first time seeing %s, adding to next\n", maybeNext.GetKey()))
   137  					nextNodes[maybeNext.GetKey()] = maybeNext
   138  				}
   139  			}
   140  		} else {
   141  			// if node is not successful yet, add it back to the list to process next time
   142  			nextNodes[curr.GetKey()] = curr
   143  		}
   144  	}
   145  
   146  	// set current to list of next nodes to process, only adding nodes whos deps are all complete
   147  	nextKeys := []rollouts.NodeKey{}
   148  	for _, n := range nextNodes {
   149  		nextKeys = append(nextKeys, n.GetKey())
   150  	}
   151  	ri.driver.getRollout().Current = nextKeys
   152  
   153  	// if there are no more nodes, it is complete
   154  	if len(ri.driver.getRollout().Current) == 0 {
   155  		// TODO consider removing the overall "Complete"ness state. track in db / external to graph. callers should
   156  		// use the return value of process step to determine completeness/doneness
   157  		return true, results, nil
   158  	}
   159  
   160  	// stringify just for debug logging
   161  	nextNodesStrs := []string{}
   162  	for _, n := range nextNodes {
   163  		nextNodesStrs = append(nextNodesStrs, fmt.Sprintf("%s (%s)", n.GetLabel(), n.GetKey()))
   164  	}
   165  	// TODO(edge-foundation): Replace/remove print debugging before merging
   166  	// TODO: this still prints out even if the "new" current == last current. only log if it changes?
   167  	ri.logger.Info(fmt.Sprintln("Done processing step. Current node(s):", fmt.Sprintf("[%v]", strings.Join(nextNodesStrs, ","))))
   168  
   169  	return false, results, nil
   170  }
   171  

View as plain text