...

Source file src/edge-infra.dev/pkg/edge/rollouts/drivers/inmem_driver_test.go

Documentation: edge-infra.dev/pkg/edge/rollouts/drivers

     1  package drivers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/go-logr/logr"
    10  	"github.com/go-logr/logr/funcr"
    11  	"gotest.tools/v3/assert"
    12  
    13  	"edge-infra.dev/pkg/edge/rollouts"
    14  	"edge-infra.dev/pkg/edge/rollouts/internal"
    15  )
    16  
    17  var engineChanConditions = internal.ConditionMap{
    18  	"tg1": {
    19  		NodeState: rollouts.InProgress,
    20  		Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
    21  			clusterIDs, _ := inMemStore.GetClusterLabelMatches("dev")
    22  			inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
    23  		}},
    24  	"g1": {
    25  		NodeState: rollouts.Complete,
    26  		Action: func(_ *internal.InMemStore, _ rollouts.NodeExecutionResult) {
    27  		}},
    28  	"tg2": {
    29  		NodeState: rollouts.InProgress,
    30  		Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
    31  			clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:east")
    32  			inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
    33  		}},
    34  	"tg3": {
    35  		NodeState: rollouts.InProgress,
    36  		Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
    37  			clusterIDs, _ := inMemStore.GetClusterLabelMatches("staging:west")
    38  			inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
    39  		}},
    40  	"ag1": {
    41  		NodeState: rollouts.Pending,
    42  		Action: func(inMemStore *internal.InMemStore, result rollouts.NodeExecutionResult) {
    43  			_ = inMemStore.OpenApprovalGate(result.Key)
    44  		}},
    45  	"tg4": {
    46  		NodeState: rollouts.InProgress,
    47  		Action: func(inMemStore *internal.InMemStore, _ rollouts.NodeExecutionResult) {
    48  			clusterIDs, _ := inMemStore.GetClusterLabelMatches("prod:us")
    49  			inMemStore.SetClusterArtifactReady(internal.StoreArtifactName, clusterIDs)
    50  		}},
    51  }
    52  
    53  // func TestNewInMemDriver(t *testing.T) {
    54  // 	inMemStore := internal.NewExampleInMemStore()
    55  // 	fmt.Printf("TestNewInMemDriver address: %p\n", inMemStore)
    56  
    57  // 	_, err := newInMemoryDriver(inMemStore, examples.GetSampleGraphFromJSON())
    58  // 	assert.NilError(t, err)
    59  // }
    60  
    61  func newStdoutLogger() logr.Logger {
    62  	return funcr.New(func(prefix, args string) {
    63  		if prefix != "" {
    64  			fmt.Printf("%s: %s\n", prefix, args)
    65  		} else {
    66  			fmt.Println(args)
    67  		}
    68  	}, funcr.Options{})
    69  }
    70  
    71  func TestRunRollouts(t *testing.T) {
    72  	inMemStore1 := internal.NewExampleInMemStore()
    73  	fmt.Printf("TestRunRollouts address: %p\n", inMemStore1)
    74  	store1 := NewInMemoryRolloutStore(inMemStore1)
    75  	fmt.Printf("TestRunRollouts store address: %p\n", store1)
    76  	resultChan := make(chan rollouts.NodeExecutionResult)
    77  	ctx, cancel := context.WithCancelCause(context.Background())
    78  	defer cancel(nil)
    79  	testGraph, err := store1.GetRollout(ctx, internal.ExampleRolloutID)
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  	ri, err := store1.NewRolloutInstance(ctx, newStdoutLogger(), testGraph, resultChan)
    84  	assert.NilError(t, err)
    85  
    86  	go func(ctx context.Context) {
    87  		for {
    88  			rolloutOutcome, _ := RunRollout(ctx, ri)
    89  			if rolloutOutcome {
    90  				fmt.Println("rollout done")
    91  				return
    92  			}
    93  			select {
    94  			case <-time.After(time.Millisecond * 50):
    95  				fmt.Println("continuing")
    96  				continue
    97  			case <-ctx.Done():
    98  				fmt.Println("done")
    99  				return
   100  			}
   101  		}
   102  	}(ctx)
   103  
   104  	timeout := time.NewTimer(time.Second * 3)
   105  testloop:
   106  	for {
   107  		select {
   108  		case rolloutStatus := <-ri.NotifyChan:
   109  			if rolloutStatus.RolloutState == rollouts.RolloutComplete {
   110  				fmt.Println("rollout complete")
   111  				cancel(nil)
   112  				<-ctx.Done()
   113  				break testloop
   114  			}
   115  			// Rollout not done, modify state for next Run
   116  			err := internal.ModifyStore(store1.Store, engineChanConditions, rolloutStatus)
   117  			if err != nil {
   118  				cancel(nil)
   119  				<-ctx.Done()
   120  				t.Fatal(err)
   121  			}
   122  		case <-timeout.C:
   123  			cancel(nil)
   124  			<-ctx.Done()
   125  			err = fmt.Errorf("check/modify loop timed out")
   126  		}
   127  	}
   128  	assert.NilError(t, err)
   129  }
   130  
   131  func TestPersistGraph(t *testing.T) {
   132  	inMemStore := internal.NewExampleInMemStore()
   133  	fmt.Printf("TestPersistGraph address: %p\n", inMemStore)
   134  	store := NewInMemoryRolloutStore(inMemStore)
   135  	fmt.Printf("TestPersistGraph store address: %p\n", store)
   136  
   137  	ctx := context.Background()
   138  	testGraph, err := store.GetRollout(ctx, internal.ExampleRolloutID)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  
   143  	inMemDriver, err := newInMemoryDriver(store.Store, newStdoutLogger(), testGraph)
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  
   148  	tg1Node := testGraph.Nodes["tg1"]
   149  	fmt.Printf("tg1Node address %p\n", tg1Node)
   150  	tg1 := testGraph.Nodes["tg1"].(*rollouts.TargetGroup)
   151  	fmt.Println("initial tg1 status", tg1)
   152  	result, err := inMemDriver.execute(ctx, tg1)
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	fmt.Println(result.Message)
   157  
   158  	// _, found := inMemStore.GraphStore[internal.ExampleBannerID][internal.ExampleRolloutID]
   159  	storeRollout, found := inMemStore.GraphStore[internal.ExampleBannerID][internal.ExampleRolloutID]
   160  	fmt.Printf("storeRollout address: %p\n", storeRollout)
   161  	if !found {
   162  		t.Fatal("didn't find rollout")
   163  	}
   164  	rolloutTG1 := storeRollout.Nodes[tg1.GetKey()]
   165  	fmt.Println("rolloutTG1 pointer", rolloutTG1)
   166  	rolloutTG1, ok := rolloutTG1.(*rollouts.TargetGroup)
   167  	if !ok {
   168  		t.Fatal("couldn't coerce")
   169  	}
   170  
   171  	fmt.Println("store tg1 status:", rolloutTG1.GetState())
   172  	fmt.Println("local tg1 status:", tg1.GetState())
   173  
   174  	assert.Equal(t, tg1.GetState(), rollouts.InProgress)
   175  	assert.Equal(t, rolloutTG1.GetState(), rollouts.InProgress)
   176  	// err = inMemDriver.persistGraph(ctx)
   177  	// if err != nil {
   178  	// 	t.Fatal(err)
   179  	// }
   180  	// assert.Equal(t, rolloutTG1.GetState(), rollouts.InProgress)
   181  }
   182  

View as plain text