...

Source file src/edge-infra.dev/pkg/edge/api/testutils/seededpostgres/db.go

Documentation: edge-infra.dev/pkg/edge/api/testutils/seededpostgres

     1  package seededpostgres
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"net"
     7  	"os"
     8  	"path"
     9  
    10  	edgesql "edge-infra.dev/pkg/edge/api/sql"
    11  	"edge-infra.dev/pkg/edge/api/sql/plugin"
    12  	"edge-infra.dev/pkg/lib/build/bazel"
    13  	"edge-infra.dev/pkg/lib/compression"
    14  	"edge-infra.dev/pkg/lib/gcp/cloudsql"
    15  	"edge-infra.dev/pkg/lib/logging"
    16  
    17  	"github.com/bazelbuild/rules_go/go/runfiles"
    18  	embeddedpostgres "github.com/fergusstrange/embedded-postgres"
    19  	"github.com/golang-migrate/migrate/v4/database/postgres"
    20  )
    21  
    22  var PostgresVersion = embeddedpostgres.V14
    23  
    24  // SeededPostgres wraps a github.com/fergusstrange/embedded-postgres.EmbeddedPostgres object.
    25  type SeededPostgres struct {
    26  	dbname   string
    27  	username string
    28  	password string
    29  	port     int
    30  	tempDir  string
    31  	ep       *embeddedpostgres.EmbeddedPostgres
    32  }
    33  
    34  // New creates an embedded postgres database and seeds it with data.
    35  func New() (*SeededPostgres, error) {
    36  	return NewWithUser("postgres", "postgres", "postgres")
    37  }
    38  
    39  func NewWithUser(dbname, username, password string) (*SeededPostgres, error) {
    40  	if dbname == "" || username == "" || password == "" {
    41  		return nil, fmt.Errorf("NewWithUser arguments must not be empty: dbname=%q username=%q password=%q", dbname, username, password)
    42  	}
    43  
    44  	var cfg = embeddedpostgres.DefaultConfig()
    45  	cfg = cfg.Version(PostgresVersion)
    46  	cfg = cfg.Database(dbname)
    47  	cfg = cfg.Username(username)
    48  	cfg = cfg.Password(password)
    49  
    50  	var port, err = findUnusedPort()
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	cfg = cfg.Port(uint32(port)) /* #nosec G115 */
    55  
    56  	var tempDir string
    57  	if bazel.IsBazelTest() || bazel.IsBazelRun() {
    58  		embeddedTxzFile, err := runfiles.Rlocation(path.Join("edge_infra", "hack", "tools", "postgres.txz"))
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		tempDir, err = bazel.NewTestTmpDir("edge-infra-api-test-*")
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		err = compression.DecompressTarXz(embeddedTxzFile, tempDir)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  
    71  		cfg = cfg.RuntimePath(path.Join(tempDir, "runtime"))
    72  		cfg = cfg.BinariesPath(tempDir)
    73  	}
    74  
    75  	var sp = &SeededPostgres{
    76  		dbname:   dbname,
    77  		username: username,
    78  		password: password,
    79  		port:     port,
    80  		tempDir:  tempDir,
    81  		ep:       embeddedpostgres.NewDatabase(cfg),
    82  	}
    83  
    84  	err = sp.ep.Start()
    85  	if err != nil {
    86  		_ = sp.Close()
    87  		return nil, err
    88  	}
    89  
    90  	db, err := sp.DB()
    91  	if err != nil {
    92  		_ = sp.Close()
    93  		return nil, err
    94  	}
    95  	defer db.Close()
    96  
    97  	driver, err := postgres.WithInstance(db, &postgres.Config{})
    98  	if err != nil {
    99  		_ = sp.Close()
   100  		return nil, err
   101  	}
   102  	defer driver.Close()
   103  
   104  	_, err = db.Exec("CREATE EXTENSION IF NOT EXISTS \"pgcrypto\"")
   105  	if err != nil {
   106  		_ = sp.Close()
   107  		return nil, err
   108  	}
   109  
   110  	var logger = logging.NewLogger().WithName("seededpostgres")
   111  	var pluginConfig = &plugin.Config{
   112  		MigrationAction: "up",
   113  		Ordered:         true,
   114  		TestMode:        true,
   115  		Data:            Seed,
   116  	}
   117  
   118  	err = edgesql.SetupEdgeTables(pluginConfig, driver, logger, db)
   119  	if err != nil {
   120  		_ = sp.Close()
   121  		return nil, err
   122  	}
   123  	return sp, nil
   124  }
   125  
   126  // Close should be called when done testing to stop the embedded postgres database and free up resources used.
   127  func (sp *SeededPostgres) Close() error {
   128  	var errStop error
   129  	if sp.ep != nil {
   130  		// Stop the embedded postgres, but wait to return the error until after deleting tempDir
   131  		errStop = sp.ep.Stop()
   132  	}
   133  
   134  	errDeleteTempDir := os.RemoveAll(sp.tempDir)
   135  	if errDeleteTempDir != nil {
   136  		return errDeleteTempDir
   137  	}
   138  
   139  	return errStop
   140  }
   141  
   142  // DB connects to the database for the desired user. It uses `edge-infra.dev/pkg/edge/api/sql.PostgresConnection` to create the sql.DB object.
   143  func (sp *SeededPostgres) DB() (*sql.DB, error) {
   144  	return sp.EdgePostgres().NewConnection()
   145  }
   146  
   147  func (sp *SeededPostgres) EdgePostgres() *cloudsql.EdgePostgres {
   148  	return cloudsql.PostgresConnection("localhost", fmt.Sprint(sp.port)).DBName(sp.dbname).Username(sp.username).Password(sp.password)
   149  }
   150  
   151  func (sp *SeededPostgres) Port() int {
   152  	return sp.port
   153  }
   154  
   155  func findUnusedPort() (int, error) {
   156  	addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
   157  	if err != nil {
   158  		return 0, err
   159  	}
   160  	l, err := net.ListenTCP("tcp", addr)
   161  	if err != nil {
   162  		return 0, err
   163  	}
   164  	return l.Addr().(*net.TCPAddr).Port, l.Close()
   165  }
   166  

View as plain text