1 package drivers
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7
8 "github.com/go-logr/logr"
9
10 "edge-infra.dev/pkg/edge/rollouts"
11 )
12
13 type RolloutStore interface {
14 GetRollout(ctx context.Context, rolloutID string) (*rollouts.RolloutGraph, error)
15 GetRollouts(ctx context.Context) ([]*rollouts.RolloutGraph, error)
16 GetBannerRollouts(ctx context.Context, bannerID string) ([]*rollouts.RolloutGraph, error)
17 GetRolloutPlan(ctx context.Context, planID string) (rollouts.RolloutPlan, error)
18 GetRolloutPlans(ctx context.Context) ([]rollouts.RolloutPlan, error)
19 GetBannerRolloutPlans(ctx context.Context, bannerID string) ([]rollouts.RolloutPlan, error)
20 OpenApprovalGate(ctx context.Context, rolloutID string, nodeKey rollouts.NodeKey) error
21 NewRolloutInstance(ctx context.Context, logger logr.Logger, rollout *rollouts.RolloutGraph, notifyChan chan rollouts.NodeExecutionResult) (*RolloutInstance, error)
22 }
23
24 type RolloutInstance struct {
25 driver rolloutDriver
26 logger logr.Logger
27 NotifyChan chan rollouts.NodeExecutionResult
28 }
29
30 func (ri *RolloutInstance) sendEvent(event rollouts.NodeExecutionResult) {
31 ri.NotifyChan <- event
32 }
33
34 type rolloutDriver interface {
35 getRollout() *rollouts.RolloutGraph
36 execute(ctx context.Context, node rollouts.RolloutGraphNode) (rollouts.NodeExecutionResult, error)
37 executeTargetGroup(ctx context.Context, tg *rollouts.TargetGroup) (rollouts.NodeExecutionResult, error)
38 checkTimerGate(ctx context.Context, tg *rollouts.TimerGate) (rollouts.NodeExecutionResult, error)
39 checkApprovalGate(ctx context.Context, ag *rollouts.ApprovalGate) (rollouts.NodeExecutionResult, error)
40 persistGraph(ctx context.Context) error
41 }
42
43 func RunRollout(ctx context.Context, ri *RolloutInstance) (bool, error) {
44
45 done, results, err := processStep(ctx, ri)
46 if err != nil {
47 return false, err
48 }
49
50
51
52
53 if err = ri.driver.persistGraph(ctx); err != nil {
54 return false, err
55 }
56
57 if done {
58
59 doneResult := rollouts.NodeExecutionResult{
60 Message: fmt.Sprintf("rollout %s complete", ri.driver.getRollout().ID),
61 RolloutState: rollouts.RolloutComplete,
62 }
63 ri.sendEvent(doneResult)
64 return true, nil
65 }
66
67 for _, result := range results {
68 result.RolloutState = rollouts.RolloutInProgress
69 ri.sendEvent(result)
70 }
71
72
73 for _, currKey := range ri.driver.getRollout().Current {
74 curr := ri.driver.getRollout().Nodes[currKey]
75 allDepsComplete := true
76 var incompleteDep rollouts.RolloutGraphNode
77 for _, depOfCurr := range curr.GetDependsOn() {
78 if depOfCurr.GetState() != rollouts.Complete {
79 allDepsComplete = false
80 incompleteDep = depOfCurr
81 }
82 }
83 if !allDepsComplete && curr.GetState() == rollouts.InProgress {
84 return false, fmt.Errorf("node %s was started without finishing its deps first. incomplete dep: %s",
85 curr.GetKey(), incompleteDep.GetKey())
86 }
87 }
88
89 return false, nil
90 }
91
92 func processStep(ctx context.Context, ri *RolloutInstance) (bool, []rollouts.NodeExecutionResult, error) {
93 nextNodes := map[rollouts.NodeKey]rollouts.RolloutGraphNode{}
94 results := []rollouts.NodeExecutionResult{}
95
96
97 ri.logger.Info(fmt.Sprint("Processing step. Graph:", ri.driver.getRollout().ID))
98
99
100
101 for _, currKey := range ri.driver.getRollout().Current {
102
103 curr := ri.driver.getRollout().Nodes[currKey]
104 allDepsComplete := true
105 for _, dep := range curr.GetDependsOn() {
106
107 if dep.GetState() != rollouts.Complete {
108 allDepsComplete = false
109 break
110 }
111
112
113
114
115
116 }
117 if !allDepsComplete {
118 continue
119 }
120
121 result, err := ri.driver.execute(ctx, curr)
122 if err != nil {
123
124 return false, nil, err
125 }
126 results = append(results, result)
127
128
129 ri.driver.getRollout().NodeExecutionResults[curr.GetKey()] = result
130
131 success := result.Done()
132 if success {
133
134 for _, maybeNext := range curr.GetNext() {
135 if _, seen := nextNodes[maybeNext.GetKey()]; !seen {
136 ri.logger.Info(fmt.Sprintf("first time seeing %s, adding to next\n", maybeNext.GetKey()))
137 nextNodes[maybeNext.GetKey()] = maybeNext
138 }
139 }
140 } else {
141
142 nextNodes[curr.GetKey()] = curr
143 }
144 }
145
146
147 nextKeys := []rollouts.NodeKey{}
148 for _, n := range nextNodes {
149 nextKeys = append(nextKeys, n.GetKey())
150 }
151 ri.driver.getRollout().Current = nextKeys
152
153
154 if len(ri.driver.getRollout().Current) == 0 {
155
156
157 return true, results, nil
158 }
159
160
161 nextNodesStrs := []string{}
162 for _, n := range nextNodes {
163 nextNodesStrs = append(nextNodesStrs, fmt.Sprintf("%s (%s)", n.GetLabel(), n.GetKey()))
164 }
165
166
167 ri.logger.Info(fmt.Sprintln("Done processing step. Current node(s):", fmt.Sprintf("[%v]", strings.Join(nextNodesStrs, ","))))
168
169 return false, results, nil
170 }
171
View as plain text