...

Source file src/github.com/fergusstrange/embedded-postgres/prepare_database.go

Documentation: github.com/fergusstrange/embedded-postgres

     1  package embeddedpostgres
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"os/exec"
    11  	"path/filepath"
    12  
    13  	"github.com/lib/pq"
    14  )
    15  
    16  const (
    17  	fmtCloseDBConn = "unable to close database connection: %w"
    18  	fmtAfterError  = "%v happened after error: %w"
    19  )
    20  
    21  type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error
    22  type createDatabase func(port uint32, username, password, database string) error
    23  
    24  func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error {
    25  	passwordFile, err := createPasswordFile(runtimePath, password)
    26  	if err != nil {
    27  		return err
    28  	}
    29  
    30  	args := []string{
    31  		"-A", "password",
    32  		"-U", username,
    33  		"-D", pgDataDir,
    34  		fmt.Sprintf("--pwfile=%s", passwordFile),
    35  	}
    36  
    37  	if locale != "" {
    38  		args = append(args, fmt.Sprintf("--locale=%s", locale))
    39  	}
    40  
    41  	postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
    42  	postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...)
    43  	postgresInitDBProcess.Stderr = logger
    44  	postgresInitDBProcess.Stdout = logger
    45  
    46  	if err = postgresInitDBProcess.Run(); err != nil {
    47  		logContent, readLogsErr := readLogsOrTimeout(logger) // we want to preserve the original error
    48  		if readLogsErr != nil {
    49  			logContent = []byte(string(logContent) + " - " + readLogsErr.Error())
    50  		}
    51  		return fmt.Errorf("unable to init database using '%s': %w\n%s", postgresInitDBProcess.String(), err, string(logContent))
    52  	}
    53  
    54  	if err = os.Remove(passwordFile); err != nil {
    55  		return fmt.Errorf("unable to remove password file '%v': %w", passwordFile, err)
    56  	}
    57  
    58  	return nil
    59  }
    60  
    61  func createPasswordFile(runtimePath, password string) (string, error) {
    62  	passwordFileLocation := filepath.Join(runtimePath, "pwfile")
    63  	if err := os.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil {
    64  		return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation)
    65  	}
    66  
    67  	return passwordFileLocation, nil
    68  }
    69  
    70  func defaultCreateDatabase(port uint32, username, password, database string) (err error) {
    71  	if database == "postgres" {
    72  		return nil
    73  	}
    74  
    75  	conn, err := openDatabaseConnection(port, username, password, "postgres")
    76  	if err != nil {
    77  		return errorCustomDatabase(database, err)
    78  	}
    79  
    80  	db := sql.OpenDB(conn)
    81  	defer func() {
    82  		err = connectionClose(db, err)
    83  	}()
    84  
    85  	if _, err := db.Exec(fmt.Sprintf("CREATE DATABASE \"%s\"", database)); err != nil {
    86  		return errorCustomDatabase(database, err)
    87  	}
    88  
    89  	return nil
    90  }
    91  
    92  // connectionClose closes the database connection and handles the error of the function that used the database connection
    93  func connectionClose(db io.Closer, err error) error {
    94  	closeErr := db.Close()
    95  	if closeErr != nil {
    96  		closeErr = fmt.Errorf(fmtCloseDBConn, closeErr)
    97  
    98  		if err != nil {
    99  			err = fmt.Errorf(fmtAfterError, closeErr, err)
   100  		} else {
   101  			err = closeErr
   102  		}
   103  	}
   104  
   105  	return err
   106  }
   107  
   108  func healthCheckDatabaseOrTimeout(config Config) error {
   109  	healthCheckSignal := make(chan bool)
   110  
   111  	defer close(healthCheckSignal)
   112  
   113  	timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout)
   114  
   115  	defer cancelFunc()
   116  
   117  	go func() {
   118  		for timeout.Err() == nil {
   119  			if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil {
   120  				continue
   121  			}
   122  			healthCheckSignal <- true
   123  
   124  			break
   125  		}
   126  	}()
   127  
   128  	select {
   129  	case <-healthCheckSignal:
   130  		return nil
   131  	case <-timeout.Done():
   132  		return errors.New("timed out waiting for database to become available")
   133  	}
   134  }
   135  
   136  func healthCheckDatabase(port uint32, database, username, password string) (err error) {
   137  	conn, err := openDatabaseConnection(port, username, password, database)
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	db := sql.OpenDB(conn)
   143  	defer func() {
   144  		err = connectionClose(db, err)
   145  	}()
   146  
   147  	if _, err := db.Query("SELECT 1"); err != nil {
   148  		return err
   149  	}
   150  
   151  	return nil
   152  }
   153  
   154  func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) {
   155  	conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable",
   156  		port,
   157  		username,
   158  		password,
   159  		database))
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	return conn, nil
   165  }
   166  
   167  func errorCustomDatabase(database string, err error) error {
   168  	return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err)
   169  }
   170  

View as plain text