1
2
3
4 package redshift
5
6 import (
7 "context"
8 "database/sql"
9 "fmt"
10 "go.uber.org/atomic"
11 "io"
12 "io/ioutil"
13 nurl "net/url"
14 "strconv"
15 "strings"
16
17 "github.com/golang-migrate/migrate/v4"
18 "github.com/golang-migrate/migrate/v4/database"
19 "github.com/hashicorp/go-multierror"
20 "github.com/lib/pq"
21 )
22
23 func init() {
24 db := Redshift{}
25 database.Register("redshift", &db)
26 }
27
28 var DefaultMigrationsTable = "schema_migrations"
29
30 var (
31 ErrNilConfig = fmt.Errorf("no config")
32 ErrNoDatabaseName = fmt.Errorf("no database name")
33 )
34
35 type Config struct {
36 MigrationsTable string
37 DatabaseName string
38 }
39
40 type Redshift struct {
41 isLocked atomic.Bool
42 conn *sql.Conn
43 db *sql.DB
44
45
46 config *Config
47 }
48
49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
50 if config == nil {
51 return nil, ErrNilConfig
52 }
53
54 if err := instance.Ping(); err != nil {
55 return nil, err
56 }
57
58 if config.DatabaseName == "" {
59 query := `SELECT CURRENT_DATABASE()`
60 var databaseName string
61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
62 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
63 }
64
65 if len(databaseName) == 0 {
66 return nil, ErrNoDatabaseName
67 }
68
69 config.DatabaseName = databaseName
70 }
71
72 if len(config.MigrationsTable) == 0 {
73 config.MigrationsTable = DefaultMigrationsTable
74 }
75
76 conn, err := instance.Conn(context.Background())
77
78 if err != nil {
79 return nil, err
80 }
81
82 px := &Redshift{
83 conn: conn,
84 db: instance,
85 config: config,
86 }
87
88 if err := px.ensureVersionTable(); err != nil {
89 return nil, err
90 }
91
92 return px, nil
93 }
94
95 func (p *Redshift) Open(url string) (database.Driver, error) {
96 purl, err := nurl.Parse(url)
97 if err != nil {
98 return nil, err
99 }
100 purl.Scheme = "postgres"
101
102 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
103 if err != nil {
104 return nil, err
105 }
106
107 migrationsTable := purl.Query().Get("x-migrations-table")
108
109 px, err := WithInstance(db, &Config{
110 DatabaseName: purl.Path,
111 MigrationsTable: migrationsTable,
112 })
113 if err != nil {
114 return nil, err
115 }
116
117 return px, nil
118 }
119
120 func (p *Redshift) Close() error {
121 connErr := p.conn.Close()
122 dbErr := p.db.Close()
123 if connErr != nil || dbErr != nil {
124 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
125 }
126 return nil
127 }
128
129
130 func (p *Redshift) Lock() error {
131 if !p.isLocked.CAS(false, true) {
132 return database.ErrLocked
133 }
134 return nil
135 }
136
137 func (p *Redshift) Unlock() error {
138 if !p.isLocked.CAS(true, false) {
139 return database.ErrNotLocked
140 }
141 return nil
142 }
143
144 func (p *Redshift) Run(migration io.Reader) error {
145 migr, err := ioutil.ReadAll(migration)
146 if err != nil {
147 return err
148 }
149
150
151 query := string(migr[:])
152 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
153 if pgErr, ok := err.(*pq.Error); ok {
154 var line uint
155 var col uint
156 var lineColOK bool
157 if pgErr.Position != "" {
158 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
159 line, col, lineColOK = computeLineFromPos(query, int(pos))
160 }
161 }
162 message := fmt.Sprintf("migration failed: %s", pgErr.Message)
163 if lineColOK {
164 message = fmt.Sprintf("%s (column %d)", message, col)
165 }
166 if pgErr.Detail != "" {
167 message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
168 }
169 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
170 }
171 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
172 }
173
174 return nil
175 }
176
177 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
178
179 s = strings.Replace(s, "\r\n", "\n", -1)
180
181 runes := []rune(s)
182 if pos > len(runes) {
183 return 0, 0, false
184 }
185 sel := runes[:pos]
186 line = uint(runesCount(sel, newLine) + 1)
187 col = uint(pos - 1 - runesLastIndex(sel, newLine))
188 return line, col, true
189 }
190
191 const newLine = '\n'
192
193 func runesCount(input []rune, target rune) int {
194 var count int
195 for _, r := range input {
196 if r == target {
197 count++
198 }
199 }
200 return count
201 }
202
203 func runesLastIndex(input []rune, target rune) int {
204 for i := len(input) - 1; i >= 0; i-- {
205 if input[i] == target {
206 return i
207 }
208 }
209 return -1
210 }
211
212 func (p *Redshift) SetVersion(version int, dirty bool) error {
213 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
214 if err != nil {
215 return &database.Error{OrigErr: err, Err: "transaction start failed"}
216 }
217
218 query := `DELETE FROM "` + p.config.MigrationsTable + `"`
219 if _, err := tx.Exec(query); err != nil {
220 if errRollback := tx.Rollback(); errRollback != nil {
221 err = multierror.Append(err, errRollback)
222 }
223 return &database.Error{OrigErr: err, Query: []byte(query)}
224 }
225
226
227
228
229 if version >= 0 || (version == database.NilVersion && dirty) {
230 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
231 if _, err := tx.Exec(query, version, dirty); err != nil {
232 if errRollback := tx.Rollback(); errRollback != nil {
233 err = multierror.Append(err, errRollback)
234 }
235 return &database.Error{OrigErr: err, Query: []byte(query)}
236 }
237 }
238
239 if err := tx.Commit(); err != nil {
240 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
241 }
242
243 return nil
244 }
245
246 func (p *Redshift) Version() (version int, dirty bool, err error) {
247 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
248 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
249 switch {
250 case err == sql.ErrNoRows:
251 return database.NilVersion, false, nil
252
253 case err != nil:
254 if e, ok := err.(*pq.Error); ok {
255 if e.Code.Name() == "undefined_table" {
256 return database.NilVersion, false, nil
257 }
258 }
259 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
260
261 default:
262 return version, dirty, nil
263 }
264 }
265
266 func (p *Redshift) Drop() (err error) {
267
268 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
269 tables, err := p.conn.QueryContext(context.Background(), query)
270 if err != nil {
271 return &database.Error{OrigErr: err, Query: []byte(query)}
272 }
273 defer func() {
274 if errClose := tables.Close(); errClose != nil {
275 err = multierror.Append(err, errClose)
276 }
277 }()
278
279
280 tableNames := make([]string, 0)
281 for tables.Next() {
282 var tableName string
283 if err := tables.Scan(&tableName); err != nil {
284 return err
285 }
286 if len(tableName) > 0 {
287 tableNames = append(tableNames, tableName)
288 }
289 }
290 if err := tables.Err(); err != nil {
291 return &database.Error{OrigErr: err, Query: []byte(query)}
292 }
293
294 if len(tableNames) > 0 {
295
296 for _, t := range tableNames {
297 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
298 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
299 return &database.Error{OrigErr: err, Query: []byte(query)}
300 }
301 }
302 }
303
304 return nil
305 }
306
307
308
309
310 func (p *Redshift) ensureVersionTable() (err error) {
311 if err = p.Lock(); err != nil {
312 return err
313 }
314
315 defer func() {
316 if e := p.Unlock(); e != nil {
317 if err == nil {
318 err = e
319 } else {
320 err = multierror.Append(err, e)
321 }
322 }
323 }()
324
325
326 var count int
327 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
328 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
329 return &database.Error{OrigErr: err, Query: []byte(query)}
330 }
331 if count == 1 {
332 return nil
333 }
334
335
336 query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
337 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
338 return &database.Error{OrigErr: err, Query: []byte(query)}
339 }
340 return nil
341 }
342
View as plain text