1 package clickhouse
2
3 import (
4 "database/sql"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "net/url"
9 "strconv"
10 "strings"
11 "time"
12
13 "go.uber.org/atomic"
14
15 "github.com/golang-migrate/migrate/v4"
16 "github.com/golang-migrate/migrate/v4/database"
17 "github.com/golang-migrate/migrate/v4/database/multistmt"
18 "github.com/hashicorp/go-multierror"
19 )
20
21 var (
22 multiStmtDelimiter = []byte(";")
23
24 DefaultMigrationsTable = "schema_migrations"
25 DefaultMigrationsTableEngine = "TinyLog"
26 DefaultMultiStatementMaxSize = 10 * 1 << 20
27
28 ErrNilConfig = fmt.Errorf("no config")
29 )
30
31 type Config struct {
32 DatabaseName string
33 ClusterName string
34 MigrationsTable string
35 MigrationsTableEngine string
36 MultiStatementEnabled bool
37 MultiStatementMaxSize int
38 }
39
40 func init() {
41 database.Register("clickhouse", &ClickHouse{})
42 }
43
44 func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
45 if config == nil {
46 return nil, ErrNilConfig
47 }
48
49 if err := conn.Ping(); err != nil {
50 return nil, err
51 }
52
53 ch := &ClickHouse{
54 conn: conn,
55 config: config,
56 }
57
58 if err := ch.init(); err != nil {
59 return nil, err
60 }
61
62 return ch, nil
63 }
64
65 type ClickHouse struct {
66 conn *sql.DB
67 config *Config
68 isLocked atomic.Bool
69 }
70
71 func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
72 purl, err := url.Parse(dsn)
73 if err != nil {
74 return nil, err
75 }
76 q := migrate.FilterCustomQuery(purl)
77 q.Scheme = "tcp"
78 conn, err := sql.Open("clickhouse", q.String())
79 if err != nil {
80 return nil, err
81 }
82
83 multiStatementMaxSize := DefaultMultiStatementMaxSize
84 if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
85 multiStatementMaxSize, err = strconv.Atoi(s)
86 if err != nil {
87 return nil, err
88 }
89 }
90
91 migrationsTableEngine := DefaultMigrationsTableEngine
92 if s := purl.Query().Get("x-migrations-table-engine"); len(s) > 0 {
93 migrationsTableEngine = s
94 }
95
96 ch = &ClickHouse{
97 conn: conn,
98 config: &Config{
99 MigrationsTable: purl.Query().Get("x-migrations-table"),
100 MigrationsTableEngine: migrationsTableEngine,
101 DatabaseName: purl.Query().Get("database"),
102 ClusterName: purl.Query().Get("x-cluster-name"),
103 MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
104 MultiStatementMaxSize: multiStatementMaxSize,
105 },
106 }
107
108 if err := ch.init(); err != nil {
109 return nil, err
110 }
111
112 return ch, nil
113 }
114
115 func (ch *ClickHouse) init() error {
116 if len(ch.config.DatabaseName) == 0 {
117 if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
118 return err
119 }
120 }
121
122 if len(ch.config.MigrationsTable) == 0 {
123 ch.config.MigrationsTable = DefaultMigrationsTable
124 }
125
126 if ch.config.MultiStatementMaxSize <= 0 {
127 ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
128 }
129
130 if len(ch.config.MigrationsTableEngine) == 0 {
131 ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
132 }
133
134 return ch.ensureVersionTable()
135 }
136
137 func (ch *ClickHouse) Run(r io.Reader) error {
138 if ch.config.MultiStatementEnabled {
139 var err error
140 if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
141 tq := strings.TrimSpace(string(m))
142 if tq == "" {
143 return true
144 }
145 if _, e := ch.conn.Exec(string(m)); e != nil {
146 err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
147 return false
148 }
149 return true
150 }); e != nil {
151 return e
152 }
153 return err
154 }
155
156 migration, err := ioutil.ReadAll(r)
157 if err != nil {
158 return err
159 }
160
161 if _, err := ch.conn.Exec(string(migration)); err != nil {
162 return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
163 }
164
165 return nil
166 }
167 func (ch *ClickHouse) Version() (int, bool, error) {
168 var (
169 version int
170 dirty uint8
171 query = "SELECT version, dirty FROM `" + ch.config.MigrationsTable + "` ORDER BY sequence DESC LIMIT 1"
172 )
173 if err := ch.conn.QueryRow(query).Scan(&version, &dirty); err != nil {
174 if err == sql.ErrNoRows {
175 return database.NilVersion, false, nil
176 }
177 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
178 }
179 return version, dirty == 1, nil
180 }
181
182 func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
183 var (
184 bool = func(v bool) uint8 {
185 if v {
186 return 1
187 }
188 return 0
189 }
190 tx, err = ch.conn.Begin()
191 )
192 if err != nil {
193 return err
194 }
195
196 query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
197 if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
198 return &database.Error{OrigErr: err, Query: []byte(query)}
199 }
200
201 return tx.Commit()
202 }
203
204
205
206
207 func (ch *ClickHouse) ensureVersionTable() (err error) {
208 if err = ch.Lock(); err != nil {
209 return err
210 }
211
212 defer func() {
213 if e := ch.Unlock(); e != nil {
214 if err == nil {
215 err = e
216 } else {
217 err = multierror.Append(err, e)
218 }
219 }
220 }()
221
222 var (
223 table string
224 query = "SHOW TABLES FROM " + ch.config.DatabaseName + " LIKE '" + ch.config.MigrationsTable + "'"
225 )
226
227 if err := ch.conn.QueryRow(query).Scan(&table); err != nil {
228 if err != sql.ErrNoRows {
229 return &database.Error{OrigErr: err, Query: []byte(query)}
230 }
231 } else {
232 return nil
233 }
234
235
236 if len(ch.config.ClusterName) > 0 {
237 query = fmt.Sprintf(`
238 CREATE TABLE %s ON CLUSTER %s (
239 version Int64,
240 dirty UInt8,
241 sequence UInt64
242 ) Engine=%s`, ch.config.MigrationsTable, ch.config.ClusterName, ch.config.MigrationsTableEngine)
243 } else {
244 query = fmt.Sprintf(`
245 CREATE TABLE %s (
246 version Int64,
247 dirty UInt8,
248 sequence UInt64
249 ) Engine=%s`, ch.config.MigrationsTable, ch.config.MigrationsTableEngine)
250 }
251
252 if strings.HasSuffix(ch.config.MigrationsTableEngine, "Tree") {
253 query = fmt.Sprintf(`%s ORDER BY sequence`, query)
254 }
255
256 if _, err := ch.conn.Exec(query); err != nil {
257 return &database.Error{OrigErr: err, Query: []byte(query)}
258 }
259 return nil
260 }
261
262 func (ch *ClickHouse) Drop() (err error) {
263 query := "SHOW TABLES FROM " + ch.config.DatabaseName
264 tables, err := ch.conn.Query(query)
265
266 if err != nil {
267 return &database.Error{OrigErr: err, Query: []byte(query)}
268 }
269 defer func() {
270 if errClose := tables.Close(); errClose != nil {
271 err = multierror.Append(err, errClose)
272 }
273 }()
274
275 for tables.Next() {
276 var table string
277 if err := tables.Scan(&table); err != nil {
278 return err
279 }
280
281 query = "DROP TABLE IF EXISTS " + ch.config.DatabaseName + "." + table
282
283 if _, err := ch.conn.Exec(query); err != nil {
284 return &database.Error{OrigErr: err, Query: []byte(query)}
285 }
286 }
287 if err := tables.Err(); err != nil {
288 return &database.Error{OrigErr: err, Query: []byte(query)}
289 }
290
291 return nil
292 }
293
294 func (ch *ClickHouse) Lock() error {
295 if !ch.isLocked.CAS(false, true) {
296 return database.ErrLocked
297 }
298
299 return nil
300 }
301 func (ch *ClickHouse) Unlock() error {
302 if !ch.isLocked.CAS(true, false) {
303 return database.ErrNotLocked
304 }
305
306 return nil
307 }
308 func (ch *ClickHouse) Close() error { return ch.conn.Close() }
309
View as plain text