1 package integration
2
3 import (
4 "database/sql"
5 "fmt"
6 "strings"
7 "testing"
8
9 "edge-infra.dev/pkg/lib/uuid"
10 "edge-infra.dev/test/f2"
11 "edge-infra.dev/test/f2/x/postgres"
12 )
13
14 const schema = `
15 CREATE OR REPLACE FUNCTION trigger_set_timestamp()
16 RETURNS TRIGGER AS
17 $$
18 BEGIN
19 NEW.updated_at = NOW();
20 RETURN NEW;
21 END;
22 $$ LANGUAGE plpgsql;
23
24 CREATE TABLE IF NOT EXISTS banners (
25 banner_edge_id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
26 banner_name text UNIQUE NOT NULL,
27 project_id text UNIQUE NOT NULL,
28 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
29 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
30 );
31
32 CREATE TABLE IF NOT EXISTS clusters (
33 cluster_edge_id UUID DEFAULT gen_random_uuid() PRIMARY KEY,
34 cluster_name text NOT NULL,
35 banner_edge_id UUID references banners (banner_edge_id) ON DELETE CASCADE,
36 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
37 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
38 );
39
40 CREATE TABLE IF NOT EXISTS terminals (
41 terminal_id UUID default gen_random_uuid() PRIMARY KEY,
42 cluster_edge_id UUID NOT NULL,
43 hostname TEXT NOT NULL DEFAULT 'ien',
44 FOREIGN KEY(cluster_edge_id) REFERENCES clusters(cluster_edge_id) ON DELETE CASCADE,
45 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
46 updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
47 );
48 `
49
50 const (
51 InsertIntoBanners = `INSERT INTO banners (banner_edge_id, banner_name, project_id) VALUES ('%s','%s','%s');`
52 InsertIntoClusters = `INSERT INTO clusters (cluster_edge_id, cluster_name, banner_edge_id) VALUES ('%s','%s','%s');`
53 InsertIntoTerminals = `INSERT INTO terminals (terminal_id, hostname, cluster_edge_id) VALUES ('%s','%s','%s');`
54 )
55
56 var UUIDs, Names = generateUUIDsAndNames()
57
58 func generateUUIDsAndNames() (map[string][]string, map[string][]string) {
59 var projectUUIDs, bannerUUIDs, clusterUUIDs, terminalUUIDs []string
60 var bannerNames, clusterNames, terminalNames []string
61 for i := 0; i < 4; i++ {
62 projectUUIDs = append(projectUUIDs, uuid.New().UUID)
63 bannerUUIDs = append(bannerUUIDs, uuid.New().UUID)
64 clusterUUIDs = append(clusterUUIDs, uuid.New().UUID)
65 terminalUUIDs = append(terminalUUIDs, uuid.New().UUID)
66
67 bannerNames = append(bannerNames, fmt.Sprintf("banner%d", i))
68 clusterNames = append(clusterNames, fmt.Sprintf("cluster%d", i))
69 terminalNames = append(terminalNames, fmt.Sprintf("terminal%d", i))
70 }
71 uuids := map[string][]string{
72 "projects": projectUUIDs,
73 "banners": bannerUUIDs,
74 "clusters": clusterUUIDs,
75 "terminals": terminalUUIDs,
76 }
77 names := map[string][]string{
78 "banners": bannerNames,
79 "clusters": clusterNames,
80 "terminals": terminalNames,
81 }
82 return uuids, names
83 }
84
85 var (
86 Banners = []insertVals{
87 {UUIDs["banners"][0], "banner0", UUIDs["projects"][0]},
88 {UUIDs["banners"][1], "banner1", UUIDs["projects"][1]},
89 }
90
91 Clusters = []insertVals{
92 {UUIDs["clusters"][0], "cluster0", UUIDs["banners"][0]},
93 {UUIDs["clusters"][1], "cluster1", UUIDs["banners"][1]},
94 }
95
96 Terminals = []insertVals{
97 {UUIDs["terminals"][0], "terminal0", UUIDs["clusters"][0]},
98 {UUIDs["terminals"][1], "terminal1", UUIDs["clusters"][0]},
99 {UUIDs["terminals"][2], "terminal2", UUIDs["clusters"][1]},
100 {UUIDs["terminals"][3], "terminal3", UUIDs["clusters"][1]},
101 }
102 )
103
104 type insertVals struct {
105 id string
106 name string
107 parentID string
108 }
109
110 func CreateTables(ctx f2.Context, t *testing.T) (f2.Context, error) {
111 db := postgres.FromContextT(ctx, t).DB()
112 _, err := db.ExecContext(ctx, schema)
113 return ctx, err
114 }
115
116 func PopulateTables(ctx f2.Context, db *sql.DB) (err error) {
117 bannerQueries := populateQueries(InsertIntoBanners, Banners)
118 err = insertIntoTable(ctx, db, bannerQueries)
119 if err != nil {
120 return err
121 }
122
123 clusterQueries := populateQueries(InsertIntoClusters, Clusters)
124 err = insertIntoTable(ctx, db, clusterQueries)
125 if err != nil {
126 return err
127 }
128
129 terminalQueries := populateQueries(InsertIntoTerminals, Terminals)
130 err = insertIntoTable(ctx, db, terminalQueries)
131 if err != nil {
132 return err
133 }
134
135 return nil
136 }
137
138 func populateQueries(query string, vals []insertVals) (res []string) {
139 for _, val := range vals {
140 res = append(res, fmt.Sprintf(query, val.id, val.name, val.parentID))
141 }
142 return res
143 }
144
145 func insertIntoTable(ctx f2.Context, db *sql.DB, queries []string) error {
146 for _, query := range queries {
147 _, err := db.ExecContext(ctx, query)
148 if err != nil {
149 return err
150 }
151 }
152 return nil
153 }
154
155 func DropDatabase(ctx f2.Context, tname string) (f2.Context, error) {
156 db := ctx.Value(contextVal(tname)).(*sql.DB)
157 _, err := db.ExecContext(ctx, fmt.Sprintf("DROP SCHEMA %s CASCADE;", strings.ToLower(tname)))
158 return ctx, err
159 }
160
View as plain text