1
2
3
4 package postgres
5
6 import (
7 "context"
8 "database/sql"
9 "fmt"
10 "go.uber.org/atomic"
11 "io"
12 "io/ioutil"
13 nurl "net/url"
14 "regexp"
15 "strconv"
16 "strings"
17 "time"
18
19 "github.com/golang-migrate/migrate/v4"
20 "github.com/golang-migrate/migrate/v4/database"
21 "github.com/golang-migrate/migrate/v4/database/multistmt"
22 multierror "github.com/hashicorp/go-multierror"
23 "github.com/lib/pq"
24 )
25
26 func init() {
27 db := Postgres{}
28 database.Register("postgres", &db)
29 database.Register("postgresql", &db)
30 }
31
32 var (
33 multiStmtDelimiter = []byte(";")
34
35 DefaultMigrationsTable = "schema_migrations"
36 DefaultMultiStatementMaxSize = 10 * 1 << 20
37 )
38
39 var (
40 ErrNilConfig = fmt.Errorf("no config")
41 ErrNoDatabaseName = fmt.Errorf("no database name")
42 ErrNoSchema = fmt.Errorf("no schema")
43 ErrDatabaseDirty = fmt.Errorf("database is dirty")
44 )
45
46 type Config struct {
47 MigrationsTable string
48 MigrationsTableQuoted bool
49 MultiStatementEnabled bool
50 DatabaseName string
51 SchemaName string
52 migrationsSchemaName string
53 migrationsTableName string
54 StatementTimeout time.Duration
55 MultiStatementMaxSize int
56 }
57
58 type Postgres struct {
59
60 conn *sql.Conn
61 db *sql.DB
62 isLocked atomic.Bool
63
64
65 config *Config
66 }
67
68 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
69 if config == nil {
70 return nil, ErrNilConfig
71 }
72
73 if err := instance.Ping(); err != nil {
74 return nil, err
75 }
76
77 if config.DatabaseName == "" {
78 query := `SELECT CURRENT_DATABASE()`
79 var databaseName string
80 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
81 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
82 }
83
84 if len(databaseName) == 0 {
85 return nil, ErrNoDatabaseName
86 }
87
88 config.DatabaseName = databaseName
89 }
90
91 if config.SchemaName == "" {
92 query := `SELECT CURRENT_SCHEMA()`
93 var schemaName string
94 if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
95 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
96 }
97
98 if len(schemaName) == 0 {
99 return nil, ErrNoSchema
100 }
101
102 config.SchemaName = schemaName
103 }
104
105 if len(config.MigrationsTable) == 0 {
106 config.MigrationsTable = DefaultMigrationsTable
107 }
108
109 config.migrationsSchemaName = config.SchemaName
110 config.migrationsTableName = config.MigrationsTable
111 if config.MigrationsTableQuoted {
112 re := regexp.MustCompile(`"(.*?)"`)
113 result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
114 config.migrationsTableName = result[len(result)-1][1]
115 if len(result) == 2 {
116 config.migrationsSchemaName = result[0][1]
117 } else if len(result) > 2 {
118 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
119 }
120 }
121
122 conn, err := instance.Conn(context.Background())
123
124 if err != nil {
125 return nil, err
126 }
127
128 px := &Postgres{
129 conn: conn,
130 db: instance,
131 config: config,
132 }
133
134 if err := px.ensureVersionTable(); err != nil {
135 return nil, err
136 }
137
138 return px, nil
139 }
140
141 func (p *Postgres) Open(url string) (database.Driver, error) {
142 purl, err := nurl.Parse(url)
143 if err != nil {
144 return nil, err
145 }
146
147 db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
148 if err != nil {
149 return nil, err
150 }
151
152 migrationsTable := purl.Query().Get("x-migrations-table")
153 migrationsTableQuoted := false
154 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
155 migrationsTableQuoted, err = strconv.ParseBool(s)
156 if err != nil {
157 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
158 }
159 }
160 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
161 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
162 }
163
164 statementTimeoutString := purl.Query().Get("x-statement-timeout")
165 statementTimeout := 0
166 if statementTimeoutString != "" {
167 statementTimeout, err = strconv.Atoi(statementTimeoutString)
168 if err != nil {
169 return nil, err
170 }
171 }
172
173 multiStatementMaxSize := DefaultMultiStatementMaxSize
174 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
175 multiStatementMaxSize, err = strconv.Atoi(s)
176 if err != nil {
177 return nil, err
178 }
179 if multiStatementMaxSize <= 0 {
180 multiStatementMaxSize = DefaultMultiStatementMaxSize
181 }
182 }
183
184 multiStatementEnabled := false
185 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
186 multiStatementEnabled, err = strconv.ParseBool(s)
187 if err != nil {
188 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
189 }
190 }
191
192 px, err := WithInstance(db, &Config{
193 DatabaseName: purl.Path,
194 MigrationsTable: migrationsTable,
195 MigrationsTableQuoted: migrationsTableQuoted,
196 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
197 MultiStatementEnabled: multiStatementEnabled,
198 MultiStatementMaxSize: multiStatementMaxSize,
199 })
200
201 if err != nil {
202 return nil, err
203 }
204
205 return px, nil
206 }
207
208 func (p *Postgres) Close() error {
209 connErr := p.conn.Close()
210 dbErr := p.db.Close()
211 if connErr != nil || dbErr != nil {
212 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
213 }
214 return nil
215 }
216
217
218 func (p *Postgres) Lock() error {
219 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
220 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
221 if err != nil {
222 return err
223 }
224
225
226 query := `SELECT pg_advisory_lock($1)`
227 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
228 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
229 }
230
231 return nil
232 })
233 }
234
235 func (p *Postgres) Unlock() error {
236 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
237 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
238 if err != nil {
239 return err
240 }
241
242 query := `SELECT pg_advisory_unlock($1)`
243 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
244 return &database.Error{OrigErr: err, Query: []byte(query)}
245 }
246 return nil
247 })
248 }
249
250 func (p *Postgres) Run(migration io.Reader) error {
251 if p.config.MultiStatementEnabled {
252 var err error
253 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
254 if err = p.runStatement(m); err != nil {
255 return false
256 }
257 return true
258 }); e != nil {
259 return e
260 }
261 return err
262 }
263 migr, err := ioutil.ReadAll(migration)
264 if err != nil {
265 return err
266 }
267 return p.runStatement(migr)
268 }
269
270 func (p *Postgres) runStatement(statement []byte) error {
271 ctx := context.Background()
272 if p.config.StatementTimeout != 0 {
273 var cancel context.CancelFunc
274 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
275 defer cancel()
276 }
277 query := string(statement)
278 if strings.TrimSpace(query) == "" {
279 return nil
280 }
281 if _, err := p.conn.ExecContext(ctx, query); err != nil {
282 if pgErr, ok := err.(*pq.Error); ok {
283 var line uint
284 var col uint
285 var lineColOK bool
286 if pgErr.Position != "" {
287 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
288 line, col, lineColOK = computeLineFromPos(query, int(pos))
289 }
290 }
291 message := fmt.Sprintf("migration failed: %s", pgErr.Message)
292 if lineColOK {
293 message = fmt.Sprintf("%s (column %d)", message, col)
294 }
295 if pgErr.Detail != "" {
296 message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
297 }
298 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
299 }
300 return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
301 }
302 return nil
303 }
304
305 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
306
307 s = strings.Replace(s, "\r\n", "\n", -1)
308
309 runes := []rune(s)
310 if pos > len(runes) {
311 return 0, 0, false
312 }
313 sel := runes[:pos]
314 line = uint(runesCount(sel, newLine) + 1)
315 col = uint(pos - 1 - runesLastIndex(sel, newLine))
316 return line, col, true
317 }
318
319 const newLine = '\n'
320
321 func runesCount(input []rune, target rune) int {
322 var count int
323 for _, r := range input {
324 if r == target {
325 count++
326 }
327 }
328 return count
329 }
330
331 func runesLastIndex(input []rune, target rune) int {
332 for i := len(input) - 1; i >= 0; i-- {
333 if input[i] == target {
334 return i
335 }
336 }
337 return -1
338 }
339
340 func (p *Postgres) SetVersion(version int, dirty bool) error {
341 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
342 if err != nil {
343 return &database.Error{OrigErr: err, Err: "transaction start failed"}
344 }
345
346 query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
347 if _, err := tx.Exec(query); err != nil {
348 if errRollback := tx.Rollback(); errRollback != nil {
349 err = multierror.Append(err, errRollback)
350 }
351 return &database.Error{OrigErr: err, Query: []byte(query)}
352 }
353
354
355
356
357 if version >= 0 || (version == database.NilVersion && dirty) {
358 query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
359 if _, err := tx.Exec(query, version, dirty); err != nil {
360 if errRollback := tx.Rollback(); errRollback != nil {
361 err = multierror.Append(err, errRollback)
362 }
363 return &database.Error{OrigErr: err, Query: []byte(query)}
364 }
365 }
366
367 if err := tx.Commit(); err != nil {
368 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
369 }
370
371 return nil
372 }
373
374 func (p *Postgres) Version() (version int, dirty bool, err error) {
375 query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
376 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
377 switch {
378 case err == sql.ErrNoRows:
379 return database.NilVersion, false, nil
380
381 case err != nil:
382 if e, ok := err.(*pq.Error); ok {
383 if e.Code.Name() == "undefined_table" {
384 return database.NilVersion, false, nil
385 }
386 }
387 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
388
389 default:
390 return version, dirty, nil
391 }
392 }
393
394 func (p *Postgres) Drop() (err error) {
395
396 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
397 tables, err := p.conn.QueryContext(context.Background(), query)
398 if err != nil {
399 return &database.Error{OrigErr: err, Query: []byte(query)}
400 }
401 defer func() {
402 if errClose := tables.Close(); errClose != nil {
403 err = multierror.Append(err, errClose)
404 }
405 }()
406
407
408 tableNames := make([]string, 0)
409 for tables.Next() {
410 var tableName string
411 if err := tables.Scan(&tableName); err != nil {
412 return err
413 }
414 if len(tableName) > 0 {
415 tableNames = append(tableNames, tableName)
416 }
417 }
418 if err := tables.Err(); err != nil {
419 return &database.Error{OrigErr: err, Query: []byte(query)}
420 }
421
422 if len(tableNames) > 0 {
423
424 for _, t := range tableNames {
425 query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
426 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
427 return &database.Error{OrigErr: err, Query: []byte(query)}
428 }
429 }
430 }
431
432 return nil
433 }
434
435
436
437
438 func (p *Postgres) ensureVersionTable() (err error) {
439 if err = p.Lock(); err != nil {
440 return err
441 }
442
443 defer func() {
444 if e := p.Unlock(); e != nil {
445 if err == nil {
446 err = e
447 } else {
448 err = multierror.Append(err, e)
449 }
450 }
451 }()
452
453
454
455
456
457 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
458 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
459
460 var count int
461 err = row.Scan(&count)
462 if err != nil {
463 return &database.Error{OrigErr: err, Query: []byte(query)}
464 }
465
466 if count == 1 {
467 return nil
468 }
469
470 query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
471 if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
472 return &database.Error{OrigErr: err, Query: []byte(query)}
473 }
474
475 return nil
476 }
477
View as plain text