1 package spanner
2
3 import (
4 "errors"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "log"
9 nurl "net/url"
10 "regexp"
11 "strconv"
12 "strings"
13
14 "context"
15
16 "cloud.google.com/go/spanner"
17 sdb "cloud.google.com/go/spanner/admin/database/apiv1"
18 "cloud.google.com/go/spanner/spansql"
19
20 "github.com/golang-migrate/migrate/v4"
21 "github.com/golang-migrate/migrate/v4/database"
22
23 "github.com/hashicorp/go-multierror"
24 uatomic "go.uber.org/atomic"
25 "google.golang.org/api/iterator"
26 adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
27 )
28
29 func init() {
30 db := Spanner{}
31 database.Register("spanner", &db)
32 }
33
34
35 const DefaultMigrationsTable = "SchemaMigrations"
36
37 const (
38 unlockedVal = 0
39 lockedVal = 1
40 )
41
42
43 var (
44 ErrNilConfig = errors.New("no config")
45 ErrNoDatabaseName = errors.New("no database name")
46 ErrNoSchema = errors.New("no schema")
47 ErrDatabaseDirty = errors.New("database is dirty")
48 ErrLockHeld = errors.New("unable to obtain lock")
49 ErrLockNotHeld = errors.New("unable to release already released lock")
50 )
51
52
53 type Config struct {
54 MigrationsTable string
55 DatabaseName string
56
57
58
59
60 CleanStatements bool
61 }
62
63
64 type Spanner struct {
65 db *DB
66
67 config *Config
68
69 lock *uatomic.Uint32
70 }
71
72 type DB struct {
73 admin *sdb.DatabaseAdminClient
74 data *spanner.Client
75 }
76
77 func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB {
78 return &DB{
79 admin: &admin,
80 data: &data,
81 }
82 }
83
84
85 func WithInstance(instance *DB, config *Config) (database.Driver, error) {
86 if config == nil {
87 return nil, ErrNilConfig
88 }
89
90 if len(config.DatabaseName) == 0 {
91 return nil, ErrNoDatabaseName
92 }
93
94 if len(config.MigrationsTable) == 0 {
95 config.MigrationsTable = DefaultMigrationsTable
96 }
97
98 sx := &Spanner{
99 db: instance,
100 config: config,
101 lock: uatomic.NewUint32(unlockedVal),
102 }
103
104 if err := sx.ensureVersionTable(); err != nil {
105 return nil, err
106 }
107
108 return sx, nil
109 }
110
111
112 func (s *Spanner) Open(url string) (database.Driver, error) {
113 purl, err := nurl.Parse(url)
114 if err != nil {
115 return nil, err
116 }
117
118 ctx := context.Background()
119
120 adminClient, err := sdb.NewDatabaseAdminClient(ctx)
121 if err != nil {
122 return nil, err
123 }
124 dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
125 dataClient, err := spanner.NewClient(ctx, dbname)
126 if err != nil {
127 log.Fatal(err)
128 }
129
130 migrationsTable := purl.Query().Get("x-migrations-table")
131
132 cleanQuery := purl.Query().Get("x-clean-statements")
133 clean := false
134 if cleanQuery != "" {
135 clean, err = strconv.ParseBool(cleanQuery)
136 if err != nil {
137 return nil, err
138 }
139 }
140
141 db := &DB{admin: adminClient, data: dataClient}
142 return WithInstance(db, &Config{
143 DatabaseName: dbname,
144 MigrationsTable: migrationsTable,
145 CleanStatements: clean,
146 })
147 }
148
149
150 func (s *Spanner) Close() error {
151 s.db.data.Close()
152 return s.db.admin.Close()
153 }
154
155
156
157 func (s *Spanner) Lock() error {
158 if swapped := s.lock.CAS(unlockedVal, lockedVal); swapped {
159 return nil
160 }
161 return ErrLockHeld
162 }
163
164
165 func (s *Spanner) Unlock() error {
166 if swapped := s.lock.CAS(lockedVal, unlockedVal); swapped {
167 return nil
168 }
169 return ErrLockNotHeld
170 }
171
172
173 func (s *Spanner) Run(migration io.Reader) error {
174 migr, err := ioutil.ReadAll(migration)
175 if err != nil {
176 return err
177 }
178
179 stmts := []string{string(migr)}
180 if s.config.CleanStatements {
181 stmts, err = cleanStatements(migr)
182 if err != nil {
183 return err
184 }
185 }
186
187 ctx := context.Background()
188 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
189 Database: s.config.DatabaseName,
190 Statements: stmts,
191 })
192
193 if err != nil {
194 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
195 }
196
197 if err := op.Wait(ctx); err != nil {
198 return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
199 }
200
201 return nil
202 }
203
204
205 func (s *Spanner) SetVersion(version int, dirty bool) error {
206 ctx := context.Background()
207
208 _, err := s.db.data.ReadWriteTransaction(ctx,
209 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
210 m := []*spanner.Mutation{
211 spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
212 spanner.Insert(s.config.MigrationsTable,
213 []string{"Version", "Dirty"},
214 []interface{}{version, dirty},
215 )}
216 return txn.BufferWrite(m)
217 })
218 if err != nil {
219 return &database.Error{OrigErr: err}
220 }
221
222 return nil
223 }
224
225
226 func (s *Spanner) Version() (version int, dirty bool, err error) {
227 ctx := context.Background()
228
229 stmt := spanner.Statement{
230 SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
231 }
232 iter := s.db.data.Single().Query(ctx, stmt)
233 defer iter.Stop()
234
235 row, err := iter.Next()
236 switch err {
237 case iterator.Done:
238 return database.NilVersion, false, nil
239 case nil:
240 var v int64
241 if err = row.Columns(&v, &dirty); err != nil {
242 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
243 }
244 version = int(v)
245 default:
246 return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
247 }
248
249 return version, dirty, nil
250 }
251
252 var nameMatcher = regexp.MustCompile(`(CREATE TABLE\s(\S+)\s)|(CREATE.+INDEX\s(\S+)\s)`)
253
254
255
256
257
258
259
260 func (s *Spanner) Drop() error {
261 ctx := context.Background()
262 res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
263 Database: s.config.DatabaseName,
264 })
265 if err != nil {
266 return &database.Error{OrigErr: err, Err: "drop failed"}
267 }
268 if len(res.Statements) == 0 {
269 return nil
270 }
271
272 stmts := make([]string, 0)
273 for i := len(res.Statements) - 1; i >= 0; i-- {
274 s := res.Statements[i]
275 m := nameMatcher.FindSubmatch([]byte(s))
276
277 if len(m) == 0 {
278 continue
279 } else if tbl := m[2]; len(tbl) > 0 {
280 stmts = append(stmts, fmt.Sprintf(`DROP TABLE %s`, tbl))
281 } else if idx := m[4]; len(idx) > 0 {
282 stmts = append(stmts, fmt.Sprintf(`DROP INDEX %s`, idx))
283 }
284 }
285
286 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
287 Database: s.config.DatabaseName,
288 Statements: stmts,
289 })
290 if err != nil {
291 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
292 }
293 if err := op.Wait(ctx); err != nil {
294 return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
295 }
296
297 return nil
298 }
299
300
301
302
303 func (s *Spanner) ensureVersionTable() (err error) {
304 if err = s.Lock(); err != nil {
305 return err
306 }
307
308 defer func() {
309 if e := s.Unlock(); e != nil {
310 if err == nil {
311 err = e
312 } else {
313 err = multierror.Append(err, e)
314 }
315 }
316 }()
317
318 ctx := context.Background()
319 tbl := s.config.MigrationsTable
320 iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"})
321 if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
322 return nil
323 }
324
325 stmt := fmt.Sprintf(`CREATE TABLE %s (
326 Version INT64 NOT NULL,
327 Dirty BOOL NOT NULL
328 ) PRIMARY KEY(Version)`, tbl)
329
330 op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
331 Database: s.config.DatabaseName,
332 Statements: []string{stmt},
333 })
334
335 if err != nil {
336 return &database.Error{OrigErr: err, Query: []byte(stmt)}
337 }
338 if err := op.Wait(ctx); err != nil {
339 return &database.Error{OrigErr: err, Query: []byte(stmt)}
340 }
341
342 return nil
343 }
344
345 func cleanStatements(migration []byte) ([]string, error) {
346
347
348
349 ddl, err := spansql.ParseDDL("", string(migration))
350 if err != nil {
351 return nil, err
352 }
353 stmts := make([]string, 0, len(ddl.List))
354 for _, stmt := range ddl.List {
355 stmts = append(stmts, stmt.SQL())
356 }
357 return stmts, nil
358 }
359
View as plain text