...

Source file src/github.com/ory/x/dbal/migratest/helper.go

Documentation: github.com/ory/x/dbal/migratest

     1  package migratest
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/google/uuid"
     9  	"github.com/jmoiron/sqlx"
    10  	migrate "github.com/rubenv/sql-migrate"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/ory/x/dbal"
    14  	"github.com/ory/x/sqlcon/dockertest"
    15  )
    16  
    17  // MigrationSchemas contains several schemas.
    18  type MigrationSchemas []map[string]*dbal.PackrMigrationSource
    19  
    20  // RunPackrMigrationTests runs migration tests from packr migrations.
    21  func RunPackrMigrationTests(
    22  	t *testing.T, schema, data MigrationSchemas,
    23  	init, cleanup func(*testing.T, *sqlx.DB),
    24  	runner func(*testing.T, string, *sqlx.DB, int, int, int),
    25  ) {
    26  	if testing.Short() {
    27  		t.SkipNow()
    28  		return
    29  	}
    30  
    31  	var m sync.Mutex
    32  	var dbs = map[string]*sqlx.DB{}
    33  	var mid = uuid.New().String()
    34  
    35  	var dbnames = map[string]bool{}
    36  	for _, ms := range schema {
    37  		for dbname := range ms {
    38  			dbnames[dbname] = true
    39  		}
    40  	}
    41  
    42  	var connectors []func()
    43  	for dbname := range dbnames {
    44  		switch dbname {
    45  		case dbal.DriverPostgreSQL:
    46  			connectors = append(connectors, func() {
    47  				db, err := dockertest.ConnectToTestPostgreSQL()
    48  				if err != nil {
    49  					t.Fatalf("Could not connect to database: %v", err)
    50  				}
    51  				m.Lock()
    52  				dbs[dbal.DriverPostgreSQL] = db
    53  				m.Unlock()
    54  			})
    55  		case dbal.DriverMySQL:
    56  			connectors = append(connectors, func() {
    57  				db, err := dockertest.ConnectToTestMySQL()
    58  				if err != nil {
    59  					t.Fatalf("Could not connect to database: %v", err)
    60  				}
    61  				m.Lock()
    62  				dbs[dbal.DriverMySQL] = db
    63  				m.Unlock()
    64  			})
    65  		case dbal.DriverCockroachDB:
    66  			connectors = append(connectors, func() {
    67  				db, err := dockertest.ConnectToTestCockroachDB()
    68  				if err != nil {
    69  					t.Fatalf("Could not connect to database: %v", err)
    70  				}
    71  				m.Lock()
    72  				dbs[dbal.DriverCockroachDB] = db
    73  				m.Unlock()
    74  			})
    75  		default:
    76  			panic(fmt.Sprintf("Database name %s unknown", dbname))
    77  		}
    78  	}
    79  
    80  	dockertest.Parallel(connectors)
    81  
    82  	if data != nil {
    83  		require.Equal(t, len(schema), len(data))
    84  	}
    85  
    86  	for name, db := range dbs {
    87  		dialect := db.DriverName()
    88  		if dialect == "pgx" {
    89  			dialect = "postgres"
    90  		}
    91  		t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) {
    92  			init(t, db)
    93  
    94  			for sk, ss := range schema {
    95  				t.Run(fmt.Sprintf("schema=%d/run", sk), func(t *testing.T) {
    96  					steps := len(ss[name].Box.List())
    97  					for step := 0; step < steps; step++ {
    98  						t.Run(fmt.Sprintf("up=%d", step), func(t *testing.T) {
    99  							migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
   100  							n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Up, 1)
   101  							require.NoError(t, err)
   102  							require.Equal(t, n, 1, sk)
   103  
   104  							t.Run(fmt.Sprintf("data=%d", step), func(t *testing.T) {
   105  								if data == nil || data[sk] == nil {
   106  									t.Skip("Skipping data creation because no schema specified...")
   107  									return
   108  								}
   109  
   110  								migrate.SetTable(fmt.Sprintf("%s_%d_data", mid, sk))
   111  								n, err = migrate.ExecMax(db.DB, dialect, data[sk][name], migrate.Up, 1)
   112  								require.NoError(t, err)
   113  								require.Equal(t, 1, n)
   114  							})
   115  						})
   116  					}
   117  
   118  					for step := 0; step < steps; step++ {
   119  						t.Run(fmt.Sprintf("runner=%d", step), func(t *testing.T) {
   120  							runner(t, name, db, sk, step, steps)
   121  						})
   122  					}
   123  				})
   124  			}
   125  
   126  			for sk := len(schema) - 1; sk >= 0; sk-- {
   127  				ss := schema[sk]
   128  
   129  				t.Run(fmt.Sprintf("schema=%d/cleanup", sk), func(t *testing.T) {
   130  					steps := len(ss[name].Box.List())
   131  
   132  					migrate.SetTable(fmt.Sprintf("%s_%d", mid, sk))
   133  					for step := 0; step < steps; step++ {
   134  						t.Run(fmt.Sprintf("down=%d", step), func(t *testing.T) {
   135  							n, err := migrate.ExecMax(db.DB, dialect, ss[name], migrate.Down, 1)
   136  							require.NoError(t, err)
   137  							require.Equal(t, 1, n)
   138  						})
   139  					}
   140  				})
   141  			}
   142  
   143  			cleanup(t, db)
   144  		})
   145  	}
   146  }
   147  

View as plain text