1 package sql
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "net"
8 "path"
9 "testing"
10 "time"
11
12 "github.com/bazelbuild/rules_go/go/runfiles"
13 pgsql "github.com/fergusstrange/embedded-postgres"
14 "github.com/google/uuid"
15 _ "github.com/jackc/pgx/v4/stdlib"
16
17 "edge-infra.dev/pkg/f8n/kinform/model"
18 sovereign "edge-infra.dev/pkg/f8n/sovereign/model"
19 "edge-infra.dev/pkg/lib/build/bazel"
20 "edge-infra.dev/pkg/lib/compression"
21 )
22
23 var (
24 skipAll = false
25 pgDSN string
26 )
27
28 func TestMain(m *testing.M) {
29 embeddedTxzFile, err := runfiles.Rlocation(path.Join("edge_infra", "hack", "tools", "postgres.txz"))
30 if err != nil {
31 panic(err)
32 }
33
34 tempDir, err := bazel.NewTestTmpDir("edge-infra-kinform-sql-test-*")
35 if err != nil {
36 panic(err)
37 }
38 pgRuntimePath := path.Join(tempDir, "runtime")
39 pgTempDir := tempDir
40
41 err = compression.DecompressTarXz(embeddedTxzFile, tempDir)
42 if err != nil {
43 panic(err)
44 }
45
46 pgUser := postgres
47 pgPass := postgres
48 pgHost := "127.0.0.1"
49 pgDB := postgres
50 port, err := findOpenPort()
51 if err != nil {
52 panic(err)
53 }
54 pgDSN = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", pgHost, port, pgUser, pgPass, pgDB)
55
56 cfg := pgsql.DefaultConfig()
57 cfg = cfg.Username(pgUser)
58 cfg = cfg.Password(pgPass)
59 cfg = cfg.Port(uint32(port))
60 cfg = cfg.Version(pgsql.V14)
61 cfg = cfg.RuntimePath(pgRuntimePath)
62 cfg = cfg.BinariesPath(pgTempDir)
63 cfg = cfg.Database(pgDB)
64
65 pg := pgsql.NewDatabase(cfg)
66 if err := pg.Start(); err != nil {
67 panic(err)
68 }
69 defer func() {
70 if err := pg.Stop(); err != nil {
71 panic(err)
72 }
73 }()
74
75 m.Run()
76 }
77
78 func findOpenPort() (int, error) {
79 addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
80 if err != nil {
81 return 0, err
82 }
83 l, err := net.ListenTCP("tcp", addr)
84 if err != nil {
85 return 0, err
86 }
87 defer l.Close()
88 return l.Addr().(*net.TCPAddr).Port, nil
89 }
90
91 func TestIntegration(t *testing.T) {
92
93
94 t.SkipNow()
95 }
96
97
98 func TestLabelMatching(t *testing.T) {
99 if skipAll {
100
101 t.SkipNow()
102 }
103
104 t.Logf("using dsn %s\n", pgDSN)
105 db, err := sql.Open("pgx", pgDSN)
106 if err != nil {
107 t.Fatal(err)
108 }
109 if err := db.Ping(); err != nil {
110 t.Fatal(err)
111 }
112
113
114 err = execSchema(db)
115 if err != nil {
116 t.Fatal(err)
117 }
118
119 ctx := context.Background()
120 dbHandle := &DBHandle{DB: db}
121
122
123 clusterID, err := dbHandle.InsertCluster(ctx, "test-cluster")
124 if err != nil {
125 t.Fatal(err)
126 }
127 t.Logf("inserted new cluster. id: %v", clusterID)
128 _, err = dbHandle.InsertClusterLabel(ctx, clusterID, "test", "true")
129 if err != nil {
130 t.Fatal(err)
131 }
132
133
134 clusters, err := dbHandle.GetClustersMatchingArtifactLabels(ctx, uuid.Nil)
135 if err != nil {
136 t.Fatal(err)
137 }
138 if len(clusters) != 0 {
139 t.Fatal("found clusters matching nil artifact. this is unexpected")
140 }
141
142
143 nilDigest := "0000000000000000000000000000000000000000000000000000000000000000"
144 artifactVersionID, err := dbHandle.InsertArtifactVersion(ctx, "kinform-test", "latest", nilDigest)
145 if err != nil {
146 t.Fatal(err)
147 }
148 artifact := sovereign.Artifact{
149 ProjectID: "ret-edge-test",
150 Repository: "workloads",
151 ArtifactVersion: artifactVersionID,
152 }
153 artifactID, err := dbHandle.InsertArtifact(ctx, artifact)
154 if err != nil {
155 t.Fatal(err)
156 }
157
158
159 _, err = dbHandle.InsertArtifactLabel(ctx, artifactID, "test", "true")
160 if err != nil {
161 t.Fatal(err)
162 }
163
164
165 clusters, err = dbHandle.GetClustersMatchingArtifactLabels(ctx, artifactID)
166 if err != nil {
167 t.Fatal(err)
168 }
169 t.Logf("found %v clusters with labels matching artifact with id %v", len(clusters), artifactID)
170 if len(clusters) != 1 {
171 t.Fatalf("expected artifact %v to have matched the test cluster %v", artifactID, clusterID)
172 }
173 }
174
175 func TestConnectCluster(t *testing.T) {
176 if skipAll {
177
178 t.SkipNow()
179 }
180
181 t.Logf("using dsn %s\n", pgDSN)
182 db, err := sql.Open("pgx", pgDSN)
183 if err != nil {
184 t.Fatal(err)
185 }
186 if err := db.Ping(); err != nil {
187 t.Fatal(err)
188 }
189 err = execSchema(db)
190 if err != nil {
191 t.Fatal(err)
192 }
193
194 ctx := context.Background()
195 dbHandle := &DBHandle{DB: db}
196
197 clusterID := uuid.New()
198 sessionID := uuid.New()
199 heartbeat := model.ClusterHeartbeat{
200 Cluster: clusterID,
201 ClusterVersion: model.ClusterVersionInfo{},
202 Timestamp: time.Now(),
203 SessionID: sessionID,
204 }
205 err = dbHandle.UpdateClusterHeartbeatWithSession(ctx, heartbeat)
206 if err != nil {
207 t.Fatal(err)
208 }
209
210
211 for i := 0; i < 50; i++ {
212 heartbeat.Timestamp = time.Now()
213 err = dbHandle.UpdateClusterHeartbeatWithSession(ctx, heartbeat)
214 if err != nil {
215 t.Fatal(err)
216 }
217 }
218
219
220 testQuery := `SELECT count(*) FROM kinform_sessions WHERE cluster = $1`
221 row := db.QueryRowContext(ctx, testQuery, clusterID)
222 var rowCount int
223 err = row.Scan(&rowCount)
224 if err != nil {
225 t.Fatal(err)
226 }
227 if rowCount != 1 {
228 t.Fatalf("expected to find 1 row. got %d", rowCount)
229 }
230 }
231
232
233
234
235
236
237
238
239
240
241
242
243
244
View as plain text