1 package migratest
2
3 import (
4 "fmt"
5 "sync"
6 "testing"
7
8 "github.com/google/uuid"
9 "github.com/jmoiron/sqlx"
10 migrate "github.com/rubenv/sql-migrate"
11 "github.com/stretchr/testify/require"
12
13 "github.com/ory/x/dbal"
14 "github.com/ory/x/sqlcon/dockertest"
15 )
16
17
18 type MigrationSchemas []map[string]*dbal.PackrMigrationSource
19
20
21 func RunPackrMigrationTests(
22 t *testing.T, schema, data MigrationSchemas,
23 init, cleanup func(*testing.T, *sqlx.DB),
24 runner func(*testing.T, string, *sqlx.DB, int, int, int),
25 ) {
26 if testing.Short() {
27 t.SkipNow()
28 return
29 }
30
31 var m sync.Mutex
32 var dbs = map[string]*sqlx.DB{}
33 var mid = uuid.New().String()
34
35 var dbnames = map[string]bool{}
36 for _, ms := range schema {
37 for dbname := range ms {
38 dbnames[dbname] = true
39 }
40 }
41
42 var connectors []func()
43 for dbname := range dbnames {
44 switch dbname {
45 case dbal.DriverPostgreSQL:
46 connectors = append(connectors, func() {
47 db, err := dockertest.ConnectToTestPostgreSQL()
48 if err != nil {
49 t.Fatalf("Could not connect to database: %v", err)
50 }
51 m.Lock()
52 dbs[dbal.DriverPostgreSQL] = db
53 m.Unlock()
54 })
55 case dbal.DriverMySQL:
56 connectors = append(connectors, func() {
57 db, err := dockertest.ConnectToTestMySQL()
58 if err != nil {
59 t.Fatalf("Could not connect to database: %v", err)
60 }
61 m.Lock()
62 dbs[dbal.DriverMySQL] = db
63 m.Unlock()
64 })
65 case dbal.DriverCockroachDB:
66 connectors = append(connectors, func() {
67 db, err := dockertest.ConnectToTestCockroachDB()
68 if err != nil {
69 t.Fatalf("Could not connect to database: %v", err)
70 }
71 m.Lock()
72 dbs[dbal.DriverCockroachDB] = db
73 m.Unlock()
74 })
75 default:
76 panic(fmt.Sprintf("Database name %s unknown", dbname))
77 }
78 }
79
80 dockertest.Parallel(connectors)
81
82 if data != nil {
83 require.Equal(t, len(schema), len(data))
84 }
85
86 for name, db := range dbs {
87 dialect := db.DriverName()
88 if dialect == "pgx" {
89 dialect = "postgres"
90 }
91 t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) {
92 init(t, db)
93
94 for sk, ss := range schema {
95 t.Run(fmt.Sprintf("schema=%d/run", sk), func(t *testing.T) {
96 steps := len(ss[name].Box.List())
97 for step := 0; step < steps; step++ {
98 t.Run(fmt.Sprintf("up=%d", step), func(t *testing.T) {
99 migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
100 n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Up, 1)
101 require.NoError(t, err)
102 require.Equal(t, n, 1, sk)
103
104 t.Run(fmt.Sprintf("data=%d", step), func(t *testing.T) {
105 if data == nil || data[sk] == nil {
106 t.Skip("Skipping data creation because no schema specified...")
107 return
108 }
109
110 migrate.SetTable(fmt.Sprintf("%s_%d_data", mid, sk))
111 n, err = migrate.ExecMax(db.DB, dialect, data[sk][name], migrate.Up, 1)
112 require.NoError(t, err)
113 require.Equal(t, 1, n)
114 })
115 })
116 }
117
118 for step := 0; step < steps; step++ {
119 t.Run(fmt.Sprintf("runner=%d", step), func(t *testing.T) {
120 runner(t, name, db, sk, step, steps)
121 })
122 }
123 })
124 }
125
126 for sk := len(schema) - 1; sk >= 0; sk-- {
127 ss := schema[sk]
128
129 t.Run(fmt.Sprintf("schema=%d/cleanup", sk), func(t *testing.T) {
130 steps := len(ss[name].Box.List())
131
132 migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
133 for step := 0; step < steps; step++ {
134 t.Run(fmt.Sprintf("down=%d", step), func(t *testing.T) {
135 n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Down, 1)
136 require.NoError(t, err)
137 require.Equal(t, 1, n)
138 })
139 }
140 })
141 }
142
143 cleanup(t, db)
144 })
145 }
146 }
147
View as plain text