1 package awss3
2
3 import (
4 "errors"
5 "io/ioutil"
6 "strings"
7 "testing"
8
9 "github.com/aws/aws-sdk-go/aws"
10 "github.com/aws/aws-sdk-go/service/s3"
11 st "github.com/golang-migrate/migrate/v4/source/testing"
12 "github.com/stretchr/testify/assert"
13 )
14
15 func Test(t *testing.T) {
16 s3Client := fakeS3{
17 bucket: "some-bucket",
18 objects: map[string]string{
19 "staging/migrations/1_foobar.up.sql": "1 up",
20 "staging/migrations/1_foobar.down.sql": "1 down",
21 "prod/migrations/1_foobar.up.sql": "1 up",
22 "prod/migrations/1_foobar.down.sql": "1 down",
23 "prod/migrations/3_foobar.up.sql": "3 up",
24 "prod/migrations/4_foobar.up.sql": "4 up",
25 "prod/migrations/4_foobar.down.sql": "4 down",
26 "prod/migrations/5_foobar.down.sql": "5 down",
27 "prod/migrations/7_foobar.up.sql": "7 up",
28 "prod/migrations/7_foobar.down.sql": "7 down",
29 "prod/migrations/not-a-migration.txt": "",
30 "prod/migrations/0-random-stuff/whatever.txt": "",
31 },
32 }
33 driver, err := WithInstance(&s3Client, &Config{
34 Bucket: "some-bucket",
35 Prefix: "prod/migrations/",
36 })
37 if err != nil {
38 t.Fatal(err)
39 }
40 st.Test(t, driver)
41 }
42
43 func TestParseURI(t *testing.T) {
44 tests := []struct {
45 name string
46 uri string
47 config *Config
48 }{
49 {
50 "with prefix, no trailing slash",
51 "s3://migration-bucket/production",
52 &Config{
53 Bucket: "migration-bucket",
54 Prefix: "production/",
55 },
56 },
57 {
58 "without prefix, no trailing slash",
59 "s3://migration-bucket",
60 &Config{
61 Bucket: "migration-bucket",
62 },
63 },
64 {
65 "with prefix, trailing slash",
66 "s3://migration-bucket/production/",
67 &Config{
68 Bucket: "migration-bucket",
69 Prefix: "production/",
70 },
71 },
72 {
73 "without prefix, trailing slash",
74 "s3://migration-bucket/",
75 &Config{
76 Bucket: "migration-bucket",
77 },
78 },
79 }
80 for _, test := range tests {
81 t.Run(test.name, func(t *testing.T) {
82 actual, err := parseURI(test.uri)
83 if err != nil {
84 t.Fatal(err)
85 }
86 assert.Equal(t, test.config, actual)
87 })
88 }
89 }
90
91 type fakeS3 struct {
92 s3.S3
93 bucket string
94 objects map[string]string
95 }
96
97 func (s *fakeS3) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
98 bucket := aws.StringValue(input.Bucket)
99 if bucket != s.bucket {
100 return nil, errors.New("bucket not found")
101 }
102 prefix := aws.StringValue(input.Prefix)
103 delimiter := aws.StringValue(input.Delimiter)
104 var output s3.ListObjectsOutput
105 for name := range s.objects {
106 if strings.HasPrefix(name, prefix) {
107 if delimiter == "" || !strings.Contains(strings.Replace(name, prefix, "", 1), delimiter) {
108 output.Contents = append(output.Contents, &s3.Object{
109 Key: aws.String(name),
110 })
111 }
112 }
113 }
114 return &output, nil
115 }
116
117 func (s *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
118 bucket := aws.StringValue(input.Bucket)
119 if bucket != s.bucket {
120 return nil, errors.New("bucket not found")
121 }
122 if data, ok := s.objects[aws.StringValue(input.Key)]; ok {
123 body := ioutil.NopCloser(strings.NewReader(data))
124 return &s3.GetObjectOutput{Body: body}, nil
125 }
126 return nil, errors.New("object not found")
127 }
128
View as plain text