1 package engine
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7 "testing"
8 "time"
9
10 "github.com/go-logr/logr"
11 "github.com/go-logr/logr/funcr"
12 "gotest.tools/v3/assert"
13 "gotest.tools/v3/poll"
14
15 "edge-infra.dev/pkg/edge/rollouts"
16 "edge-infra.dev/pkg/edge/rollouts/drivers"
17 "edge-infra.dev/pkg/edge/rollouts/internal"
18 )
19
20 func newStdoutLogger() logr.Logger {
21 return funcr.New(func(prefix, args string) {
22 if prefix != "" {
23 fmt.Printf("%s: %s\n", prefix, args)
24 } else {
25 fmt.Println(args)
26 }
27 }, funcr.Options{})
28 }
29
30 func TestNewEngine(t *testing.T) {
31 store := drivers.NewInMemoryRolloutStore(internal.NewExampleInMemStore())
32 assert.Assert(t, store != nil)
33
34 logger := newStdoutLogger()
35 engine := NewRolloutEngine(store, logger, time.Millisecond*10)
36 assert.Assert(t, engine != nil)
37 }
38
39 func TestEngineRun(t *testing.T) {
40 pollDelay := 5 * time.Millisecond
41 pollTimeout := 50 * time.Millisecond
42
43 inMemStore := internal.NewExampleInMemStore()
44 store := drivers.NewInMemoryRolloutStore(inMemStore)
45 assert.Assert(t, store != nil)
46
47 ctx, cancel := context.WithCancelCause(context.Background())
48 defer cancel(nil)
49 logger := newStdoutLogger()
50 engine := NewRolloutEngine(store, logger, time.Millisecond*10)
51 assert.Assert(t, engine != nil)
52
53
54 testRollout, err := store.GetRollout(ctx, internal.ExampleRolloutID)
55 assert.NilError(t, err)
56
57
58 engineRunErr := engine.Run(ctx)
59 assert.NilError(t, engineRunErr)
60
61 engineListenErr := engine.Listen(ctx)
62 assert.NilError(t, engineListenErr)
63
64
65
66 tg1Check := func(_ poll.LogT) poll.Result {
67 tg1 := testRollout.Nodes[rollouts.NodeKey("tg1")]
68 if tg1.GetState() == rollouts.InProgress {
69 fmt.Println("reached tg1Check success")
70 return poll.Success()
71 }
72 return poll.Continue("waiting on tg1 to be InProgress. current state: %s", tg1.GetState())
73 }
74 poll.WaitOn(t, tg1Check, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
75
76
77
78
79
80 devClusterIDs, err := inMemStore.GetClusterLabelMatches("dev")
81 assert.NilError(t, err)
82 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, devClusterIDs)
83
84 tg1CompleteCheck := func(_ poll.LogT) poll.Result {
85 tg1 := testRollout.Nodes[rollouts.NodeKey("tg1")]
86 if tg1.GetState() == rollouts.Complete {
87 fmt.Println("reached tg1CompleteCheck success")
88 return poll.Success()
89 }
90 return poll.Continue("waiting on tg1 to be %s. current state: %s", rollouts.Complete, tg1.GetState())
91 }
92 poll.WaitOn(t, tg1CompleteCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
93
94
95 stageEastClusterIDs, err := inMemStore.GetClusterLabelMatches("staging:east")
96 assert.NilError(t, err)
97 stageWestClusterIDs, err := inMemStore.GetClusterLabelMatches("staging:west")
98 assert.NilError(t, err)
99
100 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, stageEastClusterIDs)
101 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, stageWestClusterIDs)
102
103 tg2And3CompleteCheck := func(_ poll.LogT) poll.Result {
104 tg2 := testRollout.Nodes[rollouts.NodeKey("tg2")]
105 tg3 := testRollout.Nodes[rollouts.NodeKey("tg3")]
106 if tg2.GetState() == rollouts.Complete && tg3.GetState() == rollouts.Complete {
107 fmt.Println("reached tg2And3CompleteCheck success")
108 return poll.Success()
109 }
110 return poll.Continue("waiting on tg2 and tg3 to be %s. current state: tg2: %s tg3: %s", rollouts.Complete, tg2.GetState(), tg3.GetState())
111 }
112 poll.WaitOn(t, tg2And3CompleteCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
113
114
115 ag1PendingCheck := func(_ poll.LogT) poll.Result {
116 ag1 := testRollout.Nodes[rollouts.NodeKey("ag1")]
117 if ag1.GetState() == rollouts.Pending {
118 fmt.Println("reached ag1PendingCheck success")
119 return poll.Success()
120 }
121 return poll.Continue("waiting on ag1 to be %s. current state: %s", rollouts.Pending, ag1.GetState())
122 }
123 poll.WaitOn(t, ag1PendingCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
124
125
126 ag1 := testRollout.Nodes[rollouts.NodeKey("ag1")].(*rollouts.ApprovalGate)
127 err = inMemStore.OpenApprovalGate(ag1.GetKey())
128 assert.NilError(t, err)
129 ag1CompleteCheck := func(_ poll.LogT) poll.Result {
130 if ag1.GetState() == rollouts.Complete {
131 fmt.Println("reached ag1CompleteCheck success")
132 return poll.Success()
133 }
134 return poll.Continue("waiting on ag1 to be %s. current state: %s", rollouts.Complete, ag1.GetState())
135 }
136 poll.WaitOn(t, ag1CompleteCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
137
138
139 tg4PendingCheck := func(_ poll.LogT) poll.Result {
140 tg4 := testRollout.Nodes[rollouts.NodeKey("tg4")]
141 if tg4.GetState() == rollouts.Pending {
142 fmt.Println("reached tg4PendingCheck success")
143 return poll.Success()
144 }
145 return poll.Continue("waiting on tg4 to be %s. current state: %s", rollouts.Pending, tg4.GetState())
146 }
147 poll.WaitOn(t, tg4PendingCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
148
149
150 prodClusterIDs, err := inMemStore.GetClusterLabelMatches("prod:us")
151 assert.NilError(t, err)
152 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, prodClusterIDs)
153
154 tg4CompleteCheck := func(_ poll.LogT) poll.Result {
155 tg4 := testRollout.Nodes[rollouts.NodeKey("tg4")]
156 if tg4.GetState() == rollouts.Complete {
157 fmt.Println("reached tg4CompleteCheck success")
158 return poll.Success()
159 }
160 return poll.Continue("waiting on tg4 to be %s. current state: %s", rollouts.Complete, tg4.GetState())
161 }
162 poll.WaitOn(t, tg4CompleteCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
163
164
165 completeCheck := func(_ poll.LogT) poll.Result {
166 incompleteNodes := []string{}
167 for key, node := range testRollout.Nodes {
168 if node.GetState() != rollouts.Complete {
169 incompleteNodes = append(incompleteNodes, string(key))
170 }
171 }
172 if len(incompleteNodes) > 0 {
173 return poll.Continue("waiting on nodes [%s] to complete", strings.Join(incompleteNodes, ","))
174 }
175
176
177 return poll.Success()
178 }
179 poll.WaitOn(t, completeCheck, poll.WithDelay(pollDelay), poll.WithTimeout(pollTimeout))
180
181
182 cancel(nil)
183 <-ctx.Done()
184 }
185
186 var engineChanConditions = internal.ConditionMap{
187 "tg1": {
188 NodeState: rollouts.InProgress,
189 Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
190 clusterIDs, _ := inMemStore.GetClusterLabelMatches("dev")
191 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
192 }},
193 "g1": {
194 NodeState: rollouts.Complete,
195 Action: func(_ *internal.InMemStore, _ rollouts.NodeExecutionResult) {
196 }},
197 "tg2": {
198 NodeState: rollouts.InProgress,
199 Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
200 clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:east")
201 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
202 }},
203 "tg3": {
204 NodeState: rollouts.InProgress,
205 Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
206 clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:west")
207 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
208 }},
209 "ag1": {
210 NodeState: rollouts.Pending,
211 Action: func(inMemStore *internal.InMemStore, result rollouts.NodeExecutionResult) {
212 _ = inMemStore.OpenApprovalGate(result.Key)
213 }},
214 "tg4": {
215 NodeState: rollouts.InProgress,
216 Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
217 clusterIDs, _ := inMemStore.GetClusterLabelMatches("prod:us")
218 inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
219 }},
220 }
221
222 func TestEngineChanRun(t *testing.T) {
223 inMemStore := internal.NewExampleInMemStore()
224 store := drivers.NewInMemoryRolloutStore(inMemStore)
225 assert.Assert(t, store != nil)
226
227 ctx, cancel := context.WithCancelCause(context.Background())
228 defer cancel(nil)
229 logger := newStdoutLogger()
230 engine := NewRolloutEngine(store, logger, time.Millisecond*10)
231 assert.Assert(t, engine != nil)
232
233 err := engine.Run(ctx)
234 if err != nil {
235 t.Fatal(err)
236 }
237
238 timeout := time.NewTimer(time.Second * 2)
239 testloop:
240 for {
241 fmt.Println("in for")
242 select {
243 case nodeResult := <-engine.resultChan:
244 fmt.Println("selecting engine event")
245 switch nodeResult.RolloutState {
246 case rollouts.RolloutComplete:
247 t.Log("done")
248 cancel(nil)
249 <-ctx.Done()
250 break testloop
251 default:
252 fmt.Println("received from event chan")
253 fmt.Println(nodeResult.Message)
254 fmt.Println("running event check")
255 err := internal.ModifyStore(inMemStore, engineChanConditions, nodeResult)
256 if err != nil {
257 t.Fatal(err)
258 }
259 }
260 case <-timeout.C:
261 fmt.Println("timeout")
262 cancel(nil)
263 <-ctx.Done()
264 t.Fatal("test timed out")
265 }
266 }
267 }
268
269 func TestEngineRunListen(t *testing.T) {
270 inMemStore := internal.NewExampleInMemStore()
271 store := drivers.NewInMemoryRolloutStore(inMemStore)
272 assert.Assert(t, store != nil)
273
274 ctx, cancel := context.WithCancelCause(context.Background())
275 defer cancel(nil)
276 logger := newStdoutLogger()
277 engine := NewRolloutEngine(store, logger, time.Millisecond*10)
278 assert.Assert(t, engine != nil)
279
280 err := engine.Run(ctx)
281 if err != nil {
282 t.Fatal(err)
283 }
284 err = engine.Listen(ctx)
285 if err != nil {
286 t.Fatal(err)
287 }
288
289 <-time.After(time.Second * 1)
290 cancel(nil)
291
292
293
294 assert.NilError(t, err)
295 }
296
View as plain text