...
1 package embeddedpostgres
2
3 import (
4 "archive/zip"
5 "bytes"
6 "crypto/sha256"
7 "encoding/hex"
8 "errors"
9 "fmt"
10 "io"
11 "log"
12 "net/http"
13 "os"
14 "path/filepath"
15 "strings"
16 )
17
18
19 type RemoteFetchStrategy func() error
20
21
22 func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
23 return func() error {
24 operatingSystem, architecture, version := versionStrategy()
25
26 jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
27 remoteFetchHost,
28 operatingSystem,
29 architecture,
30 version,
31 operatingSystem,
32 architecture,
33 version)
34
35 jarDownloadResponse, err := http.Get(jarDownloadURL)
36 if err != nil {
37 return fmt.Errorf("unable to connect to %s", remoteFetchHost)
38 }
39
40 defer closeBody(jarDownloadResponse)()
41
42 if jarDownloadResponse.StatusCode != http.StatusOK {
43 return fmt.Errorf("no version found matching %s", version)
44 }
45
46 jarBodyBytes, err := io.ReadAll(jarDownloadResponse.Body)
47 if err != nil {
48 return errorFetchingPostgres(err)
49 }
50
51 shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
52 shaDownloadResponse, err := http.Get(shaDownloadURL)
53
54 defer closeBody(shaDownloadResponse)()
55
56 if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
57 if shaBodyBytes, err := io.ReadAll(shaDownloadResponse.Body); err == nil {
58 jarChecksum := sha256.Sum256(jarBodyBytes)
59 if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
60 return errors.New("downloaded checksums do not match")
61 }
62 }
63 }
64
65 return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
66 }
67 }
68
69 func closeBody(resp *http.Response) func() {
70 return func() {
71 if err := resp.Body.Close(); err != nil {
72 log.Fatal(err)
73 }
74 }
75 }
76
77 func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
78 size := contentLength
79
80
81
82 if contentLength < 0 {
83 size = int64(len(bodyBytes))
84 }
85 zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), size)
86 if err != nil {
87 return errorFetchingPostgres(err)
88 }
89
90 cacheLocation, _ := cacheLocator()
91
92 if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
93 return errorExtractingPostgres(err)
94 }
95
96 for _, file := range zipReader.File {
97 if !file.FileHeader.FileInfo().IsDir() && strings.HasSuffix(file.FileHeader.Name, ".txz") {
98 if err := decompressSingleFile(file, cacheLocation); err != nil {
99 return err
100 }
101
102
103 return nil
104 }
105 }
106
107 return fmt.Errorf("error fetching postgres: cannot find binary in archive retrieved from %s", downloadURL)
108 }
109
110 func decompressSingleFile(file *zip.File, cacheLocation string) error {
111 renamed := false
112
113 archiveReader, err := file.Open()
114 if err != nil {
115 return errorExtractingPostgres(err)
116 }
117
118 archiveBytes, err := io.ReadAll(archiveReader)
119 if err != nil {
120 return errorExtractingPostgres(err)
121 }
122
123
124
125
126 tmp, err := os.CreateTemp(filepath.Dir(cacheLocation), "temp_")
127 if err != nil {
128 return errorExtractingPostgres(err)
129 }
130 defer func() {
131
132
133 if !renamed {
134 if err := os.Remove(tmp.Name()); err != nil {
135 panic(err)
136 }
137 }
138 }()
139
140 if _, err := tmp.Write(archiveBytes); err != nil {
141 return errorExtractingPostgres(err)
142 }
143
144
145
146 if err := tmp.Close(); err != nil {
147 return errorExtractingPostgres(err)
148 }
149
150 if err := renameOrIgnore(tmp.Name(), cacheLocation); err != nil {
151 return errorExtractingPostgres(err)
152 }
153 renamed = true
154
155 return nil
156 }
157
158 func errorExtractingPostgres(err error) error {
159 return fmt.Errorf("unable to extract postgres archive: %s", err)
160 }
161
162 func errorFetchingPostgres(err error) error {
163 return fmt.Errorf("error fetching postgres: %s", err)
164 }
165
View as plain text