1 package neo4j
2
3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "io/ioutil"
8 neturl "net/url"
9 "strconv"
10 "sync/atomic"
11
12 "github.com/golang-migrate/migrate/v4/database"
13 "github.com/golang-migrate/migrate/v4/database/multistmt"
14 "github.com/hashicorp/go-multierror"
15 "github.com/neo4j/neo4j-go-driver/neo4j"
16 )
17
18 func init() {
19 db := Neo4j{}
20 database.Register("neo4j", &db)
21 }
22
23 const DefaultMigrationsLabel = "SchemaMigration"
24
25 var (
26 StatementSeparator = []byte(";")
27 DefaultMultiStatementMaxSize = 10 * 1 << 20
28 )
29
30 var (
31 ErrNilConfig = fmt.Errorf("no config")
32 )
33
34 type Config struct {
35 MigrationsLabel string
36 MultiStatement bool
37 MultiStatementMaxSize int
38 }
39
40 type Neo4j struct {
41 driver neo4j.Driver
42 lock uint32
43
44
45 config *Config
46 }
47
48 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
49 if config == nil {
50 return nil, ErrNilConfig
51 }
52
53 nDriver := &Neo4j{
54 driver: driver,
55 config: config,
56 }
57
58 if err := nDriver.ensureVersionConstraint(); err != nil {
59 return nil, err
60 }
61
62 return nDriver, nil
63 }
64
65 func (n *Neo4j) Open(url string) (database.Driver, error) {
66 uri, err := neturl.Parse(url)
67 if err != nil {
68 return nil, err
69 }
70 password, _ := uri.User.Password()
71 authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
72 uri.User = nil
73 uri.Scheme = "bolt"
74 msQuery := uri.Query().Get("x-multi-statement")
75
76
77 tlsEncrypted := uri.Query().Get("x-tls-encrypted")
78 multi := false
79 encrypted := false
80 if msQuery != "" {
81 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
82 if err != nil {
83 return nil, err
84 }
85 }
86
87 if tlsEncrypted != "" {
88 encrypted, err = strconv.ParseBool(tlsEncrypted)
89 if err != nil {
90 return nil, err
91 }
92 }
93
94 multiStatementMaxSize := DefaultMultiStatementMaxSize
95 if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
96 multiStatementMaxSize, err = strconv.Atoi(s)
97 if err != nil {
98 return nil, err
99 }
100 }
101
102 uri.RawQuery = ""
103
104 driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
105 config.Encrypted = encrypted
106 })
107 if err != nil {
108 return nil, err
109 }
110
111 return WithInstance(driver, &Config{
112 MigrationsLabel: DefaultMigrationsLabel,
113 MultiStatement: multi,
114 MultiStatementMaxSize: multiStatementMaxSize,
115 })
116 }
117
118 func (n *Neo4j) Close() error {
119 return n.driver.Close()
120 }
121
122
123 func (n *Neo4j) Lock() error {
124 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
125 return database.ErrLocked
126 }
127
128 return nil
129 }
130
131 func (n *Neo4j) Unlock() error {
132 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
133 return database.ErrNotLocked
134 }
135 return nil
136 }
137
138 func (n *Neo4j) Run(migration io.Reader) (err error) {
139 session, err := n.driver.Session(neo4j.AccessModeWrite)
140 if err != nil {
141 return err
142 }
143 defer func() {
144 if cerr := session.Close(); cerr != nil {
145 err = multierror.Append(err, cerr)
146 }
147 }()
148
149 if n.config.MultiStatement {
150 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
151 var stmtRunErr error
152 if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
153 trimStmt := bytes.TrimSpace(stmt)
154 if len(trimStmt) == 0 {
155 return true
156 }
157 trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
158 if len(trimStmt) == 0 {
159 return true
160 }
161
162 result, err := transaction.Run(string(trimStmt), nil)
163 if _, err := neo4j.Collect(result, err); err != nil {
164 stmtRunErr = err
165 return false
166 }
167 return true
168 }); err != nil {
169 return nil, err
170 }
171 return nil, stmtRunErr
172 })
173 return err
174 }
175
176 body, err := ioutil.ReadAll(migration)
177 if err != nil {
178 return err
179 }
180
181 _, err = neo4j.Collect(session.Run(string(body[:]), nil))
182 return err
183 }
184
185 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
186 session, err := n.driver.Session(neo4j.AccessModeWrite)
187 if err != nil {
188 return err
189 }
190 defer func() {
191 if cerr := session.Close(); cerr != nil {
192 err = multierror.Append(err, cerr)
193 }
194 }()
195
196 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
197 n.config.MigrationsLabel)
198 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
199 if err != nil {
200 return err
201 }
202 return nil
203 }
204
205 type MigrationRecord struct {
206 Version int
207 Dirty bool
208 }
209
210 func (n *Neo4j) Version() (version int, dirty bool, err error) {
211 session, err := n.driver.Session(neo4j.AccessModeRead)
212 if err != nil {
213 return database.NilVersion, false, err
214 }
215 defer func() {
216 if cerr := session.Close(); cerr != nil {
217 err = multierror.Append(err, cerr)
218 }
219 }()
220
221 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
222 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
223 n.config.MigrationsLabel)
224 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
225 result, err := transaction.Run(query, nil)
226 if err != nil {
227 return nil, err
228 }
229 if result.Next() {
230 record := result.Record()
231 mr := MigrationRecord{}
232 versionResult, ok := record.Get("version")
233 if !ok {
234 mr.Version = database.NilVersion
235 } else {
236 mr.Version = int(versionResult.(int64))
237 }
238
239 dirtyResult, ok := record.Get("dirty")
240 if ok {
241 mr.Dirty = dirtyResult.(bool)
242 }
243
244 return mr, nil
245 }
246 return nil, result.Err()
247 })
248 if err != nil {
249 return database.NilVersion, false, err
250 }
251 if result == nil {
252 return database.NilVersion, false, err
253 }
254 mr := result.(MigrationRecord)
255 return mr.Version, mr.Dirty, err
256 }
257
258 func (n *Neo4j) Drop() (err error) {
259 session, err := n.driver.Session(neo4j.AccessModeWrite)
260 if err != nil {
261 return err
262 }
263 defer func() {
264 if cerr := session.Close(); cerr != nil {
265 err = multierror.Append(err, cerr)
266 }
267 }()
268
269 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
270 return err
271 }
272 return nil
273 }
274
275 func (n *Neo4j) ensureVersionConstraint() (err error) {
276 session, err := n.driver.Session(neo4j.AccessModeWrite)
277 if err != nil {
278 return err
279 }
280 defer func() {
281 if cerr := session.Close(); cerr != nil {
282 err = multierror.Append(err, cerr)
283 }
284 }()
285
286
291 res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
292 if err != nil {
293 return err
294 }
295 if len(res) == 1 {
296 return nil
297 }
298
299 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
300 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
301 return err
302 }
303 return nil
304 }
305
View as plain text