1 package cassandra
2
3 import (
4 "errors"
5 "fmt"
6 "go.uber.org/atomic"
7 "io"
8 "io/ioutil"
9 nurl "net/url"
10 "strconv"
11 "strings"
12 "time"
13
14 "github.com/gocql/gocql"
15 "github.com/golang-migrate/migrate/v4/database"
16 "github.com/golang-migrate/migrate/v4/database/multistmt"
17 "github.com/hashicorp/go-multierror"
18 )
19
20 func init() {
21 db := new(Cassandra)
22 database.Register("cassandra", db)
23 }
24
25 var (
26 multiStmtDelimiter = []byte(";")
27
28 DefaultMultiStatementMaxSize = 10 * 1 << 20
29 )
30
31 var DefaultMigrationsTable = "schema_migrations"
32
33 var (
34 ErrNilConfig = errors.New("no config")
35 ErrNoKeyspace = errors.New("no keyspace provided")
36 ErrDatabaseDirty = errors.New("database is dirty")
37 ErrClosedSession = errors.New("session is closed")
38 )
39
40 type Config struct {
41 MigrationsTable string
42 KeyspaceName string
43 MultiStatementEnabled bool
44 MultiStatementMaxSize int
45 }
46
47 type Cassandra struct {
48 session *gocql.Session
49 isLocked atomic.Bool
50
51
52 config *Config
53 }
54
55 func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
56 if config == nil {
57 return nil, ErrNilConfig
58 } else if len(config.KeyspaceName) == 0 {
59 return nil, ErrNoKeyspace
60 }
61
62 if session.Closed() {
63 return nil, ErrClosedSession
64 }
65
66 if len(config.MigrationsTable) == 0 {
67 config.MigrationsTable = DefaultMigrationsTable
68 }
69
70 if config.MultiStatementMaxSize <= 0 {
71 config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
72 }
73
74 c := &Cassandra{
75 session: session,
76 config: config,
77 }
78
79 if err := c.ensureVersionTable(); err != nil {
80 return nil, err
81 }
82
83 return c, nil
84 }
85
86 func (c *Cassandra) Open(url string) (database.Driver, error) {
87 u, err := nurl.Parse(url)
88 if err != nil {
89 return nil, err
90 }
91
92
93 if len(u.Path) == 0 {
94 return nil, ErrNoKeyspace
95 }
96
97 cluster := gocql.NewCluster(u.Host)
98 cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
99 cluster.Consistency = gocql.All
100 cluster.Timeout = 1 * time.Minute
101
102 if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
103 authenticator := gocql.PasswordAuthenticator{
104 Username: u.Query().Get("username"),
105 Password: u.Query().Get("password"),
106 }
107 cluster.Authenticator = authenticator
108 }
109
110
111 if len(u.Query().Get("consistency")) > 0 {
112 var consistency gocql.Consistency
113 consistency, err = parseConsistency(u.Query().Get("consistency"))
114 if err != nil {
115 return nil, err
116 }
117
118 cluster.Consistency = consistency
119 }
120 if len(u.Query().Get("protocol")) > 0 {
121 var protoversion int
122 protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
123 if err != nil {
124 return nil, err
125 }
126 cluster.ProtoVersion = protoversion
127 }
128 if len(u.Query().Get("timeout")) > 0 {
129 var timeout time.Duration
130 timeout, err = time.ParseDuration(u.Query().Get("timeout"))
131 if err != nil {
132 return nil, err
133 }
134 cluster.Timeout = timeout
135 }
136
137 if len(u.Query().Get("sslmode")) > 0 {
138 if u.Query().Get("sslmode") != "disable" {
139 sslOpts := &gocql.SslOptions{}
140
141 if len(u.Query().Get("sslrootcert")) > 0 {
142 sslOpts.CaPath = u.Query().Get("sslrootcert")
143 }
144 if len(u.Query().Get("sslcert")) > 0 {
145 sslOpts.CertPath = u.Query().Get("sslcert")
146 }
147 if len(u.Query().Get("sslkey")) > 0 {
148 sslOpts.KeyPath = u.Query().Get("sslkey")
149 }
150
151 if u.Query().Get("sslmode") == "verify-full" {
152 sslOpts.EnableHostVerification = true
153 }
154
155 cluster.SslOpts = sslOpts
156 }
157 }
158
159 if len(u.Query().Get("disable-host-lookup")) > 0 {
160 if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
161 cluster.DisableInitialHostLookup = true
162 } else if err != nil {
163 return nil, err
164 }
165 }
166
167 session, err := cluster.CreateSession()
168 if err != nil {
169 return nil, err
170 }
171
172 multiStatementMaxSize := DefaultMultiStatementMaxSize
173 if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
174 multiStatementMaxSize, err = strconv.Atoi(s)
175 if err != nil {
176 return nil, err
177 }
178 }
179
180 return WithInstance(session, &Config{
181 KeyspaceName: strings.TrimPrefix(u.Path, "/"),
182 MigrationsTable: u.Query().Get("x-migrations-table"),
183 MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
184 MultiStatementMaxSize: multiStatementMaxSize,
185 })
186 }
187
188 func (c *Cassandra) Close() error {
189 c.session.Close()
190 return nil
191 }
192
193 func (c *Cassandra) Lock() error {
194 if !c.isLocked.CAS(false, true) {
195 return database.ErrLocked
196 }
197 return nil
198 }
199
200 func (c *Cassandra) Unlock() error {
201 if !c.isLocked.CAS(true, false) {
202 return database.ErrNotLocked
203 }
204 return nil
205 }
206
207 func (c *Cassandra) Run(migration io.Reader) error {
208 if c.config.MultiStatementEnabled {
209 var err error
210 if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
211 tq := strings.TrimSpace(string(m))
212 if tq == "" {
213 return true
214 }
215 if e := c.session.Query(tq).Exec(); e != nil {
216 err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
217 return false
218 }
219 return true
220 }); e != nil {
221 return e
222 }
223 return err
224 }
225
226 migr, err := ioutil.ReadAll(migration)
227 if err != nil {
228 return err
229 }
230
231 if err := c.session.Query(string(migr)).Exec(); err != nil {
232
233 return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
234 }
235 return nil
236 }
237
238 func (c *Cassandra) SetVersion(version int, dirty bool) error {
239
240
241 squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
242 dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
243 iter := c.session.Query(squery).Iter()
244 var previous int
245 for iter.Scan(&previous) {
246 if err := c.session.Query(dquery, previous).Exec(); err != nil {
247 return &database.Error{OrigErr: err, Query: []byte(dquery)}
248 }
249 }
250 if err := iter.Close(); err != nil {
251 return &database.Error{OrigErr: err, Query: []byte(squery)}
252 }
253
254
255
256
257 if version >= 0 || (version == database.NilVersion && dirty) {
258 query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
259 if err := c.session.Query(query, version, dirty).Exec(); err != nil {
260 return &database.Error{OrigErr: err, Query: []byte(query)}
261 }
262 }
263
264 return nil
265 }
266
267
268 func (c *Cassandra) Version() (version int, dirty bool, err error) {
269 query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
270 err = c.session.Query(query).Scan(&version, &dirty)
271 switch {
272 case err == gocql.ErrNotFound:
273 return database.NilVersion, false, nil
274
275 case err != nil:
276 if _, ok := err.(*gocql.Error); ok {
277 return database.NilVersion, false, nil
278 }
279 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
280
281 default:
282 return version, dirty, nil
283 }
284 }
285
286 func (c *Cassandra) Drop() error {
287
288 query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
289 iter := c.session.Query(query).Iter()
290 var tableName string
291 for iter.Scan(&tableName) {
292 err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
293 if err != nil {
294 return err
295 }
296 }
297
298 return nil
299 }
300
301
302
303
304 func (c *Cassandra) ensureVersionTable() (err error) {
305 if err = c.Lock(); err != nil {
306 return err
307 }
308
309 defer func() {
310 if e := c.Unlock(); e != nil {
311 if err == nil {
312 err = e
313 } else {
314 err = multierror.Append(err, e)
315 }
316 }
317 }()
318
319 err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
320 if err != nil {
321 return err
322 }
323 if _, _, err = c.Version(); err != nil {
324 return err
325 }
326 return nil
327 }
328
329
330
331 func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
332 defer func() {
333 if r := recover(); r != nil {
334 var ok bool
335 err, ok = r.(error)
336 if !ok {
337 err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
338 }
339 }
340 }()
341 consistency = gocql.ParseConsistency(consistencyStr)
342
343 return consistency, nil
344 }
345
View as plain text