1
2
3
4 package firebird
5
6 import (
7 "context"
8 "database/sql"
9 "fmt"
10 "github.com/golang-migrate/migrate/v4"
11 "github.com/golang-migrate/migrate/v4/database"
12 "github.com/hashicorp/go-multierror"
13 _ "github.com/nakagami/firebirdsql"
14 "go.uber.org/atomic"
15 "io"
16 "io/ioutil"
17 nurl "net/url"
18 )
19
20 func init() {
21 db := Firebird{}
22 database.Register("firebird", &db)
23 database.Register("firebirdsql", &db)
24 }
25
26 var DefaultMigrationsTable = "schema_migrations"
27
28 var (
29 ErrNilConfig = fmt.Errorf("no config")
30 )
31
32 type Config struct {
33 DatabaseName string
34 MigrationsTable string
35 }
36
37 type Firebird struct {
38
39 conn *sql.Conn
40 db *sql.DB
41 isLocked atomic.Bool
42
43
44 config *Config
45 }
46
47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
48 if config == nil {
49 return nil, ErrNilConfig
50 }
51
52 if err := instance.Ping(); err != nil {
53 return nil, err
54 }
55
56 if len(config.MigrationsTable) == 0 {
57 config.MigrationsTable = DefaultMigrationsTable
58 }
59
60 conn, err := instance.Conn(context.Background())
61 if err != nil {
62 return nil, err
63 }
64
65 fb := &Firebird{
66 conn: conn,
67 db: instance,
68 config: config,
69 }
70
71 if err := fb.ensureVersionTable(); err != nil {
72 return nil, err
73 }
74
75 return fb, nil
76 }
77
78 func (f *Firebird) Open(dsn string) (database.Driver, error) {
79 purl, err := nurl.Parse(dsn)
80 if err != nil {
81 return nil, err
82 }
83
84 db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
85 if err != nil {
86 return nil, err
87 }
88
89 px, err := WithInstance(db, &Config{
90 MigrationsTable: purl.Query().Get("x-migrations-table"),
91 DatabaseName: purl.Path,
92 })
93
94 if err != nil {
95 return nil, err
96 }
97
98 return px, nil
99 }
100
101 func (f *Firebird) Close() error {
102 connErr := f.conn.Close()
103 dbErr := f.db.Close()
104 if connErr != nil || dbErr != nil {
105 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
106 }
107 return nil
108 }
109
110 func (f *Firebird) Lock() error {
111 if !f.isLocked.CAS(false, true) {
112 return database.ErrLocked
113 }
114 return nil
115 }
116
117 func (f *Firebird) Unlock() error {
118 if !f.isLocked.CAS(true, false) {
119 return database.ErrNotLocked
120 }
121 return nil
122 }
123
124 func (f *Firebird) Run(migration io.Reader) error {
125 migr, err := ioutil.ReadAll(migration)
126 if err != nil {
127 return err
128 }
129
130
131 query := string(migr[:])
132 if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
133 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
134 }
135
136 return nil
137 }
138
139 func (f *Firebird) SetVersion(version int, dirty bool) error {
140
141
142
143
144
145
146
147 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
148 DELETE FROM "%v";
149 INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
150 END;`,
151 f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
152
153 if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
154 return &database.Error{OrigErr: err, Query: []byte(query)}
155 }
156
157 return nil
158 }
159
160 func (f *Firebird) Version() (version int, dirty bool, err error) {
161 var d int
162 query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
163 err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
164 switch {
165 case err == sql.ErrNoRows:
166 return database.NilVersion, false, nil
167 case err != nil:
168 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
169
170 default:
171 return version, itob(d), nil
172 }
173 }
174
175 func (f *Firebird) Drop() (err error) {
176
177 query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
178 tables, err := f.conn.QueryContext(context.Background(), query)
179 if err != nil {
180 return &database.Error{OrigErr: err, Query: []byte(query)}
181 }
182 defer func() {
183 if errClose := tables.Close(); errClose != nil {
184 err = multierror.Append(err, errClose)
185 }
186 }()
187
188
189 tableNames := make([]string, 0)
190 for tables.Next() {
191 var tableName string
192 if err := tables.Scan(&tableName); err != nil {
193 return err
194 }
195 if len(tableName) > 0 {
196 tableNames = append(tableNames, tableName)
197 }
198 }
199 if err := tables.Err(); err != nil {
200 return &database.Error{OrigErr: err, Query: []byte(query)}
201 }
202
203
204 for _, t := range tableNames {
205 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
206 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
207 execute statement 'drop table "%v"';
208 END;`,
209 t, t)
210
211 if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
212 return &database.Error{OrigErr: err, Query: []byte(query)}
213 }
214 }
215
216 return nil
217 }
218
219
220 func (f *Firebird) ensureVersionTable() (err error) {
221 if err = f.Lock(); err != nil {
222 return err
223 }
224
225 defer func() {
226 if e := f.Unlock(); e != nil {
227 if err == nil {
228 err = e
229 } else {
230 err = multierror.Append(err, e)
231 }
232 }
233 }()
234
235 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
236 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
237 execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
238 END;`,
239 f.config.MigrationsTable, f.config.MigrationsTable)
240
241 if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
242 return &database.Error{OrigErr: err, Query: []byte(query)}
243 }
244
245 return nil
246 }
247
248
249 func btoi(v bool) int {
250 if v {
251 return 1
252 }
253 return 0
254 }
255
256
257 func itob(v int) bool {
258 return v != 0
259 }
260
View as plain text