package drivers import ( "context" "fmt" "testing" "time" "github.com/go-logr/logr" "github.com/go-logr/logr/funcr" "gotest.tools/v3/assert" "edge-infra.dev/pkg/edge/rollouts" "edge-infra.dev/pkg/edge/rollouts/internal" ) var engineChanConditions = internal.ConditionMap{ "tg1": { NodeState: rollouts.InProgress, Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) { clusterIDs, _ := inMemStore.GetClusterLabelMatches("dev") inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs) }}, "g1": { NodeState: rollouts.Complete, Action: func(_ *internal.InMemStore, _ rollouts.NodeExecutionResult) { }}, "tg2": { NodeState: rollouts.InProgress, Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) { clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:east") inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs) }}, "tg3": { NodeState: rollouts.InProgress, Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) { clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:west") inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs) }}, "ag1": { NodeState: rollouts.Pending, Action: func(inMemStore *internal.InMemStore, result rollouts.NodeExecutionResult) { _ = inMemStore.OpenApprovalGate(result.Key) }}, "tg4": { NodeState: rollouts.InProgress, Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) { clusterIDs, _ := inMemStore.GetClusterLabelMatches("prod:us") inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs) }}, } // func TestNewInMemDriver(t *testing.T) { // inMemStore := internal.NewExampleInMemStore() // fmt.Printf("TestNewInMemDriver address: %p\n", inMemStore) // _, err := newInMemoryDriver(inMemStore, examples.GetSampleGraphFromJSON()) // assert.NilError(t, err) // } func newStdoutLogger() logr.Logger { return funcr.New(func(prefix, args string) { if prefix != "" { fmt.Printf("%s: %s\n", prefix, args) } else { fmt.Println(args) } }, funcr.Options{}) } func TestRunRollouts(t *testing.T) { inMemStore1 := internal.NewExampleInMemStore() fmt.Printf("TestRunRollouts address: %p\n", inMemStore1) store1 := NewInMemoryRolloutStore(inMemStore1) fmt.Printf("TestRunRollouts store address: %p\n", store1) resultChan := make(chan rollouts.NodeExecutionResult) ctx, cancel := context.WithCancelCause(context.Background()) defer cancel(nil) testGraph, err := store1.GetRollout(ctx, internal.ExampleRolloutID) if err != nil { t.Fatal(err) } ri, err := store1.NewRolloutInstance(ctx, newStdoutLogger(), testGraph, resultChan) assert.NilError(t, err) go func(ctx context.Context) { for { rolloutOutcome, _ := RunRollout(ctx, ri) if rolloutOutcome { fmt.Println("rollout done") return } select { case <-time.After(time.Millisecond * 50): fmt.Println("continuing") continue case <-ctx.Done(): fmt.Println("done") return } } }(ctx) timeout := time.NewTimer(time.Second * 3) testloop: for { select { case rolloutStatus := <-ri.NotifyChan: if rolloutStatus.RolloutState == rollouts.RolloutComplete { fmt.Println("rollout complete") cancel(nil) <-ctx.Done() break testloop } // Rollout not done, modify state for next Run err := internal.ModifyStore(store1.Store, engineChanConditions, rolloutStatus) if err != nil { cancel(nil) <-ctx.Done() t.Fatal(err) } case <-timeout.C: cancel(nil) <-ctx.Done() err = fmt.Errorf("check/modify loop timed out") } } assert.NilError(t, err) } func TestPersistGraph(t *testing.T) { inMemStore := internal.NewExampleInMemStore() fmt.Printf("TestPersistGraph address: %p\n", inMemStore) store := NewInMemoryRolloutStore(inMemStore) fmt.Printf("TestPersistGraph store address: %p\n", store) ctx := context.Background() testGraph, err := store.GetRollout(ctx, internal.ExampleRolloutID) if err != nil { t.Fatal(err) } inMemDriver, err := newInMemoryDriver(store.Store, newStdoutLogger(), testGraph) if err != nil { t.Fatal(err) } tg1Node := testGraph.Nodes["tg1"] fmt.Printf("tg1Node address %p\n", tg1Node) tg1 := testGraph.Nodes["tg1"].(*rollouts.TargetGroup) fmt.Println("initial tg1 status", tg1) result, err := inMemDriver.execute(ctx, tg1) if err != nil { t.Fatal(err) } fmt.Println(result.Message) // _, found := inMemStore.GraphStore[internal.ExampleBannerID][internal.ExampleRolloutID] storeRollout, found := inMemStore.GraphStore[internal.ExampleBannerID][internal.ExampleRolloutID] fmt.Printf("storeRollout address: %p\n", storeRollout) if !found { t.Fatal("didn't find rollout") } rolloutTG1 := storeRollout.Nodes[tg1.GetKey()] fmt.Println("rolloutTG1 pointer", rolloutTG1) rolloutTG1, ok := rolloutTG1.(*rollouts.TargetGroup) if !ok { t.Fatal("couldn't coerce") } fmt.Println("store tg1 status:", rolloutTG1.GetState()) fmt.Println("local tg1 status:", tg1.GetState()) assert.Equal(t, tg1.GetState(), rollouts.InProgress) assert.Equal(t, rolloutTG1.GetState(), rollouts.InProgress) // err = inMemDriver.persistGraph(ctx) // if err != nil { // t.Fatal(err) // } // assert.Equal(t, rolloutTG1.GetState(), rollouts.InProgress) }