1 package snowflake
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "go.uber.org/atomic"
8 "io"
9 "io/ioutil"
10 nurl "net/url"
11 "strconv"
12 "strings"
13
14 "github.com/golang-migrate/migrate/v4/database"
15 "github.com/hashicorp/go-multierror"
16 "github.com/lib/pq"
17 sf "github.com/snowflakedb/gosnowflake"
18 )
19
20 func init() {
21 db := Snowflake{}
22 database.Register("snowflake", &db)
23 }
24
25 var DefaultMigrationsTable = "schema_migrations"
26
27 var (
28 ErrNilConfig = fmt.Errorf("no config")
29 ErrNoDatabaseName = fmt.Errorf("no database name")
30 ErrNoPassword = fmt.Errorf("no password")
31 ErrNoSchema = fmt.Errorf("no schema")
32 ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name")
33 )
34
35 type Config struct {
36 MigrationsTable string
37 DatabaseName string
38 }
39
40 type Snowflake struct {
41 isLocked atomic.Bool
42 conn *sql.Conn
43 db *sql.DB
44
45
46 config *Config
47 }
48
49 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
50 if config == nil {
51 return nil, ErrNilConfig
52 }
53
54 if err := instance.Ping(); err != nil {
55 return nil, err
56 }
57
58 if config.DatabaseName == "" {
59 query := `SELECT CURRENT_DATABASE()`
60 var databaseName string
61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
62 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
63 }
64
65 if len(databaseName) == 0 {
66 return nil, ErrNoDatabaseName
67 }
68
69 config.DatabaseName = databaseName
70 }
71
72 if len(config.MigrationsTable) == 0 {
73 config.MigrationsTable = DefaultMigrationsTable
74 }
75
76 conn, err := instance.Conn(context.Background())
77
78 if err != nil {
79 return nil, err
80 }
81
82 px := &Snowflake{
83 conn: conn,
84 db: instance,
85 config: config,
86 }
87
88 if err := px.ensureVersionTable(); err != nil {
89 return nil, err
90 }
91
92 return px, nil
93 }
94
95 func (p *Snowflake) Open(url string) (database.Driver, error) {
96 purl, err := nurl.Parse(url)
97 if err != nil {
98 return nil, err
99 }
100
101 password, isPasswordSet := purl.User.Password()
102 if !isPasswordSet {
103 return nil, ErrNoPassword
104 }
105
106 splitPath := strings.Split(purl.Path, "/")
107 if len(splitPath) < 3 {
108 return nil, ErrNoSchemaOrDatabase
109 }
110
111 database := splitPath[2]
112 if len(database) == 0 {
113 return nil, ErrNoDatabaseName
114 }
115
116 schema := splitPath[1]
117 if len(schema) == 0 {
118 return nil, ErrNoSchema
119 }
120
121 cfg := &sf.Config{
122 Account: purl.Host,
123 User: purl.User.Username(),
124 Password: password,
125 Database: database,
126 Schema: schema,
127 }
128
129 dsn, err := sf.DSN(cfg)
130 if err != nil {
131 return nil, err
132 }
133
134 db, err := sql.Open("snowflake", dsn)
135 if err != nil {
136 return nil, err
137 }
138
139 migrationsTable := purl.Query().Get("x-migrations-table")
140
141 px, err := WithInstance(db, &Config{
142 DatabaseName: database,
143 MigrationsTable: migrationsTable,
144 })
145 if err != nil {
146 return nil, err
147 }
148
149 return px, nil
150 }
151
152 func (p *Snowflake) Close() error {
153 connErr := p.conn.Close()
154 dbErr := p.db.Close()
155 if connErr != nil || dbErr != nil {
156 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
157 }
158 return nil
159 }
160
161 func (p *Snowflake) Lock() error {
162 if !p.isLocked.CAS(false, true) {
163 return database.ErrLocked
164 }
165 return nil
166 }
167
168 func (p *Snowflake) Unlock() error {
169 if !p.isLocked.CAS(true, false) {
170 return database.ErrNotLocked
171 }
172 return nil
173 }
174
175 func (p *Snowflake) Run(migration io.Reader) error {
176 migr, err := ioutil.ReadAll(migration)
177 if err != nil {
178 return err
179 }
180
181
182 query := string(migr[:])
183 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
184 if pgErr, ok := err.(*pq.Error); ok {
185 var line uint
186 var col uint
187 var lineColOK bool
188 if pgErr.Position != "" {
189 if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
190 line, col, lineColOK = computeLineFromPos(query, int(pos))
191 }
192 }
193 message := fmt.Sprintf("migration failed: %s", pgErr.Message)
194 if lineColOK {
195 message = fmt.Sprintf("%s (column %d)", message, col)
196 }
197 if pgErr.Detail != "" {
198 message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
199 }
200 return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
201 }
202 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
203 }
204
205 return nil
206 }
207
208 func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
209
210 s = strings.Replace(s, "\r\n", "\n", -1)
211
212 runes := []rune(s)
213 if pos > len(runes) {
214 return 0, 0, false
215 }
216 sel := runes[:pos]
217 line = uint(runesCount(sel, newLine) + 1)
218 col = uint(pos - 1 - runesLastIndex(sel, newLine))
219 return line, col, true
220 }
221
222 const newLine = '\n'
223
224 func runesCount(input []rune, target rune) int {
225 var count int
226 for _, r := range input {
227 if r == target {
228 count++
229 }
230 }
231 return count
232 }
233
234 func runesLastIndex(input []rune, target rune) int {
235 for i := len(input) - 1; i >= 0; i-- {
236 if input[i] == target {
237 return i
238 }
239 }
240 return -1
241 }
242
243 func (p *Snowflake) SetVersion(version int, dirty bool) error {
244 tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
245 if err != nil {
246 return &database.Error{OrigErr: err, Err: "transaction start failed"}
247 }
248
249 query := `DELETE FROM "` + p.config.MigrationsTable + `"`
250 if _, err := tx.Exec(query); err != nil {
251 if errRollback := tx.Rollback(); errRollback != nil {
252 err = multierror.Append(err, errRollback)
253 }
254 return &database.Error{OrigErr: err, Query: []byte(query)}
255 }
256
257
258
259
260 if version >= 0 || (version == database.NilVersion && dirty) {
261 query = `INSERT INTO "` + p.config.MigrationsTable + `" (version,
262 dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `,
263 ` + strconv.FormatBool(dirty) + `)`
264 if _, err := tx.Exec(query); err != nil {
265 if errRollback := tx.Rollback(); errRollback != nil {
266 err = multierror.Append(err, errRollback)
267 }
268 return &database.Error{OrigErr: err, Query: []byte(query)}
269 }
270 }
271
272 if err := tx.Commit(); err != nil {
273 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
274 }
275
276 return nil
277 }
278
279 func (p *Snowflake) Version() (version int, dirty bool, err error) {
280 query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
281 err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
282 switch {
283 case err == sql.ErrNoRows:
284 return database.NilVersion, false, nil
285
286 case err != nil:
287 if e, ok := err.(*pq.Error); ok {
288 if e.Code.Name() == "undefined_table" {
289 return database.NilVersion, false, nil
290 }
291 }
292 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
293
294 default:
295 return version, dirty, nil
296 }
297 }
298
299 func (p *Snowflake) Drop() (err error) {
300
301 query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
302 tables, err := p.conn.QueryContext(context.Background(), query)
303 if err != nil {
304 return &database.Error{OrigErr: err, Query: []byte(query)}
305 }
306 defer func() {
307 if errClose := tables.Close(); errClose != nil {
308 err = multierror.Append(err, errClose)
309 }
310 }()
311
312
313 tableNames := make([]string, 0)
314 for tables.Next() {
315 var tableName string
316 if err := tables.Scan(&tableName); err != nil {
317 return err
318 }
319 if len(tableName) > 0 {
320 tableNames = append(tableNames, tableName)
321 }
322 }
323 if err := tables.Err(); err != nil {
324 return &database.Error{OrigErr: err, Query: []byte(query)}
325 }
326
327 if len(tableNames) > 0 {
328
329 for _, t := range tableNames {
330 query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
331 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
332 return &database.Error{OrigErr: err, Query: []byte(query)}
333 }
334 }
335 }
336
337 return nil
338 }
339
340
341
342
343 func (p *Snowflake) ensureVersionTable() (err error) {
344 if err = p.Lock(); err != nil {
345 return err
346 }
347
348 defer func() {
349 if e := p.Unlock(); e != nil {
350 if err == nil {
351 err = e
352 } else {
353 err = multierror.Append(err, e)
354 }
355 }
356 }()
357
358
359 var count int
360 query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
361 if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
362 return &database.Error{OrigErr: err, Query: []byte(query)}
363 }
364 if count == 1 {
365 return nil
366 }
367
368
369 query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" (
370 version bigint not null primary key, dirty boolean not null)`
371 if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
372 return &database.Error{OrigErr: err, Query: []byte(query)}
373 }
374
375 return nil
376 }
377
View as plain text