...

Source file src/go.etcd.io/bbolt/internal/btesting/btesting.go

Documentation: go.etcd.io/bbolt/internal/btesting

     1  package btesting
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/require"
    13  
    14  	bolt "go.etcd.io/bbolt"
    15  )
    16  
    17  var statsFlag = flag.Bool("stats", false, "show performance stats")
    18  
    19  // TestFreelistType is used as a env variable for test to indicate the backend type
    20  const TestFreelistType = "TEST_FREELIST_TYPE"
    21  
    22  // DB is a test wrapper for bolt.DB.
    23  type DB struct {
    24  	*bolt.DB
    25  	f string
    26  	o *bolt.Options
    27  	t testing.TB
    28  }
    29  
    30  // MustCreateDB returns a new, open DB at a temporary location.
    31  func MustCreateDB(t testing.TB) *DB {
    32  	return MustCreateDBWithOption(t, nil)
    33  }
    34  
    35  // MustCreateDBWithOption returns a new, open DB at a temporary location with given options.
    36  func MustCreateDBWithOption(t testing.TB, o *bolt.Options) *DB {
    37  	f := filepath.Join(t.TempDir(), "db")
    38  	return MustOpenDBWithOption(t, f, o)
    39  }
    40  
    41  func MustOpenDBWithOption(t testing.TB, f string, o *bolt.Options) *DB {
    42  	t.Logf("Opening bbolt DB at: %s", f)
    43  	if o == nil {
    44  		o = bolt.DefaultOptions
    45  	}
    46  
    47  	freelistType := bolt.FreelistArrayType
    48  	if env := os.Getenv(TestFreelistType); env == string(bolt.FreelistMapType) {
    49  		freelistType = bolt.FreelistMapType
    50  	}
    51  
    52  	o.FreelistType = freelistType
    53  
    54  	db, err := bolt.Open(f, 0666, o)
    55  	require.NoError(t, err)
    56  	resDB := &DB{
    57  		DB: db,
    58  		f:  f,
    59  		o:  o,
    60  		t:  t,
    61  	}
    62  	t.Cleanup(resDB.PostTestCleanup)
    63  	return resDB
    64  }
    65  
    66  func (db *DB) PostTestCleanup() {
    67  	// Check database consistency after every test.
    68  	if db.DB != nil {
    69  		db.MustCheck()
    70  		db.MustClose()
    71  	}
    72  }
    73  
    74  // Close closes the database but does NOT delete the underlying file.
    75  func (db *DB) Close() error {
    76  	if db.DB != nil {
    77  		// Log statistics.
    78  		if *statsFlag {
    79  			db.PrintStats()
    80  		}
    81  		db.t.Logf("Closing bbolt DB at: %s", db.f)
    82  		err := db.DB.Close()
    83  		if err != nil {
    84  			return err
    85  		}
    86  		db.DB = nil
    87  	}
    88  	return nil
    89  }
    90  
    91  // MustClose closes the database but does NOT delete the underlying file.
    92  func (db *DB) MustClose() {
    93  	err := db.Close()
    94  	require.NoError(db.t, err)
    95  }
    96  
    97  func (db *DB) MustDeleteFile() {
    98  	err := os.Remove(db.Path())
    99  	require.NoError(db.t, err)
   100  }
   101  
   102  func (db *DB) SetOptions(o *bolt.Options) {
   103  	db.o = o
   104  }
   105  
   106  // MustReopen reopen the database. Panic on error.
   107  func (db *DB) MustReopen() {
   108  	if db.DB != nil {
   109  		panic("Please call Close() before MustReopen()")
   110  	}
   111  	db.t.Logf("Reopening bbolt DB at: %s", db.f)
   112  	indb, err := bolt.Open(db.Path(), 0666, db.o)
   113  	require.NoError(db.t, err)
   114  	db.DB = indb
   115  }
   116  
   117  // MustCheck runs a consistency check on the database and panics if any errors are found.
   118  func (db *DB) MustCheck() {
   119  	err := db.Update(func(tx *bolt.Tx) error {
   120  		// Collect all the errors.
   121  		var errors []error
   122  		for err := range tx.Check() {
   123  			errors = append(errors, err)
   124  			if len(errors) > 10 {
   125  				break
   126  			}
   127  		}
   128  
   129  		// If errors occurred, copy the DB and print the errors.
   130  		if len(errors) > 0 {
   131  			var path = filepath.Join(db.t.TempDir(), "db.backup")
   132  			err := tx.CopyFile(path, 0600)
   133  			require.NoError(db.t, err)
   134  
   135  			// Print errors.
   136  			fmt.Print("\n\n")
   137  			fmt.Printf("consistency check failed (%d errors)\n", len(errors))
   138  			for _, err := range errors {
   139  				fmt.Println(err)
   140  			}
   141  			fmt.Println("")
   142  			fmt.Println("db saved to:")
   143  			fmt.Println(path)
   144  			fmt.Print("\n\n")
   145  			os.Exit(-1)
   146  		}
   147  
   148  		return nil
   149  	})
   150  	require.NoError(db.t, err)
   151  }
   152  
   153  // Fill - fills the DB using numTx transactions and numKeysPerTx.
   154  func (db *DB) Fill(bucket []byte, numTx int, numKeysPerTx int,
   155  	keyGen func(tx int, key int) []byte,
   156  	valueGen func(tx int, key int) []byte) error {
   157  	for tr := 0; tr < numTx; tr++ {
   158  		err := db.Update(func(tx *bolt.Tx) error {
   159  			b, _ := tx.CreateBucketIfNotExists(bucket)
   160  			for i := 0; i < numKeysPerTx; i++ {
   161  				if err := b.Put(keyGen(tr, i), valueGen(tr, i)); err != nil {
   162  					return err
   163  				}
   164  			}
   165  			return nil
   166  		})
   167  		if err != nil {
   168  			return err
   169  		}
   170  	}
   171  	return nil
   172  }
   173  
   174  func (db *DB) Path() string {
   175  	return db.f
   176  }
   177  
   178  // CopyTempFile copies a database to a temporary file.
   179  func (db *DB) CopyTempFile() {
   180  	path := filepath.Join(db.t.TempDir(), "db.copy")
   181  	err := db.View(func(tx *bolt.Tx) error {
   182  		return tx.CopyFile(path, 0600)
   183  	})
   184  	require.NoError(db.t, err)
   185  	fmt.Println("db copied to: ", path)
   186  }
   187  
   188  // PrintStats prints the database stats
   189  func (db *DB) PrintStats() {
   190  	var stats = db.Stats()
   191  	fmt.Printf("[db] %-20s %-20s %-20s\n",
   192  		fmt.Sprintf("pg(%d/%d)", stats.TxStats.GetPageCount(), stats.TxStats.GetPageAlloc()),
   193  		fmt.Sprintf("cur(%d)", stats.TxStats.GetCursorCount()),
   194  		fmt.Sprintf("node(%d/%d)", stats.TxStats.GetNodeCount(), stats.TxStats.GetNodeDeref()),
   195  	)
   196  	fmt.Printf("     %-20s %-20s %-20s\n",
   197  		fmt.Sprintf("rebal(%d/%v)", stats.TxStats.GetRebalance(), truncDuration(stats.TxStats.GetRebalanceTime())),
   198  		fmt.Sprintf("spill(%d/%v)", stats.TxStats.GetSpill(), truncDuration(stats.TxStats.GetSpillTime())),
   199  		fmt.Sprintf("w(%d/%v)", stats.TxStats.GetWrite(), truncDuration(stats.TxStats.GetWriteTime())),
   200  	)
   201  }
   202  
   203  func truncDuration(d time.Duration) string {
   204  	return regexp.MustCompile(`^(\d+)(\.\d+)`).ReplaceAllString(d.String(), "$1")
   205  }
   206  

View as plain text