1 package drivers
2
3 import (
4 "context"
5 "fmt"
6 "time"
7
8 "github.com/go-logr/logr"
9
10 "edge-infra.dev/pkg/edge/rollouts"
11 "edge-infra.dev/pkg/edge/rollouts/internal"
12 )
13
14 var _ RolloutStore = &InMemoryRolloutStore{}
15 var _ rolloutDriver = &inMemoryDriver{}
16
17 type InMemoryRolloutStore struct {
18 Store *internal.InMemStore
19 }
20
21 func NewInMemoryRolloutStore(store *internal.InMemStore) *InMemoryRolloutStore {
22 return &InMemoryRolloutStore{
23 Store: store,
24 }
25 }
26
27 func (s *InMemoryRolloutStore) GetRollouts(_ context.Context) ([]*rollouts.RolloutGraph, error) {
28 rollouts := []*rollouts.RolloutGraph{}
29 for _, bannerRollouts := range s.Store.GraphStore {
30 for _, rollout := range bannerRollouts {
31 rollouts = append(rollouts, rollout)
32 }
33 }
34 return rollouts, nil
35 }
36
37 func (s *InMemoryRolloutStore) GetBannerRollouts(_ context.Context, bannerID string) ([]*rollouts.RolloutGraph, error) {
38 rollouts := []*rollouts.RolloutGraph{}
39
40 if bannerRollouts, found := s.Store.GraphStore[bannerID]; found {
41 for _, rollout := range bannerRollouts {
42 rollouts = append(rollouts, rollout)
43 }
44 return rollouts, nil
45 }
46
47 return rollouts, fmt.Errorf("banner %s has no rollouts", bannerID)
48 }
49
50 func (s *InMemoryRolloutStore) GetRollout(_ context.Context, rolloutID string) (*rollouts.RolloutGraph, error) {
51 for _, bannerRollouts := range s.Store.GraphStore {
52 if rollout, found := bannerRollouts[rolloutID]; found {
53 return rollout, nil
54 }
55 }
56 return nil, fmt.Errorf("rollout %s not found", rolloutID)
57 }
58
59 func (s *InMemoryRolloutStore) GetRolloutPlans(_ context.Context) ([]rollouts.RolloutPlan, error) {
60 plans := []rollouts.RolloutPlan{}
61 for _, bannerPlans := range s.Store.PlanStore {
62 for _, plan := range bannerPlans {
63 plans = append(plans, plan)
64 }
65 }
66 return plans, nil
67 }
68
69 func (s *InMemoryRolloutStore) GetRolloutPlan(_ context.Context, planID string) (rollouts.RolloutPlan, error) {
70 for _, bannerPlans := range s.Store.PlanStore {
71 if plan, found := bannerPlans[planID]; found {
72 return plan, nil
73 }
74 }
75 return rollouts.RolloutPlan{}, fmt.Errorf("plan %s not found", planID)
76 }
77
78 func (s *InMemoryRolloutStore) GetBannerRolloutPlans(_ context.Context, bannerID string) ([]rollouts.RolloutPlan, error) {
79 plans := []rollouts.RolloutPlan{}
80 if bannerPlans, found := s.Store.PlanStore[bannerID]; found {
81 for _, plan := range bannerPlans {
82 plans = append(plans, plan)
83 }
84 }
85 return plans, nil
86 }
87
88 func (s *InMemoryRolloutStore) OpenApprovalGate(_ context.Context, _ string, nodeKey rollouts.NodeKey) error {
89 if _, found := s.Store.ApprovalGateStore[nodeKey]; found {
90 s.Store.ApprovalGateStore[nodeKey] = rollouts.GateApproved
91 return nil
92 }
93 return fmt.Errorf("gate %s requested for approval, but not found", nodeKey)
94 }
95
96 func (s *InMemoryRolloutStore) NewRolloutInstance(_ context.Context, logger logr.Logger, rollout *rollouts.RolloutGraph, resultChan chan rollouts.NodeExecutionResult) (*RolloutInstance, error) {
97 driver, err := newInMemoryDriver(s.Store, logger.WithName("inmemory-driver"), rollout)
98 if err != nil {
99 return nil, err
100 }
101 return &RolloutInstance{
102 driver: driver,
103 logger: logger,
104 NotifyChan: resultChan,
105 }, nil
106 }
107
108 type inMemoryDriver struct {
109 store *internal.InMemStore
110 logger logr.Logger
111 rollout *rollouts.RolloutGraph
112 }
113
114 func newInMemoryDriver(store *internal.InMemStore, logger logr.Logger, rollout *rollouts.RolloutGraph) (*inMemoryDriver, error) {
115 if store == nil {
116 return nil, fmt.Errorf("nil store passed to driver creation")
117 }
118 if rollout == nil {
119 return nil, fmt.Errorf("nil rollout passed to driver creation")
120 }
121
122 return &inMemoryDriver{
123 store: store,
124 logger: logger,
125 rollout: rollout,
126 }, nil
127 }
128
129 func (d *inMemoryDriver) getRollout() *rollouts.RolloutGraph {
130 return d.rollout
131 }
132
133 func (d *inMemoryDriver) execute(ctx context.Context, node rollouts.RolloutGraphNode) (rollouts.NodeExecutionResult, error) {
134 switch n := node.(type) {
135 case *rollouts.TargetGroup:
136 return d.executeTargetGroup(ctx, n)
137 case *rollouts.TimerGate:
138 return d.checkTimerGate(ctx, n)
139 case *rollouts.ApprovalGate:
140 return d.checkApprovalGate(ctx, n)
141 default:
142 return rollouts.NodeExecutionResult{
143 RolloutState: rollouts.RolloutError,
144 Message: fmt.Sprintf("unknown type for node %v", node),
145 }, fmt.Errorf("unknown type for node %v", node)
146 }
147 }
148
149 func (d *inMemoryDriver) matchTargetGroup(_ context.Context, selector string) (rollouts.MatchResult, error) {
150 return d.store.GetClusterLabelMatches(selector)
151 }
152
153 func (d *inMemoryDriver) checkTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) (rollouts.TargetGroupState, error) {
154 tgVersionStates, err := d.store.GetTargetGroupVersionStates(artifactName, clusterMatches)
155 if err != nil {
156 return rollouts.TargetGroupError, fmt.Errorf("could not find states for [%s]", clusterMatches.Matches)
157 }
158 for _, artifactState := range tgVersionStates {
159 if artifactState.Version == artifactVersion && !artifactState.Ready {
160 return rollouts.TargetGroupApplied, nil
161 }
162
163 if artifactState.Version != artifactVersion {
164 return rollouts.TargetGroupPending, nil
165 }
166 }
167 return rollouts.TargetGroupReady, nil
168 }
169
170 func (d *inMemoryDriver) applyTargetGroup(_ context.Context, artifactName string, artifactVersion string, clusterMatches rollouts.MatchResult) error {
171 return d.store.ApplyClusterArtifactVersions(artifactName, artifactVersion, clusterMatches)
172 }
173
174 func (d *inMemoryDriver) executeTargetGroup(ctx context.Context, tg *rollouts.TargetGroup) (rollouts.NodeExecutionResult, error) {
175 d.logger.Info(fmt.Sprintln("executing node:", tg.GetLabel()))
176 d.logger.Info(fmt.Sprintln("selector:", tg.Selector))
177 result := rollouts.NodeExecutionResult{Key: tg.GetKey()}
178 switch tg.State {
179 case rollouts.Complete:
180 result.State = rollouts.Complete
181 result.Message = fmt.Sprintf("rollout %s node %s complete", d.rollout.ID, tg.GetKey())
182 return result, nil
183
184 case rollouts.Pending:
185 tg.State = rollouts.InProgress
186 case rollouts.InProgress:
187 break
188 default:
189 return result, fmt.Errorf("unknown NodeState %s", tg.State)
190 }
191
192
193
194
195
196 matches, err := d.matchTargetGroup(ctx, tg.Selector)
197 if err != nil {
198 result.RolloutState = rollouts.RolloutError
199 result.Message = fmt.Sprintf("rollout %s failed to match target group %s for selector %s", d.rollout.ID, tg.GetKey(), tg.Selector)
200 return result, err
201 }
202
203 tgState, err := d.checkTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches)
204 if err != nil {
205 result.RolloutState = rollouts.RolloutError
206 result.Message = fmt.Sprintf("rollout %s failed to check target group %s state", d.rollout.ID, tg.GetKey())
207 return result, err
208 }
209
210 switch tgState {
211 case rollouts.TargetGroupError:
212 result.Message = fmt.Sprintf("Target Group %s in error", tg.GetKey())
213 result.RolloutState = rollouts.RolloutError
214 return result, err
215 case rollouts.TargetGroupApplied:
216 result.Message = fmt.Sprintf("Target Group %s waiting for completion", tg.GetKey())
217 return result, nil
218 case rollouts.TargetGroupReady:
219 tg.State = rollouts.Complete
220 result.State = rollouts.Complete
221 return result, nil
222 }
223
224 if err := d.applyTargetGroup(ctx, tg.ArtifactName, tg.ArtifactVersion, matches); err == nil {
225 tg.State = rollouts.InProgress
226 result.State = rollouts.InProgress
227 result.Message = fmt.Sprintf("Target Group %s applied", tg.GetKey())
228 }
229
230
231 return result, nil
232 }
233
234 func (d *inMemoryDriver) checkTimerGate(_ context.Context, tg *rollouts.TimerGate) (rollouts.NodeExecutionResult, error) {
235 tg.State = rollouts.Pending
236 open := tg.IsOpen()
237 result := rollouts.NodeExecutionResult{Key: tg.GetKey()}
238 result.Message = fmt.Sprintf("rollout %s timer gate %s", d.rollout.ID, tg.GetKey())
239 if open {
240 tg.State = rollouts.Complete
241 result.State = rollouts.Complete
242 result.Message = fmt.Sprintf("%s is open", result.Message)
243 } else {
244 timeLeft := tg.StartTime.Add(tg.Delay).UTC().Format(time.RFC3339)
245 result.Message = fmt.Sprintf("%s is closed and will open at %s", result.Message, timeLeft)
246 }
247 return result, nil
248 }
249
250 func (d *inMemoryDriver) checkApprovalGate(_ context.Context, ag *rollouts.ApprovalGate) (rollouts.NodeExecutionResult, error) {
251 result := rollouts.NodeExecutionResult{Key: ag.GetKey()}
252 if ag.IsOpen() {
253 result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey())
254 result.State = rollouts.Complete
255 return result, nil
256 }
257
258 gateStatus, err := d.store.GetApprovalGateStatus(ag)
259 if err != nil {
260 result.Message = fmt.Sprintf("rollout %s approval gate %s status not found", d.rollout.ID, ag.GetKey())
261 return result, fmt.Errorf("did not find gate %s", string(ag.GetKey()))
262 }
263
264 switch gateStatus {
265 case rollouts.GateApproved:
266 ag.GateState = rollouts.Open
267 ag.State = rollouts.Complete
268 result.State = rollouts.Complete
269 result.Message = fmt.Sprintf("rollout %s approval gate %s is open", d.rollout.ID, ag.GetKey())
270 return result, nil
271 case rollouts.GatePending:
272 ag.GateState = rollouts.Closed
273 ag.State = rollouts.Pending
274 result.State = rollouts.Pending
275 result.Message = fmt.Sprintf("rollout %s approval gate %s is pending", d.rollout.ID, ag.GetKey())
276 return result, nil
277 case rollouts.GateDenied:
278 ag.GateState = rollouts.Closed
279 result.State = rollouts.Halted
280 result.Message = fmt.Sprintf("rollout %s approval gate %s is denied, rollout halted", d.rollout.ID, ag.GetKey())
281 return result, nil
282 default:
283 result.Message = fmt.Sprintf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey())
284 return result, fmt.Errorf("unknown gate status %s for rollout %s, gate key %s", gateStatus, d.rollout.ID, ag.GetKey())
285 }
286 }
287
288
289
290 func (d *inMemoryDriver) persistGraph(_ context.Context) error {
291 return nil
292 }
293
View as plain text