1 package test
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "io"
8 "testing"
9 )
10
11 var (
12 _ CleanUpDB = &sql.DB{}
13 )
14
15
16
17
18 type CleanUpDB interface {
19 BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
20 ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
21 QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
22
23 io.Closer
24 }
25
26
27
28
29
30 func ResetBoulderTestDatabase(t testing.TB) func() {
31 return resetTestDatabase(t, context.Background(), "boulder")
32 }
33
34
35
36
37
38
39 func ResetIncidentsTestDatabase(t testing.TB) func() {
40 return resetTestDatabase(t, context.Background(), "incidents")
41 }
42
43 func resetTestDatabase(t testing.TB, ctx context.Context, dbPrefix string) func() {
44 db, err := sql.Open("mysql", fmt.Sprintf("test_setup@tcp(boulder-proxysql:6033)/%s_sa_test", dbPrefix))
45 if err != nil {
46 t.Fatalf("Couldn't create db: %s", err)
47 }
48 err = deleteEverythingInAllTables(ctx, db)
49 if err != nil {
50 t.Fatalf("Failed to delete everything: %s", err)
51 }
52 return func() {
53 err := deleteEverythingInAllTables(ctx, db)
54 if err != nil {
55 t.Fatalf("Failed to truncate tables after the test: %s", err)
56 }
57 _ = db.Close()
58 }
59 }
60
61
62
63
64
65 func deleteEverythingInAllTables(ctx context.Context, db CleanUpDB) error {
66 ts, err := allTableNamesInDB(ctx, db)
67 if err != nil {
68 return err
69 }
70 for _, tn := range ts {
71
72
73
74
75
76 tx, err := db.BeginTx(ctx, nil)
77 if err != nil {
78 return fmt.Errorf("unable to start transaction to delete all rows from table %#v: %s", tn, err)
79 }
80 _, err = tx.ExecContext(ctx, "set FOREIGN_KEY_CHECKS = 0")
81 if err != nil {
82 return fmt.Errorf("unable to disable FOREIGN_KEY_CHECKS to delete all rows from table %#v: %s", tn, err)
83 }
84
85
86
87 _, err = tx.ExecContext(ctx, "delete from `"+tn+"` where 1 = 1")
88 if err != nil {
89 return fmt.Errorf("unable to delete all rows from table %#v: %s", tn, err)
90 }
91 _, err = tx.ExecContext(ctx, "set FOREIGN_KEY_CHECKS = 1")
92 if err != nil {
93 return fmt.Errorf("unable to re-enable FOREIGN_KEY_CHECKS to delete all rows from table %#v: %s", tn, err)
94 }
95 err = tx.Commit()
96 if err != nil {
97 return fmt.Errorf("unable to commit transaction to delete all rows from table %#v: %s", tn, err)
98 }
99
100 _, err = db.ExecContext(ctx, "alter table `"+tn+"` AUTO_INCREMENT = 1")
101 if err != nil {
102 return fmt.Errorf("unable to reset autoincrement on table %#v: %s", tn, err)
103 }
104 }
105 return err
106 }
107
108
109
110
111 func allTableNamesInDB(ctx context.Context, db CleanUpDB) ([]string, error) {
112 r, err := db.QueryContext(ctx, "select table_name from information_schema.tables t where t.table_schema = DATABASE() and t.table_name != 'gorp_migrations';")
113 if err != nil {
114 return nil, err
115 }
116 var ts []string
117 for r.Next() {
118 tableName := ""
119 err = r.Scan(&tableName)
120 if err != nil {
121 return nil, err
122 }
123 ts = append(ts, tableName)
124 }
125 return ts, r.Err()
126 }
127
View as plain text