1 package multistmt_test
2
3 import (
4 "strings"
5 "testing"
6
7 "github.com/stretchr/testify/assert"
8
9 "github.com/golang-migrate/migrate/v4/database/multistmt"
10 )
11
12 const maxMigrationSize = 1024
13
14 func TestParse(t *testing.T) {
15 testCases := []struct {
16 name string
17 multiStmt string
18 delimiter string
19 expected []string
20 expectedErr error
21 }{
22 {name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
23 expected: []string{"single statement, no delimiter"}, expectedErr: nil},
24 {name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
25 expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
26 {name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
27 expected: []string{"statement one;", " statement two"}, expectedErr: nil},
28 {name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
29 expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
30 }
31
32 for _, tc := range testCases {
33 t.Run(tc.name, func(t *testing.T) {
34 stmts := make([]string, 0, len(tc.expected))
35 err := multistmt.Parse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
36 stmts = append(stmts, string(b))
37 return true
38 })
39 assert.Equal(t, tc.expectedErr, err)
40 assert.Equal(t, tc.expected, stmts)
41 })
42 }
43 }
44
45 func TestParseDiscontinue(t *testing.T) {
46 multiStmt := "statement one; statement two"
47 delimiter := ";"
48 expected := []string{"statement one;"}
49
50 stmts := make([]string, 0, len(expected))
51 err := multistmt.Parse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
52 stmts = append(stmts, string(b))
53 return false
54 })
55 assert.Nil(t, err)
56 assert.Equal(t, expected, stmts)
57 }
58
View as plain text