1 package btesting
2
3 import (
4 "flag"
5 "fmt"
6 "os"
7 "path/filepath"
8 "regexp"
9 "testing"
10 "time"
11
12 "github.com/stretchr/testify/require"
13
14 bolt "go.etcd.io/bbolt"
15 )
16
17 var statsFlag = flag.Bool("stats", false, "show performance stats")
18
19
20 const TestFreelistType = "TEST_FREELIST_TYPE"
21
22
23 type DB struct {
24 *bolt.DB
25 f string
26 o *bolt.Options
27 t testing.TB
28 }
29
30
31 func MustCreateDB(t testing.TB) *DB {
32 return MustCreateDBWithOption(t, nil)
33 }
34
35
36 func MustCreateDBWithOption(t testing.TB, o *bolt.Options) *DB {
37 f := filepath.Join(t.TempDir(), "db")
38 return MustOpenDBWithOption(t, f, o)
39 }
40
41 func MustOpenDBWithOption(t testing.TB, f string, o *bolt.Options) *DB {
42 t.Logf("Opening bbolt DB at: %s", f)
43 if o == nil {
44 o = bolt.DefaultOptions
45 }
46
47 freelistType := bolt.FreelistArrayType
48 if env := os.Getenv(TestFreelistType); env == string(bolt.FreelistMapType) {
49 freelistType = bolt.FreelistMapType
50 }
51
52 o.FreelistType = freelistType
53
54 db, err := bolt.Open(f, 0666, o)
55 require.NoError(t, err)
56 resDB := &DB{
57 DB: db,
58 f: f,
59 o: o,
60 t: t,
61 }
62 t.Cleanup(resDB.PostTestCleanup)
63 return resDB
64 }
65
66 func (db *DB) PostTestCleanup() {
67
68 if db.DB != nil {
69 db.MustCheck()
70 db.MustClose()
71 }
72 }
73
74
75 func (db *DB) Close() error {
76 if db.DB != nil {
77
78 if *statsFlag {
79 db.PrintStats()
80 }
81 db.t.Logf("Closing bbolt DB at: %s", db.f)
82 err := db.DB.Close()
83 if err != nil {
84 return err
85 }
86 db.DB = nil
87 }
88 return nil
89 }
90
91
92 func (db *DB) MustClose() {
93 err := db.Close()
94 require.NoError(db.t, err)
95 }
96
97 func (db *DB) MustDeleteFile() {
98 err := os.Remove(db.Path())
99 require.NoError(db.t, err)
100 }
101
102 func (db *DB) SetOptions(o *bolt.Options) {
103 db.o = o
104 }
105
106
107 func (db *DB) MustReopen() {
108 if db.DB != nil {
109 panic("Please call Close() before MustReopen()")
110 }
111 db.t.Logf("Reopening bbolt DB at: %s", db.f)
112 indb, err := bolt.Open(db.Path(), 0666, db.o)
113 require.NoError(db.t, err)
114 db.DB = indb
115 }
116
117
118 func (db *DB) MustCheck() {
119 err := db.Update(func(tx *bolt.Tx) error {
120
121 var errors []error
122 for err := range tx.Check() {
123 errors = append(errors, err)
124 if len(errors) > 10 {
125 break
126 }
127 }
128
129
130 if len(errors) > 0 {
131 var path = filepath.Join(db.t.TempDir(), "db.backup")
132 err := tx.CopyFile(path, 0600)
133 require.NoError(db.t, err)
134
135
136 fmt.Print("\n\n")
137 fmt.Printf("consistency check failed (%d errors)\n", len(errors))
138 for _, err := range errors {
139 fmt.Println(err)
140 }
141 fmt.Println("")
142 fmt.Println("db saved to:")
143 fmt.Println(path)
144 fmt.Print("\n\n")
145 os.Exit(-1)
146 }
147
148 return nil
149 })
150 require.NoError(db.t, err)
151 }
152
153
154 func (db *DB) Fill(bucket []byte, numTx int, numKeysPerTx int,
155 keyGen func(tx int, key int) []byte,
156 valueGen func(tx int, key int) []byte) error {
157 for tr := 0; tr < numTx; tr++ {
158 err := db.Update(func(tx *bolt.Tx) error {
159 b, _ := tx.CreateBucketIfNotExists(bucket)
160 for i := 0; i < numKeysPerTx; i++ {
161 if err := b.Put(keyGen(tr, i), valueGen(tr, i)); err != nil {
162 return err
163 }
164 }
165 return nil
166 })
167 if err != nil {
168 return err
169 }
170 }
171 return nil
172 }
173
174 func (db *DB) Path() string {
175 return db.f
176 }
177
178
179 func (db *DB) CopyTempFile() {
180 path := filepath.Join(db.t.TempDir(), "db.copy")
181 err := db.View(func(tx *bolt.Tx) error {
182 return tx.CopyFile(path, 0600)
183 })
184 require.NoError(db.t, err)
185 fmt.Println("db copied to: ", path)
186 }
187
188
189 func (db *DB) PrintStats() {
190 var stats = db.Stats()
191 fmt.Printf("[db] %-20s %-20s %-20s\n",
192 fmt.Sprintf("pg(%d/%d)", stats.TxStats.GetPageCount(), stats.TxStats.GetPageAlloc()),
193 fmt.Sprintf("cur(%d)", stats.TxStats.GetCursorCount()),
194 fmt.Sprintf("node(%d/%d)", stats.TxStats.GetNodeCount(), stats.TxStats.GetNodeDeref()),
195 )
196 fmt.Printf(" %-20s %-20s %-20s\n",
197 fmt.Sprintf("rebal(%d/%v)", stats.TxStats.GetRebalance(), truncDuration(stats.TxStats.GetRebalanceTime())),
198 fmt.Sprintf("spill(%d/%v)", stats.TxStats.GetSpill(), truncDuration(stats.TxStats.GetSpillTime())),
199 fmt.Sprintf("w(%d/%v)", stats.TxStats.GetWrite(), truncDuration(stats.TxStats.GetWriteTime())),
200 )
201 }
202
203 func truncDuration(d time.Duration) string {
204 return regexp.MustCompile(`^(\d+)(\.\d+)`).ReplaceAllString(d.String(), "$1")
205 }
206
View as plain text