1 package sqlserver
2
3 import (
4 "context"
5 "database/sql"
6 "fmt"
7 "io"
8 "io/ioutil"
9 nurl "net/url"
10 "strconv"
11 "strings"
12
13 "go.uber.org/atomic"
14
15 "github.com/Azure/go-autorest/autorest/adal"
16 mssql "github.com/denisenkom/go-mssqldb"
17 "github.com/golang-migrate/migrate/v4"
18 "github.com/golang-migrate/migrate/v4/database"
19 "github.com/hashicorp/go-multierror"
20 )
21
22 func init() {
23 database.Register("sqlserver", &SQLServer{})
24 }
25
26
27 var DefaultMigrationsTable = "schema_migrations"
28
29 var (
30 ErrNilConfig = fmt.Errorf("no config")
31 ErrNoDatabaseName = fmt.Errorf("no database name")
32 ErrNoSchema = fmt.Errorf("no schema")
33 ErrDatabaseDirty = fmt.Errorf("database is dirty")
34 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
35 )
36
37 var lockErrorMap = map[mssql.ReturnStatus]string{
38 -1: "The lock request timed out.",
39 -2: "The lock request was canceled.",
40 -3: "The lock request was chosen as a deadlock victim.",
41 -999: "Parameter validation or other call error.",
42 }
43
44
45 type Config struct {
46 MigrationsTable string
47 DatabaseName string
48 SchemaName string
49 }
50
51
52 type SQLServer struct {
53
54 conn *sql.Conn
55 db *sql.DB
56 isLocked atomic.Bool
57
58
59 config *Config
60 }
61
62
63
64
65 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
66 if config == nil {
67 return nil, ErrNilConfig
68 }
69
70 if err := instance.Ping(); err != nil {
71 return nil, err
72 }
73
74 if config.DatabaseName == "" {
75 query := `SELECT DB_NAME()`
76 var databaseName string
77 if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
78 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
79 }
80
81 if len(databaseName) == 0 {
82 return nil, ErrNoDatabaseName
83 }
84
85 config.DatabaseName = databaseName
86 }
87
88 if config.SchemaName == "" {
89 query := `SELECT SCHEMA_NAME()`
90 var schemaName string
91 if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
92 return nil, &database.Error{OrigErr: err, Query: []byte(query)}
93 }
94
95 if len(schemaName) == 0 {
96 return nil, ErrNoSchema
97 }
98
99 config.SchemaName = schemaName
100 }
101
102 if len(config.MigrationsTable) == 0 {
103 config.MigrationsTable = DefaultMigrationsTable
104 }
105
106 conn, err := instance.Conn(context.Background())
107
108 if err != nil {
109 return nil, err
110 }
111
112 ss := &SQLServer{
113 conn: conn,
114 db: instance,
115 config: config,
116 }
117
118 if err := ss.ensureVersionTable(); err != nil {
119 return nil, err
120 }
121
122 return ss, nil
123 }
124
125
126 func (ss *SQLServer) Open(url string) (database.Driver, error) {
127 purl, err := nurl.Parse(url)
128 if err != nil {
129 return nil, err
130 }
131
132 useMsiParam := purl.Query().Get("useMsi")
133 useMsi := false
134 if len(useMsiParam) > 0 {
135 useMsi, err = strconv.ParseBool(useMsiParam)
136 if err != nil {
137 return nil, err
138 }
139 }
140
141 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
142 return nil, ErrMultipleAuthOptionsPassed
143 }
144
145 filteredURL := migrate.FilterCustomQuery(purl).String()
146
147 var db *sql.DB
148 if useMsi {
149 resource := getAADResourceFromServerUri(purl)
150 tokenProvider, err := getMSITokenProvider(resource)
151 if err != nil {
152 return nil, err
153 }
154
155 connector, err := mssql.NewAccessTokenConnector(
156 filteredURL, tokenProvider)
157 if err != nil {
158 return nil, err
159 }
160
161 db = sql.OpenDB(connector)
162
163 } else {
164 db, err = sql.Open("sqlserver", filteredURL)
165 if err != nil {
166 return nil, err
167 }
168 }
169
170 migrationsTable := purl.Query().Get("x-migrations-table")
171
172 px, err := WithInstance(db, &Config{
173 DatabaseName: purl.Path,
174 MigrationsTable: migrationsTable,
175 })
176
177 if err != nil {
178 return nil, err
179 }
180
181 return px, nil
182 }
183
184
185 func (ss *SQLServer) Close() error {
186 connErr := ss.conn.Close()
187 dbErr := ss.db.Close()
188 if connErr != nil || dbErr != nil {
189 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
190 }
191 return nil
192 }
193
194
195 func (ss *SQLServer) Lock() error {
196 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
197 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
198 if err != nil {
199 return err
200 }
201
202
203
204
205 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
206
207 var status mssql.ReturnStatus
208 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
209 return nil
210 } else if err != nil {
211 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
212 } else {
213 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
214 }
215 })
216 }
217
218
219 func (ss *SQLServer) Unlock() error {
220 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
221 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
222 if err != nil {
223 return err
224 }
225
226
227 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
228 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
229 return &database.Error{OrigErr: err, Query: []byte(query)}
230 }
231
232 return nil
233 })
234 }
235
236
237 func (ss *SQLServer) Run(migration io.Reader) error {
238 migr, err := ioutil.ReadAll(migration)
239 if err != nil {
240 return err
241 }
242
243
244 query := string(migr[:])
245 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
246 if msErr, ok := err.(mssql.Error); ok {
247 message := fmt.Sprintf("migration failed: %s", msErr.Message)
248 if msErr.ProcName != "" {
249 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
250 }
251 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
252 }
253 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
254 }
255
256 return nil
257 }
258
259
260 func (ss *SQLServer) SetVersion(version int, dirty bool) error {
261
262 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
263 if err != nil {
264 return &database.Error{OrigErr: err, Err: "transaction start failed"}
265 }
266
267 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
268 if _, err := tx.Exec(query); err != nil {
269 if errRollback := tx.Rollback(); errRollback != nil {
270 err = multierror.Append(err, errRollback)
271 }
272 return &database.Error{OrigErr: err, Query: []byte(query)}
273 }
274
275
276
277
278 if version >= 0 || (version == database.NilVersion && dirty) {
279 var dirtyBit int
280 if dirty {
281 dirtyBit = 1
282 }
283 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
284 if _, err := tx.Exec(query, version, dirtyBit); err != nil {
285 if errRollback := tx.Rollback(); errRollback != nil {
286 err = multierror.Append(err, errRollback)
287 }
288 return &database.Error{OrigErr: err, Query: []byte(query)}
289 }
290 }
291
292 if err := tx.Commit(); err != nil {
293 return &database.Error{OrigErr: err, Err: "transaction commit failed"}
294 }
295
296 return nil
297 }
298
299
300 func (ss *SQLServer) Version() (version int, dirty bool, err error) {
301 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
302 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
303 switch {
304 case err == sql.ErrNoRows:
305 return database.NilVersion, false, nil
306
307 case err != nil:
308
309 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
310
311 default:
312 return version, dirty, nil
313 }
314 }
315
316
317 func (ss *SQLServer) Drop() error {
318
319
320 query := `
321 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
322
323 SET @Cursor = CURSOR FAST_FORWARD FOR
324 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
325 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
326 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
327
328 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
329
330 WHILE (@@FETCH_STATUS = 0)
331 BEGIN
332 Exec sp_executesql @Sql
333 FETCH NEXT FROM @Cursor INTO @Sql
334 END
335
336 CLOSE @Cursor DEALLOCATE @Cursor`
337
338 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
339 return &database.Error{OrigErr: err, Query: []byte(query)}
340 }
341
342
343 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
344 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
345 return &database.Error{OrigErr: err, Query: []byte(query)}
346 }
347
348 return nil
349 }
350
351 func (ss *SQLServer) ensureVersionTable() (err error) {
352 if err = ss.Lock(); err != nil {
353 return err
354 }
355
356 defer func() {
357 if e := ss.Unlock(); e != nil {
358 if err == nil {
359 err = e
360 } else {
361 err = multierror.Append(err, e)
362 }
363 }
364 }()
365
366 query := `IF NOT EXISTS
367 (SELECT *
368 FROM sysobjects
369 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
370 AND OBJECTPROPERTY(id, N'IsUserTable') = 1
371 )
372 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
373
374 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
375 return &database.Error{OrigErr: err, Query: []byte(query)}
376 }
377
378 return nil
379 }
380
381 func getMSITokenProvider(resource string) (func() (string, error), error) {
382 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
383 if err != nil {
384 return nil, err
385 }
386
387 return func() (string, error) {
388 err := msi.EnsureFresh()
389 if err != nil {
390 return "", err
391 }
392 token := msi.OAuthToken()
393 return token, nil
394 }, nil
395 }
396
397
398
399
400 func getAADResourceFromServerUri(purl *nurl.URL) string {
401 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
402 }
403
View as plain text