1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package concurrency
16
17 import (
18 "context"
19 "math"
20
21 v3 "go.etcd.io/etcd/client/v3"
22 )
23
24
25 type STM interface {
26
27
28 Get(key ...string) string
29
30 Put(key, val string, opts ...v3.OpOption)
31
32 Rev(key string) int64
33
34 Del(key string)
35
36
37 commit() *v3.TxnResponse
38 reset()
39 }
40
41
42
43 type Isolation int
44
45 const (
46
47
48 SerializableSnapshot Isolation = iota
49
50
51 Serializable
52
53
54 RepeatableReads
55
56 ReadCommitted
57 )
58
59
60 type stmError struct{ err error }
61
62 type stmOptions struct {
63 iso Isolation
64 ctx context.Context
65 prefetch []string
66 }
67
68 type stmOption func(*stmOptions)
69
70
71 func WithIsolation(lvl Isolation) stmOption {
72 return func(so *stmOptions) { so.iso = lvl }
73 }
74
75
76 func WithAbortContext(ctx context.Context) stmOption {
77 return func(so *stmOptions) { so.ctx = ctx }
78 }
79
80
81
82
83
84 func WithPrefetch(keys ...string) stmOption {
85 return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
86 }
87
88
89 func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
90 opts := &stmOptions{ctx: c.Ctx()}
91 for _, f := range so {
92 f(opts)
93 }
94 if len(opts.prefetch) != 0 {
95 f := apply
96 apply = func(s STM) error {
97 s.Get(opts.prefetch...)
98 return f(s)
99 }
100 }
101 return runSTM(mkSTM(c, opts), apply)
102 }
103
104 func mkSTM(c *v3.Client, opts *stmOptions) STM {
105 switch opts.iso {
106 case SerializableSnapshot:
107 s := &stmSerializable{
108 stm: stm{client: c, ctx: opts.ctx},
109 prefetch: make(map[string]*v3.GetResponse),
110 }
111 s.conflicts = func() []v3.Cmp {
112 return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
113 }
114 return s
115 case Serializable:
116 s := &stmSerializable{
117 stm: stm{client: c, ctx: opts.ctx},
118 prefetch: make(map[string]*v3.GetResponse),
119 }
120 s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
121 return s
122 case RepeatableReads:
123 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
124 s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
125 return s
126 case ReadCommitted:
127 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
128 s.conflicts = func() []v3.Cmp { return nil }
129 return s
130 default:
131 panic("unsupported stm")
132 }
133 }
134
135 type stmResponse struct {
136 resp *v3.TxnResponse
137 err error
138 }
139
140 func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
141 outc := make(chan stmResponse, 1)
142 go func() {
143 defer func() {
144 if r := recover(); r != nil {
145 e, ok := r.(stmError)
146 if !ok {
147
148 panic(r)
149 }
150 outc <- stmResponse{nil, e.err}
151 }
152 }()
153 var out stmResponse
154 for {
155 s.reset()
156 if out.err = apply(s); out.err != nil {
157 break
158 }
159 if out.resp = s.commit(); out.resp != nil {
160 break
161 }
162 }
163 outc <- out
164 }()
165 r := <-outc
166 return r.resp, r.err
167 }
168
169
170 type stm struct {
171 client *v3.Client
172 ctx context.Context
173
174 rset readSet
175
176 wset writeSet
177
178 getOpts []v3.OpOption
179
180 conflicts func() []v3.Cmp
181 }
182
183 type stmPut struct {
184 val string
185 op v3.Op
186 }
187
188 type readSet map[string]*v3.GetResponse
189
190 func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
191 for i, resp := range txnresp.Responses {
192 rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
193 }
194 }
195
196
197 func (rs readSet) first() int64 {
198 ret := int64(math.MaxInt64 - 1)
199 for _, resp := range rs {
200 if rev := resp.Header.Revision; rev < ret {
201 ret = rev
202 }
203 }
204 return ret
205 }
206
207
208 func (rs readSet) cmps() []v3.Cmp {
209 cmps := make([]v3.Cmp, 0, len(rs))
210 for k, rk := range rs {
211 cmps = append(cmps, isKeyCurrent(k, rk))
212 }
213 return cmps
214 }
215
216 type writeSet map[string]stmPut
217
218 func (ws writeSet) get(keys ...string) *stmPut {
219 for _, key := range keys {
220 if wv, ok := ws[key]; ok {
221 return &wv
222 }
223 }
224 return nil
225 }
226
227
228 func (ws writeSet) cmps(rev int64) []v3.Cmp {
229 cmps := make([]v3.Cmp, 0, len(ws))
230 for key := range ws {
231 cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
232 }
233 return cmps
234 }
235
236
237 func (ws writeSet) puts() []v3.Op {
238 puts := make([]v3.Op, 0, len(ws))
239 for _, v := range ws {
240 puts = append(puts, v.op)
241 }
242 return puts
243 }
244
245 func (s *stm) Get(keys ...string) string {
246 if wv := s.wset.get(keys...); wv != nil {
247 return wv.val
248 }
249 return respToValue(s.fetch(keys...))
250 }
251
252 func (s *stm) Put(key, val string, opts ...v3.OpOption) {
253 s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
254 }
255
256 func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
257
258 func (s *stm) Rev(key string) int64 {
259 if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
260 return resp.Kvs[0].ModRevision
261 }
262 return 0
263 }
264
265 func (s *stm) commit() *v3.TxnResponse {
266 txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
267 if err != nil {
268 panic(stmError{err})
269 }
270 if txnresp.Succeeded {
271 return txnresp
272 }
273 return nil
274 }
275
276 func (s *stm) fetch(keys ...string) *v3.GetResponse {
277 if len(keys) == 0 {
278 return nil
279 }
280 ops := make([]v3.Op, len(keys))
281 for i, key := range keys {
282 if resp, ok := s.rset[key]; ok {
283 return resp
284 }
285 ops[i] = v3.OpGet(key, s.getOpts...)
286 }
287 txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
288 if err != nil {
289 panic(stmError{err})
290 }
291 s.rset.add(keys, txnresp)
292 return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
293 }
294
295 func (s *stm) reset() {
296 s.rset = make(map[string]*v3.GetResponse)
297 s.wset = make(map[string]stmPut)
298 }
299
300 type stmSerializable struct {
301 stm
302 prefetch map[string]*v3.GetResponse
303 }
304
305 func (s *stmSerializable) Get(keys ...string) string {
306 if wv := s.wset.get(keys...); wv != nil {
307 return wv.val
308 }
309 firstRead := len(s.rset) == 0
310 for _, key := range keys {
311 if resp, ok := s.prefetch[key]; ok {
312 delete(s.prefetch, key)
313 s.rset[key] = resp
314 }
315 }
316 resp := s.stm.fetch(keys...)
317 if firstRead {
318
319 s.getOpts = []v3.OpOption{
320 v3.WithRev(resp.Header.Revision),
321 v3.WithSerializable(),
322 }
323 }
324 return respToValue(resp)
325 }
326
327 func (s *stmSerializable) Rev(key string) int64 {
328 s.Get(key)
329 return s.stm.Rev(key)
330 }
331
332 func (s *stmSerializable) gets() ([]string, []v3.Op) {
333 keys := make([]string, 0, len(s.rset))
334 ops := make([]v3.Op, 0, len(s.rset))
335 for k := range s.rset {
336 keys = append(keys, k)
337 ops = append(ops, v3.OpGet(k))
338 }
339 return keys, ops
340 }
341
342 func (s *stmSerializable) commit() *v3.TxnResponse {
343 keys, getops := s.gets()
344 txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
345
346 txnresp, err := txn.Else(getops...).Commit()
347 if err != nil {
348 panic(stmError{err})
349 }
350 if txnresp.Succeeded {
351 return txnresp
352 }
353
354 s.rset.add(keys, txnresp)
355 s.prefetch = s.rset
356 s.getOpts = nil
357 return nil
358 }
359
360 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
361 if len(r.Kvs) != 0 {
362 return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
363 }
364 return v3.Compare(v3.ModRevision(k), "=", 0)
365 }
366
367 func respToValue(resp *v3.GetResponse) string {
368 if resp == nil || len(resp.Kvs) == 0 {
369 return ""
370 }
371 return string(resp.Kvs[0].Value)
372 }
373
374
375 func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
376 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
377 }
378
379
380 func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
381 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
382 }
383
384
385 func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
386 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
387 }
388
View as plain text