1 package sqlcipher
2
3 import (
4 "database/sql"
5 "fmt"
6 "go.uber.org/atomic"
7 "io"
8 "io/ioutil"
9 nurl "net/url"
10 "strconv"
11 "strings"
12
13 "github.com/golang-migrate/migrate/v4"
14 "github.com/golang-migrate/migrate/v4/database"
15 "github.com/hashicorp/go-multierror"
16 _ "github.com/mutecomm/go-sqlcipher/v4"
17 )
18
19 func init() {
20 database.Register("sqlcipher", &Sqlite{})
21 }
22
23 var DefaultMigrationsTable = "schema_migrations"
24 var (
25 ErrDatabaseDirty = fmt.Errorf("database is dirty")
26 ErrNilConfig = fmt.Errorf("no config")
27 ErrNoDatabaseName = fmt.Errorf("no database name")
28 )
29
30 type Config struct {
31 MigrationsTable string
32 DatabaseName string
33 NoTxWrap bool
34 }
35
36 type Sqlite struct {
37 db *sql.DB
38 isLocked atomic.Bool
39
40 config *Config
41 }
42
43 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
44 if config == nil {
45 return nil, ErrNilConfig
46 }
47
48 if err := instance.Ping(); err != nil {
49 return nil, err
50 }
51
52 if len(config.MigrationsTable) == 0 {
53 config.MigrationsTable = DefaultMigrationsTable
54 }
55
56 mx := &Sqlite{
57 db: instance,
58 config: config,
59 }
60 if err := mx.ensureVersionTable(); err != nil {
61 return nil, err
62 }
63 return mx, nil
64 }
65
66
67
68
69 func (m *Sqlite) ensureVersionTable() (err error) {
70 if err = m.Lock(); err != nil {
71 return err
72 }
73
74 defer func() {
75 if e := m.Unlock(); e != nil {
76 if err == nil {
77 err = e
78 } else {
79 err = multierror.Append(err, e)
80 }
81 }
82 }()
83
84 query := fmt.Sprintf(`
85 CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
86 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
87 `, m.config.MigrationsTable, m.config.MigrationsTable)
88
89 if _, err := m.db.Exec(query); err != nil {
90 return err
91 }
92 return nil
93 }
94
95 func (m *Sqlite) Open(url string) (database.Driver, error) {
96 purl, err := nurl.Parse(url)
97 if err != nil {
98 return nil, err
99 }
100 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
101 db, err := sql.Open("sqlite3", dbfile)
102 if err != nil {
103 return nil, err
104 }
105
106 qv := purl.Query()
107
108 migrationsTable := qv.Get("x-migrations-table")
109 if len(migrationsTable) == 0 {
110 migrationsTable = DefaultMigrationsTable
111 }
112
113 noTxWrap := false
114 if v := qv.Get("x-no-tx-wrap"); v != "" {
115 noTxWrap, err = strconv.ParseBool(v)
116 if err != nil {
117 return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
118 }
119 }
120
121 mx, err := WithInstance(db, &Config{
122 DatabaseName: purl.Path,
123 MigrationsTable: migrationsTable,
124 NoTxWrap: noTxWrap,
125 })
126 if err != nil {
127 return nil, err
128 }
129 return mx, nil
130 }
131
132 func (m *Sqlite) Close() error {
133 return m.db.Close()
134 }
135
136 func (m *Sqlite) Drop() (err error) {
137 query := `SELECT name FROM sqlite_master WHERE type = 'table';`
138 tables, err := m.db.Query(query)
139 if err != nil {
140 return &database.Error{OrigErr: err, Query: []byte(query)}
141 }
142 defer func() {
143 if errClose := tables.Close(); errClose != nil {
144 err = multierror.Append(err, errClose)
145 }
146 }()
147
148 tableNames := make([]string, 0)
149 for tables.Next() {
150 var tableName string
151 if err := tables.Scan(&tableName); err != nil {
152 return err
153 }
154 if len(tableName) > 0 {
155 tableNames = append(tableNames, tableName)
156 }
157 }
158 if err := tables.Err(); err != nil {
159 return &database.Error{OrigErr: err, Query: []byte(query)}
160 }
161
162 if len(tableNames) > 0 {
163 for _, t := range tableNames {
164 query := "DROP TABLE " + t
165 err = m.executeQuery(query)
166 if err != nil {
167 return &database.Error{OrigErr: err, Query: []byte(query)}
168 }
169 }
170 query := "VACUUM"
171 _, err = m.db.Query(query)
172 if err != nil {
173 return &database.Error{OrigErr: err, Query: []byte(query)}
174 }
175 }
176
177 return nil
178 }
179
180 func (m *Sqlite) Lock() error {
181 if !m.isLocked.CAS(false, true) {
182 return database.ErrLocked
183 }
184 return nil
185 }
186
187 func (m *Sqlite) Unlock() error {
188 if !m.isLocked.CAS(true, false) {
189 return database.ErrNotLocked
190 }
191 return nil
192 }
193
194 func (m *Sqlite) Run(migration io.Reader) error {
195 migr, err := ioutil.ReadAll(migration)
196 if err != nil {
197 return err
198 }
199 query := string(migr[:])
200
201 if m.config.NoTxWrap {
202 return m.executeQueryNoTx(query)
203 }
204 return m.executeQuery(query)
205 }
206
207 func (m *Sqlite) executeQuery(query string) error {
208 tx, err := m.db.Begin()
209 if err != nil {
210 return &database.Error{OrigErr: err, Err: "transaction start failed"}
211 }
212 if _, err := tx.Exec(query); err != nil {
213 if errRollback := tx.Rollback(); errRollback != nil {
214 err = multierror.Append(err, errRollback)
215 }
216 return &database.Error{OrigErr: err, Query: []byte(query)}
217 }
218 if err := tx.Commit(); err != nil {
219 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
220 }
221 return nil
222 }
223
224 func (m *Sqlite) executeQueryNoTx(query string) error {
225 if _, err := m.db.Exec(query); err != nil {
226 return &database.Error{OrigErr: err, Query: []byte(query)}
227 }
228 return nil
229 }
230
231 func (m *Sqlite) SetVersion(version int, dirty bool) error {
232 tx, err := m.db.Begin()
233 if err != nil {
234 return &database.Error{OrigErr: err, Err: "transaction start failed"}
235 }
236
237 query := "DELETE FROM " + m.config.MigrationsTable
238 if _, err := tx.Exec(query); err != nil {
239 return &database.Error{OrigErr: err, Query: []byte(query)}
240 }
241
242
243
244
245 if version >= 0 || (version == database.NilVersion && dirty) {
246 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
247 if _, err := tx.Exec(query, version, dirty); err != nil {
248 if errRollback := tx.Rollback(); errRollback != nil {
249 err = multierror.Append(err, errRollback)
250 }
251 return &database.Error{OrigErr: err, Query: []byte(query)}
252 }
253 }
254
255 if err := tx.Commit(); err != nil {
256 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
257 }
258
259 return nil
260 }
261
262 func (m *Sqlite) Version() (version int, dirty bool, err error) {
263 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
264 err = m.db.QueryRow(query).Scan(&version, &dirty)
265 if err != nil {
266 return database.NilVersion, false, nil
267 }
268 return version, dirty, nil
269 }
270
View as plain text