...

Source file src/edge-infra.dev/pkg/edge/rollouts/rollout_test.go

Documentation: edge-infra.dev/pkg/edge/rollouts

     1  package rollouts
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  //go:embed testdata/simple-graph.json
    12  var simpleGraph []byte
    13  
    14  //go:embed testdata/simple-plan.json
    15  var simplePlan []byte
    16  
    17  //go:embed testdata/bad-edge-plan.json
    18  var badEdgePlan []byte
    19  
    20  //go:embed testdata/bad-edge-graph.json
    21  var badEdgeGraph []byte
    22  
    23  var fakeGraphID = "abc1234-def"
    24  var fakePlanID = "abc1234-def"
    25  
    26  func testSimpleEdges(edges []RolloutGraphEdge) error {
    27  	if len(edges) != 7 {
    28  		return fmt.Errorf("expected 7 edges, got %d", len(edges))
    29  	}
    30  	return nil
    31  }
    32  
    33  func testSimpleNodes(nodes map[NodeKey]RolloutGraphNode) error {
    34  	if len(nodes) != 7 {
    35  		return fmt.Errorf("expected 7 nodes, got %d", len(nodes))
    36  	}
    37  	// tg1
    38  	tg1 := nodes["tg1"]
    39  	if tg1 == nil {
    40  		return fmt.Errorf("expected node with key %s, got nil", "tg1")
    41  	}
    42  	if len(tg1.GetDependsOn()) != 0 {
    43  		return fmt.Errorf("expected first node, tg1, to have no dependencies, got %d", len(tg1.GetDependsOn()))
    44  	}
    45  	if len(tg1.GetNext()) != 1 {
    46  		return fmt.Errorf("expected first node, tg1, to have 1 next node, got %d", len(tg1.GetNext()))
    47  	}
    48  	// tg2
    49  	tg2 := nodes["tg2"]
    50  	if tg2 == nil {
    51  		return fmt.Errorf("expected node with key %s, got nil", "tg2")
    52  	}
    53  	if len(tg2.GetDependsOn()) != 1 {
    54  		return fmt.Errorf("expected node tg2, to have 1 dependencies, got %d", len(tg2.GetDependsOn()))
    55  	}
    56  	if len(tg2.GetNext()) != 1 {
    57  		return fmt.Errorf("expected node tg2, to have 1 next node, got %d", len(tg2.GetNext()))
    58  	}
    59  	// tg3
    60  	tg3 := nodes["tg3"]
    61  	if tg3 == nil {
    62  		return fmt.Errorf("expected node with key %s, got nil", "tg3")
    63  	}
    64  	if len(tg3.GetDependsOn()) != 1 {
    65  		return fmt.Errorf("expected node tg3, to have 1 dependencies, got %d", len(tg3.GetDependsOn()))
    66  	}
    67  	if len(tg3.GetNext()) != 1 {
    68  		return fmt.Errorf("expected node tg3, to have 1 next node, got %d", len(tg3.GetNext()))
    69  	}
    70  	// tg4
    71  	tg4 := nodes["tg4"]
    72  	if tg4 == nil {
    73  		return fmt.Errorf("expected node with key %s, got nil", "tg4")
    74  	}
    75  	if len(tg4.GetDependsOn()) != 1 {
    76  		return fmt.Errorf("expected node tg4, to have 1 dependencies, got %d", len(tg4.GetDependsOn()))
    77  	}
    78  	if len(tg4.GetNext()) != 0 {
    79  		return fmt.Errorf("expected terminal node, tg4, to have 0 next nodes, got %d", len(tg4.GetNext()))
    80  	}
    81  	// tg5
    82  	tg5 := nodes["tg5"]
    83  	if tg5 == nil {
    84  		return fmt.Errorf("expected node with key %s, got nil", "tg5")
    85  	}
    86  	if len(tg5.GetDependsOn()) != 1 {
    87  		return fmt.Errorf("expected node tg5, to have 1 dependencies, got %d", len(tg5.GetDependsOn()))
    88  	}
    89  	if len(tg5.GetNext()) != 0 {
    90  		return fmt.Errorf("expected terminal node, tg5, to have 0 next nodes, got %d", len(tg5.GetNext()))
    91  	}
    92  	// g1
    93  	g1 := nodes["g1"]
    94  	if g1 == nil {
    95  		return fmt.Errorf("expected node with key %s, got nil", "g1")
    96  	}
    97  	if len(g1.GetDependsOn()) != 1 {
    98  		return fmt.Errorf("expected node g1 to have 1 dependencies, got %d", len(g1.GetDependsOn()))
    99  	}
   100  	if len(g1.GetNext()) != 2 {
   101  		return fmt.Errorf("expected node g1 to have 2 next nodes, got %d", len(g1.GetNext()))
   102  	}
   103  	// g2
   104  	g2 := nodes["g2"]
   105  	if g2 == nil {
   106  		return fmt.Errorf("expected node with key %s, got nil", "g2")
   107  	}
   108  	if len(g2.GetDependsOn()) != 2 {
   109  		return fmt.Errorf("expected node g2 to have 2 dependencies, got %d", len(g2.GetDependsOn()))
   110  	}
   111  	if len(g2.GetNext()) != 2 {
   112  		return fmt.Errorf("expected node g2 to have 2 next nodes, got %d", len(g2.GetNext()))
   113  	}
   114  
   115  	return nil
   116  }
   117  
   118  func TestCreateRolloutPlanFromJSON(t *testing.T) {
   119  	plan, err := NewRolloutPlanFromJSON(simplePlan)
   120  	fmt.Printf("%+v\n", plan)
   121  	if err != nil {
   122  		t.Fatalf("failed to create test plan from JSON: %v", err)
   123  	}
   124  
   125  	if len(plan.Initial) != 1 {
   126  		t.Fatalf("expecting Initial nodes to be len 1 but got %d", len(plan.Initial))
   127  	}
   128  
   129  	assert.Equal(t, plan.ID, fakePlanID)
   130  
   131  	nodesErr := testSimpleNodes(plan.Nodes)
   132  	if nodesErr != nil {
   133  		t.Fatal(nodesErr)
   134  	}
   135  
   136  	edgesErr := testSimpleEdges(plan.Edges)
   137  	if edgesErr != nil {
   138  		t.Fatal(edgesErr)
   139  	}
   140  }
   141  
   142  func TestRolloutPlanJSONBadEdge(t *testing.T) {
   143  	plan, err := NewRolloutPlanFromJSON(badEdgePlan)
   144  	fmt.Printf("%+v\n", plan)
   145  	if err != nil {
   146  		t.Logf("successful error: %s", err)
   147  		fmt.Println(err)
   148  	} else {
   149  		t.Fatal("didn't log error")
   150  	}
   151  }
   152  
   153  func TestRolloutGraphJSONBadEdge(t *testing.T) {
   154  	graph, err := NewRolloutGraphFromJSON(badEdgeGraph)
   155  	fmt.Printf("%+v\n", graph)
   156  	if err != nil {
   157  		t.Logf("successful error: %s", err)
   158  		fmt.Println(err)
   159  	} else {
   160  		t.Fatal("didn't log error")
   161  	}
   162  }
   163  
   164  func TestCreateRolloutGraphFromJSON(t *testing.T) {
   165  	rollout, err := NewRolloutGraphFromJSON(simpleGraph)
   166  	if err != nil {
   167  		t.Fatalf("failed create test rollout graph from JSON: %v", err)
   168  	}
   169  
   170  	assert.Equal(t, rollout.ID, fakeGraphID)
   171  
   172  	nodesErr := testSimpleNodes(rollout.Nodes)
   173  	if nodesErr != nil {
   174  		t.Fatal(nodesErr)
   175  	}
   176  
   177  	edgesErr := testSimpleEdges(rollout.Edges)
   178  	if edgesErr != nil {
   179  		t.Fatal(edgesErr)
   180  	}
   181  }
   182  
   183  //func TestRolloutWaitsForAllInputs(t *testing.T) {
   184  //	// TODO(dk185217): Do something with this test other than skip. just skipping to keep logspam down
   185  //	// t.SkipNow()
   186  
   187  //	// for the canonical example graph:
   188  //	// S->TG_1, TG_1->G_1, G_1->TG_2, G_1->TG_3, TG_2->G_2, TG_3->G_2, G_2->TG_4, G_2->TG_5
   189  //	// G_2 is the first node with multiple inputs. Test that G_2 does not execute until TG_2 and TG_3 finish
   190  
   191  //	// if a node has multiple inputs, it must wait until all those inputs are done
   192  //	// before being processed. in the case of a finish node for eg, we might get: (from test fmt logs)
   193  //	// updating Current to: [Finish,Finish]
   194  //	// Processing current step of graph: { Complete: false, Current: [Finish,Finish] }
   195  //	// executing node: Finish
   196  //	// executing node: Finish
   197  //	// updating Current to: []
   198  
   199  //	// in this case, it is incorrect to execute the finish node because all of its inputs are not independently
   200  //	// checked before adding it to the list of next nodes
   201  
   202  //	tt := &testTimer{
   203  //		fakeTime: time.Now(),
   204  //	}
   205  
   206  //	// create a rollout plan
   207  //	plan, _ := NewRolloutPlanFromJSON(simplePlan)
   208  
   209  //	testState := map[string]string{}
   210  //	labels := map[string][]string{
   211  //		"dev":          {"dev0", "dev1", "dev2"},
   212  //		"staging-east": {"stage0-east", "stage1-east"},
   213  //		"staging-west": {"stage2-west", "stage3-west"},
   214  //		"prod-us":      {"prod0-east", "prod1-east", "prod2-west", "prod3-west"},
   215  //		"prod-global":  {"prod4", "prod5"},
   216  //	}
   217  //	rolloutDriver := NewInMemDriver(fakeBannerID, testState, labels, false)
   218  //	// create a rollout from the plan (TODO)
   219  
   220  //	rollout := NewRolloutGraph(plan, rolloutDriver, WithTimeSource(tt))
   221  
   222  //	// start the rollout
   223  //	i := 0
   224  
   225  //	ctx := context.Background()
   226  //	for {
   227  //		done, err := rollout.ProcessStep(ctx)
   228  //		if err != nil {
   229  //			t.Fatalf("failed to process rollout step: %v", err)
   230  //		}
   231  //		if done {
   232  //			break
   233  //		}
   234  
   235  //		// nodes should not be in-progress if their dependencies have not finished
   236  //		for _, currKey := range rollout.Current {
   237  //			curr := rollout.Nodes[currKey]
   238  //			allDepsComplete := true
   239  //			var incompleteDep RolloutGraphNode
   240  //			for _, depOfCurr := range curr.GetDependsOn() {
   241  //				if depOfCurr.GetState() != Complete {
   242  //					allDepsComplete = false
   243  //					incompleteDep = depOfCurr
   244  //				}
   245  //			}
   246  //			if !allDepsComplete && curr.GetState() == InProgress {
   247  //				t.Fatalf("node %s was started without finishing its deps first. incomplete dep: %s",
   248  //					curr.GetKey(), incompleteDep.GetKey())
   249  //			}
   250  //		}
   251  
   252  //		// fast forward into the future. 2 seconds per step
   253  //		fmt.Println("fast forwarding time...")
   254  //		tt.fakeTime = tt.fakeTime.Add(2001 * time.Millisecond)
   255  
   256  //		// simulate an approval. might take... 11 iterations?
   257  //		if i > 11 && !rolloutDriver.approved {
   258  //			fmt.Println("approving...")
   259  //			rolloutDriver.approved = true
   260  //		}
   261  
   262  //		i++
   263  //		if i > 100 {
   264  //			t.Fatal("took too long")
   265  //		}
   266  //	}
   267  //}
   268  
   269  //func TestRolloutGraphConstraints(t *testing.T) {
   270  //	plan := RolloutPlan{Name: "test-rollout-graph-constraints"}
   271  //	rolloutDriver := NewInMemDriver("", nil, nil, false)
   272  //	rollout := NewRolloutGraph(plan, rolloutDriver)
   273  
   274  //	// check if edges and node dependencies are wired up correctly
   275  //	numDeps := 0
   276  //	for _, n := range rollout.Nodes {
   277  //		numDeps += len(n.GetDependsOn())
   278  //	}
   279  //	if numDeps != len(rollout.Edges) {
   280  //		t.Fatalf("number of edges did not match total number of dependent relations. edges: %v, deps: %v",
   281  //			len(rollout.Edges), numDeps)
   282  //	}
   283  //}
   284  
   285  //type testTimer struct {
   286  //	fakeTime time.Time
   287  //}
   288  
   289  //func (t *testTimer) Now() time.Time {
   290  //	return t.fakeTime
   291  //}
   292  

View as plain text