...

Source file src/github.com/fergusstrange/embedded-postgres/remote_fetch_test.go

Documentation: github.com/fergusstrange/embedded-postgres

     1  package embeddedpostgres
     2  
     3  import (
     4  	"archive/zip"
     5  	"crypto/sha256"
     6  	"encoding/hex"
     7  	"github.com/stretchr/testify/require"
     8  	"io"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  )
    19  
    20  func Test_defaultRemoteFetchStrategy_ErrorWhenHttpGet(t *testing.T) {
    21  	remoteFetchStrategy := defaultRemoteFetchStrategy("http://localhost:1234/maven2",
    22  		testVersionStrategy(),
    23  		testCacheLocator())
    24  
    25  	err := remoteFetchStrategy()
    26  
    27  	assert.EqualError(t, err, "unable to connect to http://localhost:1234/maven2")
    28  }
    29  
    30  func Test_defaultRemoteFetchStrategy_ErrorWhenHttpStatusNot200(t *testing.T) {
    31  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  		w.WriteHeader(http.StatusNotFound)
    33  	}))
    34  	defer server.Close()
    35  
    36  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL,
    37  		testVersionStrategy(),
    38  		testCacheLocator())
    39  
    40  	err := remoteFetchStrategy()
    41  
    42  	assert.EqualError(t, err, "no version found matching 1.2.3")
    43  }
    44  
    45  func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) {
    46  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    47  		w.Header().Set("Content-Length", "1")
    48  	}))
    49  	defer server.Close()
    50  
    51  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
    52  		testVersionStrategy(),
    53  		testCacheLocator())
    54  
    55  	err := remoteFetchStrategy()
    56  
    57  	assert.EqualError(t, err, "error fetching postgres: unexpected EOF")
    58  }
    59  
    60  func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
    61  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    62  		if strings.HasSuffix(r.RequestURI, ".sha256") {
    63  			w.WriteHeader(http.StatusNotFound)
    64  			return
    65  		}
    66  	}))
    67  	defer server.Close()
    68  
    69  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
    70  		testVersionStrategy(),
    71  		testCacheLocator())
    72  
    73  	err := remoteFetchStrategy()
    74  
    75  	assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file")
    76  }
    77  
    78  func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
    79  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    80  		if strings.HasSuffix(r.RequestURI, ".sha256") {
    81  			w.WriteHeader(404)
    82  			return
    83  		}
    84  
    85  		if _, err := w.Write([]byte("lolz")); err != nil {
    86  			panic(err)
    87  		}
    88  	}))
    89  	defer server.Close()
    90  
    91  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
    92  		testVersionStrategy(),
    93  		testCacheLocator())
    94  
    95  	err := remoteFetchStrategy()
    96  
    97  	assert.EqualError(t, err, "error fetching postgres: zip: not a valid zip file")
    98  }
    99  
   100  func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) {
   101  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   102  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   103  			w.WriteHeader(http.StatusNotFound)
   104  			return
   105  		}
   106  
   107  		MyZipWriter := zip.NewWriter(w)
   108  
   109  		if err := MyZipWriter.Close(); err != nil {
   110  			t.Error(err)
   111  		}
   112  	}))
   113  	defer server.Close()
   114  
   115  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   116  		testVersionStrategy(),
   117  		testCacheLocator())
   118  
   119  	err := remoteFetchStrategy()
   120  
   121  	assert.EqualError(t, err, "error fetching postgres: cannot find binary in archive retrieved from "+server.URL+"/maven2/io/zonky/test/postgres/embedded-postgres-binaries-darwin-amd64/1.2.3/embedded-postgres-binaries-darwin-amd64-1.2.3.jar")
   122  }
   123  
   124  func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing.T) {
   125  	jarFile, cleanUp := createTempZipArchive()
   126  	defer cleanUp()
   127  
   128  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   129  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   130  			w.WriteHeader(http.StatusNotFound)
   131  			return
   132  		}
   133  
   134  		bytes, err := os.ReadFile(jarFile)
   135  		if err != nil {
   136  			panic(err)
   137  		}
   138  		if _, err := w.Write(bytes); err != nil {
   139  			panic(err)
   140  		}
   141  	}))
   142  	defer server.Close()
   143  
   144  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   145  		testVersionStrategy(),
   146  		func() (s string, b bool) {
   147  			return filepath.FromSlash("/invalid"), false
   148  		})
   149  
   150  	err := remoteFetchStrategy()
   151  
   152  	assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
   153  }
   154  
   155  func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *testing.T) {
   156  	jarFile, cleanUp := createTempZipArchive()
   157  	defer cleanUp()
   158  
   159  	fileBlockingExtractDirectory := filepath.Join(filepath.Dir(jarFile), "a_file_blocking_extract")
   160  
   161  	if _, err := os.Create(fileBlockingExtractDirectory); err != nil {
   162  		panic(err)
   163  	}
   164  
   165  	cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar")
   166  
   167  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   168  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   169  			w.WriteHeader(http.StatusNotFound)
   170  			return
   171  		}
   172  
   173  		bytes, err := os.ReadFile(jarFile)
   174  		if err != nil {
   175  			panic(err)
   176  		}
   177  		if _, err := w.Write(bytes); err != nil {
   178  			panic(err)
   179  		}
   180  	}))
   181  
   182  	defer server.Close()
   183  
   184  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   185  		testVersionStrategy(),
   186  		func() (s string, b bool) {
   187  			return cacheLocation, false
   188  		})
   189  
   190  	err := remoteFetchStrategy()
   191  
   192  	assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
   193  }
   194  
   195  func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *testing.T) {
   196  	jarFile, cleanUp := createTempZipArchive()
   197  	defer cleanUp()
   198  
   199  	cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_directory", "cache_file.jar")
   200  
   201  	if err := os.MkdirAll(cacheLocation, os.ModePerm); err != nil {
   202  		panic(err)
   203  	}
   204  
   205  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   206  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   207  			w.WriteHeader(http.StatusNotFound)
   208  			return
   209  		}
   210  
   211  		bytes, err := os.ReadFile(jarFile)
   212  		if err != nil {
   213  			panic(err)
   214  		}
   215  		if _, err := w.Write(bytes); err != nil {
   216  			panic(err)
   217  		}
   218  	}))
   219  	defer server.Close()
   220  
   221  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   222  		testVersionStrategy(),
   223  		func() (s string, b bool) {
   224  			return "/\\000", false
   225  		})
   226  
   227  	err := remoteFetchStrategy()
   228  
   229  	assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
   230  }
   231  
   232  func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
   233  	jarFile, cleanUp := createTempZipArchive()
   234  	defer cleanUp()
   235  
   236  	cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")
   237  
   238  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   239  		bytes, err := os.ReadFile(jarFile)
   240  		if err != nil {
   241  			panic(err)
   242  		}
   243  
   244  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   245  			w.WriteHeader(200)
   246  			if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil {
   247  				panic(err)
   248  			}
   249  
   250  			return
   251  		}
   252  
   253  		if _, err := w.Write(bytes); err != nil {
   254  			panic(err)
   255  		}
   256  	}))
   257  	defer server.Close()
   258  
   259  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   260  		testVersionStrategy(),
   261  		func() (s string, b bool) {
   262  			return cacheLocation, false
   263  		})
   264  
   265  	err := remoteFetchStrategy()
   266  
   267  	assert.EqualError(t, err, "downloaded checksums do not match")
   268  }
   269  
   270  func Test_defaultRemoteFetchStrategy(t *testing.T) {
   271  	jarFile, cleanUp := createTempZipArchive()
   272  	defer cleanUp()
   273  
   274  	cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")
   275  
   276  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   277  		bytes, err := os.ReadFile(jarFile)
   278  		if err != nil {
   279  			panic(err)
   280  		}
   281  
   282  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   283  			w.WriteHeader(200)
   284  			contentHash := sha256.Sum256(bytes)
   285  			if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
   286  				panic(err)
   287  			}
   288  
   289  			return
   290  		}
   291  
   292  		if _, err := w.Write(bytes); err != nil {
   293  			panic(err)
   294  		}
   295  	}))
   296  	defer server.Close()
   297  
   298  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   299  		testVersionStrategy(),
   300  		func() (s string, b bool) {
   301  			return cacheLocation, false
   302  		})
   303  
   304  	err := remoteFetchStrategy()
   305  
   306  	assert.NoError(t, err)
   307  	assert.FileExists(t, cacheLocation)
   308  }
   309  
   310  func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) {
   311  	jarFile, cleanUp := createTempZipArchive()
   312  	defer cleanUp()
   313  
   314  	// create a temp directory for testing
   315  	tempFile, err := os.MkdirTemp("", "cache_output")
   316  	if err != nil {
   317  		panic(err)
   318  	}
   319  	// clean up once the test is finished.
   320  	defer func() {
   321  		if err := os.RemoveAll(tempFile); err != nil {
   322  			panic(err)
   323  		}
   324  	}()
   325  
   326  	cacheLocation := path.Join(tempFile, "temp.jar")
   327  
   328  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   329  		bytes, err := os.ReadFile(jarFile)
   330  		if err != nil {
   331  			panic(err)
   332  		}
   333  
   334  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   335  			w.WriteHeader(200)
   336  			contentHash := sha256.Sum256(bytes)
   337  			if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
   338  				panic(err)
   339  			}
   340  
   341  			return
   342  		}
   343  
   344  		if _, err := w.Write(bytes); err != nil {
   345  			panic(err)
   346  		}
   347  	}))
   348  	defer server.Close()
   349  
   350  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   351  		testVersionStrategy(),
   352  		func() (s string, b bool) {
   353  			return cacheLocation, false
   354  		})
   355  
   356  	// call it the remoteFetchStrategy(). The output location should be empty and a new file created
   357  	err = remoteFetchStrategy()
   358  	assert.NoError(t, err)
   359  	assert.FileExists(t, cacheLocation)
   360  	out1, err := os.ReadFile(cacheLocation)
   361  
   362  	// write some bad data to the file, this helps us test that the file is overwritten
   363  	err = os.WriteFile(cacheLocation, []byte("invalid"), 0600)
   364  	assert.NoError(t, err)
   365  
   366  	// call the remoteFetchStrategy() again, this time the file should be overwritten
   367  	err = remoteFetchStrategy()
   368  	assert.NoError(t, err)
   369  	assert.FileExists(t, cacheLocation)
   370  
   371  	// ensure that the file contents are the same from both downloads, and that it doesn't contain the `invalid` data.
   372  	out2, err := os.ReadFile(cacheLocation)
   373  	assert.Equal(t, out1, out2)
   374  }
   375  
   376  func Test_defaultRemoteFetchStrategy_whenContentLengthNotSet(t *testing.T) {
   377  	jarFile, cleanUp := createTempZipArchive()
   378  	defer cleanUp()
   379  
   380  	cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")
   381  
   382  	bytes, err := os.ReadFile(jarFile)
   383  	if err != nil {
   384  		require.NoError(t, err)
   385  	}
   386  	contentHash := sha256.Sum256(bytes)
   387  
   388  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   389  		if strings.HasSuffix(r.RequestURI, ".sha256") {
   390  			w.WriteHeader(200)
   391  			if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
   392  				panic(err)
   393  			}
   394  
   395  			return
   396  		}
   397  
   398  		f, err := os.Open(jarFile)
   399  		if err != nil {
   400  			panic(err)
   401  		}
   402  
   403  		// stream the file back so that Go uses
   404  		// chunked encoding and never sets Content-Length
   405  		_, _ = io.Copy(w, f)
   406  	}))
   407  	defer server.Close()
   408  
   409  	remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
   410  		testVersionStrategy(),
   411  		func() (s string, b bool) {
   412  			return cacheLocation, false
   413  		})
   414  
   415  	err = remoteFetchStrategy()
   416  
   417  	assert.NoError(t, err)
   418  	assert.FileExists(t, cacheLocation)
   419  }
   420  

View as plain text