1
2
3
4 package pgx
5
6 import (
7 "context"
8 "database/sql"
9 "fmt"
10 "go.uber.org/atomic"
11 "io"
12 "io/ioutil"
13 nurl "net/url"
14 "regexp"
15 "strconv"
16 "strings"
17 "time"
18
19 "github.com/golang-migrate/migrate/v4"
20 "github.com/golang-migrate/migrate/v4/database"
21 "github.com/golang-migrate/migrate/v4/database/multistmt"
22 multierror "github.com/hashicorp/go-multierror"
23 "github.com/jackc/pgconn"
24 "github.com/jackc/pgerrcode"
25 _ "github.com/jackc/pgx/v4/stdlib"
26 )
27
28 func init() {
29 db := Postgres{}
30 database.Register("pgx", &db)
31 }
32
33 var (
34 multiStmtDelimiter = []byte(";")
35
36 DefaultMigrationsTable = "schema_migrations"
37 DefaultMultiStatementMaxSize = 10 * 1 << 20
38 )
39
40 var (
41 ErrNilConfig = fmt.Errorf("no config")
42 ErrNoDatabaseName = fmt.Errorf("no database name")
43 ErrNoSchema = fmt.Errorf("no schema")
44 ErrDatabaseDirty = fmt.Errorf("database is dirty")
45 )
46
47 type Config struct {
48 MigrationsTable string
49 DatabaseName string
50 SchemaName string
51 migrationsSchemaName string
52 migrationsTableName string
53 StatementTimeout time.Duration
54 MigrationsTableQuoted bool
55 MultiStatementEnabled bool
56 MultiStatementMaxSize int
57 }
58
59 type Postgres struct {
60
61 conn *sql.Conn
62 db *sql.DB
63 isLocked atomic.Bool
64
65
66 config *Config
67 }
68
69 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
70 if config == nil {
71 return nil, ErrNilConfig
72 }
73
74 if err := instance.Ping(); err != nil {
75 return nil, err
76 }
77
78 if config.DatabaseName == "" {
79 query := `SELECT CURRENT_DATABASE()`
80 var databaseName string
81 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
82 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
83 }
84
85 if len(databaseName) == 0 {
86 return nil, ErrNoDatabaseName
87 }
88
89 config.DatabaseName = databaseName
90 }
91
92 if config.SchemaName == "" {
93 query := `SELECT CURRENT_SCHEMA()`
94 var schemaName string
95 if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
96 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
97 }
98
99 if len(schemaName) == 0 {
100 return nil, ErrNoSchema
101 }
102
103 config.SchemaName = schemaName
104 }
105
106 if len(config.MigrationsTable) == 0 {
107 config.MigrationsTable = DefaultMigrationsTable
108 }
109
110 config.migrationsSchemaName = config.SchemaName
111 config.migrationsTableName = config.MigrationsTable
112 if config.MigrationsTableQuoted {
113 re := regexp.MustCompile(`"(.*?)"`)
114 result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
115 config.migrationsTableName = result[len(result)-1][1]
116 if len(result) == 2 {
117 config.migrationsSchemaName = result[0][1]
118 } else if len(result) > 2 {
119 return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
120 }
121 }
122
123 conn, err := instance.Conn(context.Background())
124
125 if err != nil {
126 return nil, err
127 }
128
129 px := &Postgres{
130 conn: conn,
131 db: instance,
132 config: config,
133 }
134
135 if err := px.ensureVersionTable(); err != nil {
136 return nil, err
137 }
138
139 return px, nil
140 }
141
142 func (p *Postgres) Open(url string) (database.Driver, error) {
143 purl, err := nurl.Parse(url)
144 if err != nil {
145 return nil, err
146 }
147
148
149
150
151 purl.Scheme = "postgres"
152
153 db, err := sql.Open("pgx", migrate.FilterCustomQuery(purl).String())
154 if err != nil {
155 return nil, err
156 }
157
158 migrationsTable := purl.Query().Get("x-migrations-table")
159 migrationsTableQuoted := false
160 if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
161 migrationsTableQuoted, err = strconv.ParseBool(s)
162 if err != nil {
163 return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
164 }
165 }
166 if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
167 return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
168 }
169
170 statementTimeoutString := purl.Query().Get("x-statement-timeout")
171 statementTimeout := 0
172 if statementTimeoutString != "" {
173 statementTimeout, err = strconv.Atoi(statementTimeoutString)
174 if err != nil {
175 return nil, err
176 }
177 }
178
179 multiStatementMaxSize := DefaultMultiStatementMaxSize
180 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
181 multiStatementMaxSize, err = strconv.Atoi(s)
182 if err != nil {
183 return nil, err
184 }
185 if multiStatementMaxSize <= 0 {
186 multiStatementMaxSize = DefaultMultiStatementMaxSize
187 }
188 }
189
190 multiStatementEnabled := false
191 if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
192 multiStatementEnabled, err = strconv.ParseBool(s)
193 if err != nil {
194 return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
195 }
196 }
197
198 px, err := WithInstance(db, &Config{
199 DatabaseName: purl.Path,
200 MigrationsTable: migrationsTable,
201 MigrationsTableQuoted: migrationsTableQuoted,
202 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
203 MultiStatementEnabled: multiStatementEnabled,
204 MultiStatementMaxSize: multiStatementMaxSize,
205 })
206
207 if err != nil {
208 return nil, err
209 }
210
211 return px, nil
212 }
213
214 func (p *Postgres) Close() error {
215 connErr := p.conn.Close()
216 dbErr := p.db.Close()
217 if connErr != nil || dbErr != nil {
218 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
219 }
220 return nil
221 }
222
223
224 func (p *Postgres) Lock() error {
225 return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
226 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
227 if err != nil {
228 return err
229 }
230
231
232 query := `SELECT pg_advisory_lock($1)`
233 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
234 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
235 }
236 return nil
237 })
238 }
239
240 func (p *Postgres) Unlock() error {
241 return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
242 aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
243 if err != nil {
244 return err
245 }
246
247 query := `SELECT pg_advisory_unlock($1)`
248 if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
249 return &database.Error{OrigErr: err, Query: []byte(query)}
250 }
251 return nil
252 })
253 }
254
255 func (p *Postgres) Run(migration io.Reader) error {
256 if p.config.MultiStatementEnabled {
257 var err error
258 if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
259 if err = p.runStatement(m); err != nil {
260 return false
261 }
262 return true
263 }); e != nil {
264 return e
265 }
266 return err
267 }
268 migr, err := ioutil.ReadAll(migration)
269 if err != nil {
270 return err
271 }
272 return p.runStatement(migr)
273 }
274
275 func (p *Postgres) runStatement(statement []byte) error {
276 ctx := context.Background()
277 if p.config.StatementTimeout != 0 {
278 var cancel context.CancelFunc
279 ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
280 defer cancel()
281 }
282 query := string(statement)
283 if strings.TrimSpace(query) == "" {
284 return nil
285 }
286 if _, err := p.conn.ExecContext(ctx, query); err != nil {
287
288 if pgErr, ok := err.(*pgconn.PgError); ok {
289 var line uint
290 var col uint
291 var lineColOK bool
292 line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
293 message := fmt.Sprintf("migration failed: %s", pgErr.Message)
294 if lineColOK {
295 message = fmt.Sprintf("%s (column %d)", message, col)
296 }
297 if pgErr.Detail != "" {
298 message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
299 }
300 return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
301 }
302 return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
303 }
304 return nil
305 }
306
307 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
308
309 s = strings.Replace(s, "\r\n", "\n", -1)
310
311 runes := []rune(s)
312 if pos > len(runes) {
313 return 0, 0, false
314 }
315 sel := runes[:pos]
316 line = uint(runesCount(sel, newLine) + 1)
317 col = uint(pos - 1 - runesLastIndex(sel, newLine))
318 return line, col, true
319 }
320
321 const newLine = '\n'
322
323 func runesCount(input []rune, target rune) int {
324 var count int
325 for _, r := range input {
326 if r == target {
327 count++
328 }
329 }
330 return count
331 }
332
333 func runesLastIndex(input []rune, target rune) int {
334 for i := len(input) - 1; i >= 0; i-- {
335 if input[i] == target {
336 return i
337 }
338 }
339 return -1
340 }
341
342 func (p *Postgres) SetVersion(version int, dirty bool) error {
343 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
344 if err != nil {
345 return &database.Error{OrigErr: err, Err: "transaction start failed"}
346 }
347
348 query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
349 if _, err := tx.Exec(query); err != nil {
350 if errRollback := tx.Rollback(); errRollback != nil {
351 err = multierror.Append(err, errRollback)
352 }
353 return &database.Error{OrigErr: err, Query: []byte(query)}
354 }
355
356
357
358
359 if version >= 0 || (version == database.NilVersion && dirty) {
360 query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
361 if _, err := tx.Exec(query, version, dirty); err != nil {
362 if errRollback := tx.Rollback(); errRollback != nil {
363 err = multierror.Append(err, errRollback)
364 }
365 return &database.Error{OrigErr: err, Query: []byte(query)}
366 }
367 }
368
369 if err := tx.Commit(); err != nil {
370 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
371 }
372
373 return nil
374 }
375
376 func (p *Postgres) Version() (version int, dirty bool, err error) {
377 query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
378 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
379 switch {
380 case err == sql.ErrNoRows:
381 return database.NilVersion, false, nil
382
383 case err != nil:
384 if e, ok := err.(*pgconn.PgError); ok {
385 if e.SQLState() == pgerrcode.UndefinedTable {
386 return database.NilVersion, false, nil
387 }
388 }
389 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
390
391 default:
392 return version, dirty, nil
393 }
394 }
395
396 func (p *Postgres) Drop() (err error) {
397
398 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
399 tables, err := p.conn.QueryContext(context.Background(), query)
400 if err != nil {
401 return &database.Error{OrigErr: err, Query: []byte(query)}
402 }
403 defer func() {
404 if errClose := tables.Close(); errClose != nil {
405 err = multierror.Append(err, errClose)
406 }
407 }()
408
409
410 tableNames := make([]string, 0)
411 for tables.Next() {
412 var tableName string
413 if err := tables.Scan(&tableName); err != nil {
414 return err
415 }
416 if len(tableName) > 0 {
417 tableNames = append(tableNames, tableName)
418 }
419 }
420 if err := tables.Err(); err != nil {
421 return &database.Error{OrigErr: err, Query: []byte(query)}
422 }
423
424 if len(tableNames) > 0 {
425
426 for _, t := range tableNames {
427 query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
428 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
429 return &database.Error{OrigErr: err, Query: []byte(query)}
430 }
431 }
432 }
433
434 return nil
435 }
436
437
438
439
440 func (p *Postgres) ensureVersionTable() (err error) {
441 if err = p.Lock(); err != nil {
442 return err
443 }
444
445 defer func() {
446 if e := p.Unlock(); e != nil {
447 if err == nil {
448 err = e
449 } else {
450 err = multierror.Append(err, e)
451 }
452 }
453 }()
454
455
456
457
458
459 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
460 row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
461
462 var count int
463 err = row.Scan(&count)
464 if err != nil {
465 return &database.Error{OrigErr: err, Query: []byte(query)}
466 }
467
468 if count == 1 {
469 return nil
470 }
471
472 query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
473 if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
474 return &database.Error{OrigErr: err, Query: []byte(query)}
475 }
476
477 return nil
478 }
479
480
481 func quoteIdentifier(name string) string {
482 end := strings.IndexRune(name, 0)
483 if end > -1 {
484 name = name[:end]
485 }
486 return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
487 }
488
View as plain text