1 package ql
2
3 import (
4 "database/sql"
5 "fmt"
6 "github.com/hashicorp/go-multierror"
7 "go.uber.org/atomic"
8 "io"
9 "io/ioutil"
10 "strings"
11
12 nurl "net/url"
13
14 "github.com/golang-migrate/migrate/v4"
15 "github.com/golang-migrate/migrate/v4/database"
16 _ "modernc.org/ql/driver"
17 )
18
19 func init() {
20 database.Register("ql", &Ql{})
21 }
22
23 var DefaultMigrationsTable = "schema_migrations"
24 var (
25 ErrDatabaseDirty = fmt.Errorf("database is dirty")
26 ErrNilConfig = fmt.Errorf("no config")
27 ErrNoDatabaseName = fmt.Errorf("no database name")
28 ErrAppendPEM = fmt.Errorf("failed to append PEM")
29 )
30
31 type Config struct {
32 MigrationsTable string
33 DatabaseName string
34 }
35
36 type Ql struct {
37 db *sql.DB
38 isLocked atomic.Bool
39
40 config *Config
41 }
42
43 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
44 if config == nil {
45 return nil, ErrNilConfig
46 }
47
48 if err := instance.Ping(); err != nil {
49 return nil, err
50 }
51
52 if len(config.MigrationsTable) == 0 {
53 config.MigrationsTable = DefaultMigrationsTable
54 }
55
56 mx := &Ql{
57 db: instance,
58 config: config,
59 }
60 if err := mx.ensureVersionTable(); err != nil {
61 return nil, err
62 }
63 return mx, nil
64 }
65
66
67
68
69 func (m *Ql) ensureVersionTable() (err error) {
70 if err = m.Lock(); err != nil {
71 return err
72 }
73
74 defer func() {
75 if e := m.Unlock(); e != nil {
76 if err == nil {
77 err = e
78 } else {
79 err = multierror.Append(err, e)
80 }
81 }
82 }()
83
84 tx, err := m.db.Begin()
85 if err != nil {
86 return err
87 }
88 if _, err := tx.Exec(fmt.Sprintf(`
89 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);
90 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
91 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
92 if err := tx.Rollback(); err != nil {
93 return err
94 }
95 return err
96 }
97 if err := tx.Commit(); err != nil {
98 return err
99 }
100 return nil
101 }
102
103 func (m *Ql) Open(url string) (database.Driver, error) {
104 purl, err := nurl.Parse(url)
105 if err != nil {
106 return nil, err
107 }
108 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
109 db, err := sql.Open("ql", dbfile)
110 if err != nil {
111 return nil, err
112 }
113 migrationsTable := purl.Query().Get("x-migrations-table")
114 if len(migrationsTable) == 0 {
115 migrationsTable = DefaultMigrationsTable
116 }
117 mx, err := WithInstance(db, &Config{
118 DatabaseName: purl.Path,
119 MigrationsTable: migrationsTable,
120 })
121 if err != nil {
122 return nil, err
123 }
124 return mx, nil
125 }
126 func (m *Ql) Close() error {
127 return m.db.Close()
128 }
129 func (m *Ql) Drop() (err error) {
130 query := `SELECT Name FROM __Table`
131 tables, err := m.db.Query(query)
132 if err != nil {
133 return &database.Error{OrigErr: err, Query: []byte(query)}
134 }
135 defer func() {
136 if errClose := tables.Close(); errClose != nil {
137 err = multierror.Append(err, errClose)
138 }
139 }()
140
141 tableNames := make([]string, 0)
142 for tables.Next() {
143 var tableName string
144 if err := tables.Scan(&tableName); err != nil {
145 return err
146 }
147 if len(tableName) > 0 {
148 if !strings.HasPrefix(tableName, "__") {
149 tableNames = append(tableNames, tableName)
150 }
151 }
152 }
153 if err := tables.Err(); err != nil {
154 return &database.Error{OrigErr: err, Query: []byte(query)}
155 }
156
157 if len(tableNames) > 0 {
158 for _, t := range tableNames {
159 query := "DROP TABLE " + t
160 err = m.executeQuery(query)
161 if err != nil {
162 return &database.Error{OrigErr: err, Query: []byte(query)}
163 }
164 }
165 }
166
167 return nil
168 }
169 func (m *Ql) Lock() error {
170 if !m.isLocked.CAS(false, true) {
171 return database.ErrLocked
172 }
173 return nil
174 }
175 func (m *Ql) Unlock() error {
176 if !m.isLocked.CAS(true, false) {
177 return database.ErrNotLocked
178 }
179 return nil
180 }
181 func (m *Ql) Run(migration io.Reader) error {
182 migr, err := ioutil.ReadAll(migration)
183 if err != nil {
184 return err
185 }
186 query := string(migr[:])
187
188 return m.executeQuery(query)
189 }
190 func (m *Ql) executeQuery(query string) error {
191 tx, err := m.db.Begin()
192 if err != nil {
193 return &database.Error{OrigErr: err, Err: "transaction start failed"}
194 }
195 if _, err := tx.Exec(query); err != nil {
196 if errRollback := tx.Rollback(); errRollback != nil {
197 err = multierror.Append(err, errRollback)
198 }
199 return &database.Error{OrigErr: err, Query: []byte(query)}
200 }
201 if err := tx.Commit(); err != nil {
202 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
203 }
204 return nil
205 }
206 func (m *Ql) SetVersion(version int, dirty bool) error {
207 tx, err := m.db.Begin()
208 if err != nil {
209 return &database.Error{OrigErr: err, Err: "transaction start failed"}
210 }
211
212 query := "TRUNCATE TABLE " + m.config.MigrationsTable
213 if _, err := tx.Exec(query); err != nil {
214 return &database.Error{OrigErr: err, Query: []byte(query)}
215 }
216
217
218
219
220 if version >= 0 || (version == database.NilVersion && dirty) {
221 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
222 m.config.MigrationsTable)
223 if _, err := tx.Exec(query, version, dirty); err != nil {
224 if errRollback := tx.Rollback(); errRollback != nil {
225 err = multierror.Append(err, errRollback)
226 }
227 return &database.Error{OrigErr: err, Query: []byte(query)}
228 }
229 }
230
231 if err := tx.Commit(); err != nil {
232 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
233 }
234
235 return nil
236 }
237
238 func (m *Ql) Version() (version int, dirty bool, err error) {
239 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
240 err = m.db.QueryRow(query).Scan(&version, &dirty)
241 if err != nil {
242 return database.NilVersion, false, nil
243 }
244 return version, dirty, nil
245 }
246
View as plain text