1 package sqlserver
2
3 import (
4 "context"
5 "database/sql"
6 sqldriver "database/sql/driver"
7 "fmt"
8 "log"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/dhui/dktest"
14 "github.com/golang-migrate/migrate/v4"
15
16 dt "github.com/golang-migrate/migrate/v4/database/testing"
17 "github.com/golang-migrate/migrate/v4/dktesting"
18
19 _ "github.com/golang-migrate/migrate/v4/source/file"
20 )
21
22 const defaultPort = 1433
23 const saPassword = "Root1234"
24
25 var (
26 opts = dktest.Options{
27 Env: map[string]string{"ACCEPT_EULA": "Y", "SA_PASSWORD": saPassword, "MSSQL_PID": "Express"},
28 PortRequired: true, ReadyFunc: isReady, PullTimeout: 2 * time.Minute,
29 }
30
31 specs = []dktesting.ContainerSpec{
32 {ImageName: "mcr.microsoft.com/mssql/server:2017-latest", Options: opts},
33 {ImageName: "mcr.microsoft.com/mssql/server:2019-latest", Options: opts},
34 }
35 )
36
37 func msConnectionString(host, port string) string {
38 return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port)
39 }
40
41 func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string {
42 return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master&useMsi=%t", saPassword, host, port, useMsi)
43 }
44
45 func msConnectionStringMsi(host, port string, useMsi bool) string {
46 return fmt.Sprintf("sqlserver://sa@%v:%v?database=master&useMsi=%t", host, port, useMsi)
47 }
48
49 func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
50 ip, port, err := c.Port(defaultPort)
51 if err != nil {
52 return false
53 }
54 uri := msConnectionString(ip, port)
55 db, err := sql.Open("sqlserver", uri)
56 if err != nil {
57 return false
58 }
59 defer func() {
60 if err := db.Close(); err != nil {
61 log.Println("close error:", err)
62 }
63 }()
64 if err = db.PingContext(ctx); err != nil {
65 switch err {
66 case sqldriver.ErrBadConn:
67 return false
68 default:
69 fmt.Println(err)
70 }
71 return false
72 }
73
74 return true
75 }
76
77 func Test(t *testing.T) {
78 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
79 ip, port, err := c.Port(defaultPort)
80 if err != nil {
81 t.Fatal(err)
82 }
83
84 addr := msConnectionString(ip, port)
85 p := &SQLServer{}
86 d, err := p.Open(addr)
87 if err != nil {
88 t.Fatalf("%v", err)
89 }
90
91 defer func() {
92 if err := d.Close(); err != nil {
93 t.Error(err)
94 }
95 }()
96
97 dt.Test(t, d, []byte("SELECT 1"))
98 })
99 }
100
101 func TestMigrate(t *testing.T) {
102 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
103 ip, port, err := c.Port(defaultPort)
104 if err != nil {
105 t.Fatal(err)
106 }
107
108 addr := msConnectionString(ip, port)
109 p := &SQLServer{}
110 d, err := p.Open(addr)
111 if err != nil {
112 t.Fatalf("%v", err)
113 }
114
115 defer func() {
116 if err := d.Close(); err != nil {
117 t.Error(err)
118 }
119 }()
120
121 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "master", d)
122 if err != nil {
123 t.Fatal(err)
124 }
125 dt.TestMigrate(t, m)
126 })
127 }
128
129 func TestMultiStatement(t *testing.T) {
130 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
131 ip, port, err := c.FirstPort()
132 if err != nil {
133 t.Fatal(err)
134 }
135
136 addr := msConnectionString(ip, port)
137 ms := &SQLServer{}
138 d, err := ms.Open(addr)
139 if err != nil {
140 t.Fatal(err)
141 }
142 defer func() {
143 if err := d.Close(); err != nil {
144 t.Error(err)
145 }
146 }()
147 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
148 t.Fatalf("expected err to be nil, got %v", err)
149 }
150
151
152 var exists int
153 if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
154 t.Fatal(err)
155 }
156 if exists != 1 {
157 t.Fatalf("expected table bar to exist")
158 }
159 })
160 }
161
162 func TestErrorParsing(t *testing.T) {
163 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
164 ip, port, err := c.FirstPort()
165 if err != nil {
166 t.Fatal(err)
167 }
168
169 addr := msConnectionString(ip, port)
170 p := &SQLServer{}
171 d, err := p.Open(addr)
172 if err != nil {
173 t.Fatal(err)
174 }
175 defer func() {
176 if err := d.Close(); err != nil {
177 t.Error(err)
178 }
179 }()
180
181 wantErr := `migration failed: Unknown object type 'TABLEE' used in a CREATE, DROP, or ALTER statement. in line 1:` +
182 ` CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text); (details: mssql: Unknown object type ` +
183 `'TABLEE' used in a CREATE, DROP, or ALTER statement.)`
184 if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
185 t.Fatal("expected err but got nil")
186 } else if err.Error() != wantErr {
187 t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
188 }
189 })
190 }
191
192 func TestLockWorks(t *testing.T) {
193 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
194 ip, port, err := c.Port(defaultPort)
195 if err != nil {
196 t.Fatal(err)
197 }
198
199 addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
200 p := &SQLServer{}
201 d, err := p.Open(addr)
202 if err != nil {
203 t.Fatalf("%v", err)
204 }
205 dt.Test(t, d, []byte("SELECT 1"))
206
207 ms := d.(*SQLServer)
208
209 err = ms.Lock()
210 if err != nil {
211 t.Fatal(err)
212 }
213 err = ms.Unlock()
214 if err != nil {
215 t.Fatal(err)
216 }
217
218
219 err = ms.Lock()
220 if err != nil {
221 t.Fatal(err)
222 }
223 err = ms.Unlock()
224 if err != nil {
225 t.Fatal(err)
226 }
227 })
228 }
229
230 func TestMsiTrue(t *testing.T) {
231 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
232 ip, port, err := c.Port(defaultPort)
233 if err != nil {
234 t.Fatal(err)
235 }
236
237 addr := msConnectionStringMsi(ip, port, true)
238 p := &SQLServer{}
239 _, err = p.Open(addr)
240 if err == nil {
241 t.Fatal("MSI should fail when not running in an Azure context.")
242 }
243 })
244 }
245
246 func TestOpenWithPasswordAndMSI(t *testing.T) {
247 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
248 ip, port, err := c.Port(defaultPort)
249 if err != nil {
250 t.Fatal(err)
251 }
252
253 addr := msConnectionStringMsiWithPassword(ip, port, true)
254 p := &SQLServer{}
255 _, err = p.Open(addr)
256 if err == nil {
257 t.Fatal("Open should fail when both password and useMsi=true are passed.")
258 }
259
260 addr = msConnectionStringMsiWithPassword(ip, port, false)
261 p = &SQLServer{}
262 d, err := p.Open(addr)
263 if err != nil {
264 t.Fatal(err)
265 }
266
267 defer func() {
268 if err := d.Close(); err != nil {
269 t.Error(err)
270 }
271 }()
272
273 dt.Test(t, d, []byte("SELECT 1"))
274 })
275 }
276
277 func TestMsiFalse(t *testing.T) {
278 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
279 ip, port, err := c.Port(defaultPort)
280 if err != nil {
281 t.Fatal(err)
282 }
283
284 addr := msConnectionStringMsi(ip, port, false)
285 p := &SQLServer{}
286 _, err = p.Open(addr)
287 if err == nil {
288 t.Fatal("Open should fail since no password was passed and useMsi is false.")
289 }
290 })
291 }
292
View as plain text