...

Source file src/github.com/fergusstrange/embedded-postgres/remote_fetch.go

Documentation: github.com/fergusstrange/embedded-postgres

     1  package embeddedpostgres
     2  
     3  import (
     4  	"archive/zip"
     5  	"bytes"
     6  	"crypto/sha256"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net/http"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  )
    17  
    18  // RemoteFetchStrategy provides a strategy to fetch a Postgres binary so that it is available for use.
    19  type RemoteFetchStrategy func() error
    20  
    21  //nolint:funlen
    22  func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
    23  	return func() error {
    24  		operatingSystem, architecture, version := versionStrategy()
    25  
    26  		jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
    27  			remoteFetchHost,
    28  			operatingSystem,
    29  			architecture,
    30  			version,
    31  			operatingSystem,
    32  			architecture,
    33  			version)
    34  
    35  		jarDownloadResponse, err := http.Get(jarDownloadURL)
    36  		if err != nil {
    37  			return fmt.Errorf("unable to connect to %s", remoteFetchHost)
    38  		}
    39  
    40  		defer closeBody(jarDownloadResponse)()
    41  
    42  		if jarDownloadResponse.StatusCode != http.StatusOK {
    43  			return fmt.Errorf("no version found matching %s", version)
    44  		}
    45  
    46  		jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body)
    47  		if err != nil {
    48  			return errorFetchingPostgres(err)
    49  		}
    50  
    51  		shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
    52  		shaDownloadResponse, err := http.Get(shaDownloadURL)
    53  
    54  		defer closeBody(shaDownloadResponse)()
    55  
    56  		if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
    57  			if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil {
    58  				jarChecksum := sha256.Sum256(jarBodyBytes)
    59  				if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
    60  					return errors.New("downloaded checksums do not match")
    61  				}
    62  			}
    63  		}
    64  
    65  		return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
    66  	}
    67  }
    68  
    69  func closeBody(resp *http.Response) func() {
    70  	return func() {
    71  		if err := resp.Body.Close(); err != nil {
    72  			log.Fatal(err)
    73  		}
    74  	}
    75  }
    76  
    77  func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
    78  	size := contentLength
    79  	// if the content length is not set (i.e. chunked encoding),
    80  	// we need to use the length of the bodyBytes otherwise
    81  	// the unzip operation will fail
    82  	if contentLength < 0 {
    83  		size = int64(len(bodyBytes))
    84  	}
    85  	zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size)
    86  	if err != nil {
    87  		return errorFetchingPostgres(err)
    88  	}
    89  
    90  	cacheLocation, _ := cacheLocator()
    91  
    92  	if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
    93  		return errorExtractingPostgres(err)
    94  	}
    95  
    96  	for _, file := range zipReader.File {
    97  		if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") {
    98  			if err := decompressSingleFile(file, cacheLocation); err != nil {
    99  				return err
   100  			}
   101  
   102  			// we have successfully found the file, return early
   103  			return nil
   104  		}
   105  	}
   106  
   107  	return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL)
   108  }
   109  
   110  func decompressSingleFile(file *zip.File, cacheLocation string) error {
   111  	renamed := false
   112  
   113  	archiveReader, err := file.Open()
   114  	if err != nil {
   115  		return errorExtractingPostgres(err)
   116  	}
   117  
   118  	archiveBytes, err := io.ReadAll(archiveReader)
   119  	if err != nil {
   120  		return errorExtractingPostgres(err)
   121  	}
   122  
   123  	// if multiple processes attempt to extract
   124  	// to prevent file corruption when multiple processes attempt to extract at the same time
   125  	// first to a cache location, and then move the file into place.
   126  	tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_")
   127  	if err != nil {
   128  		return errorExtractingPostgres(err)
   129  	}
   130  	defer func() {
   131  		// if anything failed before the rename then the temporary file should be cleaned up.
   132  		// if the rename was successful then there is no temporary file to remove.
   133  		if !renamed {
   134  			if err := os.Remove(tmp.Name()); err != nil {
   135  				panic(err)
   136  			}
   137  		}
   138  	}()
   139  
   140  	if _, err := tmp.Write(archiveBytes); err != nil {
   141  		return errorExtractingPostgres(err)
   142  	}
   143  
   144  	// Windows cannot rename a file if is it still open.
   145  	// The file needs to be manually closed to allow the rename to happen
   146  	if err := tmp.Close(); err != nil {
   147  		return errorExtractingPostgres(err)
   148  	}
   149  
   150  	if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil {
   151  		return errorExtractingPostgres(err)
   152  	}
   153  	renamed = true
   154  
   155  	return nil
   156  }
   157  
   158  func errorExtractingPostgres(err error) error {
   159  	return fmt.Errorf("unable to extract postgres archive: %s", err)
   160  }
   161  
   162  func errorFetchingPostgres(err error) error {
   163  	return fmt.Errorf("error fetching postgres: %s", err)
   164  }
   165  

View as plain text