...
1 package embeddedpostgres
2
3 import (
4 "archive/tar"
5 "fmt"
6 "io"
7 "os"
8 "path/filepath"
9
10 "github.com/xi2/xz"
11 )
12
13 func defaultTarReader(xzReader *xz.Reader) (func() (*tar.Header, error), func() io.Reader) {
14 tarReader := tar.NewReader(xzReader)
15
16 return func() (*tar.Header, error) {
17 return tarReader.Next()
18 }, func() io.Reader {
19 return tarReader
20 }
21 }
22
23 func decompressTarXz(tarReader func(*xz.Reader) (func() (*tar.Header, error), func() io.Reader), path, extractPath string) error {
24 tempExtractPath, err := os.MkdirTemp(filepath.Dir(extractPath), "temp_")
25 if err != nil {
26 return errorUnableToExtract(path, extractPath, err)
27 }
28 defer func() {
29 if err := os.RemoveAll(tempExtractPath); err != nil {
30 panic(err)
31 }
32 }()
33
34 tarFile, err := os.Open(path)
35 if err != nil {
36 return errorUnableToExtract(path, extractPath, err)
37 }
38
39 defer func() {
40 if err := tarFile.Close(); err != nil {
41 panic(err)
42 }
43 }()
44
45 xzReader, err := xz.NewReader(tarFile, 0)
46 if err != nil {
47 return errorUnableToExtract(path, extractPath, err)
48 }
49
50 readNext, reader := tarReader(xzReader)
51
52 for {
53 header, err := readNext()
54
55 if err == io.EOF {
56 break
57 }
58
59 if err != nil {
60 return errorExtractingPostgres(err)
61 }
62
63 targetPath := filepath.Join(tempExtractPath, header.Name)
64 finalPath := filepath.Join(extractPath, header.Name)
65
66 if err := os.MkdirAll(filepath.Dir(targetPath), os.ModePerm); err != nil {
67 return errorExtractingPostgres(err)
68 }
69
70 if err := os.MkdirAll(filepath.Dir(finalPath), os.ModePerm); err != nil {
71 return errorExtractingPostgres(err)
72 }
73
74 switch header.Typeflag {
75 case tar.TypeReg:
76 outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode))
77 if err != nil {
78 return errorExtractingPostgres(err)
79 }
80
81 if _, err := io.Copy(outFile, reader()); err != nil {
82 return errorExtractingPostgres(err)
83 }
84
85 if err := outFile.Close(); err != nil {
86 return errorExtractingPostgres(err)
87 }
88 case tar.TypeSymlink:
89 if err := os.RemoveAll(targetPath); err != nil {
90 return errorExtractingPostgres(err)
91 }
92
93 if err := os.Symlink(header.Linkname, targetPath); err != nil {
94 return errorExtractingPostgres(err)
95 }
96
97 case tar.TypeDir:
98 if err := os.MkdirAll(finalPath, os.FileMode(header.Mode)); err != nil {
99 return errorExtractingPostgres(err)
100 }
101 continue
102 }
103
104 if err := renameOrIgnore(targetPath, finalPath); err != nil {
105 return errorExtractingPostgres(err)
106 }
107 }
108
109 return nil
110 }
111
112 func errorUnableToExtract(cacheLocation, binariesPath string, err error) error {
113 return fmt.Errorf("unable to extract postgres archive %s to %s, if running parallel tests, configure RuntimePath to isolate testing directories, %w",
114 cacheLocation,
115 binariesPath,
116 err,
117 )
118 }
119
View as plain text