     1  // Copyright 2016 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package concurrency
    17  import (
    18  	"context"
    19  	"math"
    21  	v3 "go.etcd.io/etcd/client/v3"
    22  )
    24  // STM is an interface for software transactional memory.
    25  type STM interface {
    26  	// Get returns the value for a key and inserts the key in the txn's read set.
    27  	// If Get fails, it aborts the transaction with an error, never returning.
    28  	Get(key ...string) string
    29  	// Put adds a value for a key to the write set.
    30  	Put(key, val string, opts ...v3.OpOption)
    31  	// Rev returns the revision of a key in the read set.
    32  	Rev(key string) int64
    33  	// Del deletes a key.
    34  	Del(key string)
    36  	// commit attempts to apply the txn's changes to the server.
    37  	commit() *v3.TxnResponse
    38  	reset()
    39  }
    41  // Isolation is an enumeration of transactional isolation levels which
    42  // describes how transactions should interfere and conflict.
    43  type Isolation int
    45  const (
    46  	// SerializableSnapshot provides serializable isolation and also checks
    47  	// for write conflicts.
    48  	SerializableSnapshot Isolation = iota
    49  	// Serializable reads within the same transaction attempt return data
    50  	// from the at the revision of the first read.
    51  	Serializable
    52  	// RepeatableReads reads within the same transaction attempt always
    53  	// return the same data.
    54  	RepeatableReads
    55  	// ReadCommitted reads keys from any committed revision.
    56  	ReadCommitted
    57  )
    59  // stmError safely passes STM errors through panic to the STM error channel.
    60  type stmError struct{ err error }
    62  type stmOptions struct {
    63  	iso      Isolation
    64  	ctx      context.Context
    65  	prefetch []string
    66  }
    68  type stmOption func(*stmOptions)
    70  // WithIsolation specifies the transaction isolation level.
    71  func WithIsolation(lvl Isolation) stmOption {
    72  	return func(so *stmOptions) { so.iso = lvl }
    73  }
    75  // WithAbortContext specifies the context for permanently aborting the transaction.
    76  func WithAbortContext(ctx context.Context) stmOption {
    77  	return func(so *stmOptions) { so.ctx = ctx }
    78  }
    80  // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
    81  // If an STM transaction will unconditionally fetch a set of keys, prefetching
    82  // those keys will save the round-trip cost from requesting each key one by one
    83  // with Get().
    84  func WithPrefetch(keys ...string) stmOption {
    85  	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
    86  }
    88  // NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
    89  func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
    90  	opts := &stmOptions{ctx: c.Ctx()}
    91  	for _, f := range so {
    92  		f(opts)
    93  	}
    94  	if len(opts.prefetch) != 0 {
    95  		f := apply
    96  		apply = func(s STM) error {
    97  			s.Get(opts.prefetch...)
    98  			return f(s)
    99  		}
   100  	}
   101  	return runSTM(mkSTM(c, opts), apply)
   102  }
   104  func mkSTM(c *v3.Client, opts *stmOptions) STM {
   105  	switch opts.iso {
   106  	case SerializableSnapshot:
   107  		s := &stmSerializable{
   108  			stm:      stm{client: c, ctx: opts.ctx},
   109  			prefetch: make(map[string]*v3.GetResponse),
   110  		}
   111  		s.conflicts = func() []v3.Cmp {
   112  			return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
   113  		}
   114  		return s
   115  	case Serializable:
   116  		s := &stmSerializable{
   117  			stm:      stm{client: c, ctx: opts.ctx},
   118  			prefetch: make(map[string]*v3.GetResponse),
   119  		}
   120  		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
   121  		return s
   122  	case RepeatableReads:
   123  		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
   124  		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
   125  		return s
   126  	case ReadCommitted:
   127  		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
   128  		s.conflicts = func() []v3.Cmp { return nil }
   129  		return s
   130  	default:
   131  		panic("unsupported stm")
   132  	}
   133  }
   135  type stmResponse struct {
   136  	resp *v3.TxnResponse
   137  	err  error
   138  }
   140  func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
   141  	outc := make(chan stmResponse, 1)
   142  	go func() {
   143  		defer func() {
   144  			if r := recover(); r != nil {
   145  				e, ok := r.(stmError)
   146  				if !ok {
   147  					// client apply panicked
   148  					panic(r)
   149  				}
   150  				outc <- stmResponse{nil, e.err}
   151  			}
   152  		}()
   153  		var out stmResponse
   154  		for {
   155  			s.reset()
   156  			if out.err = apply(s); out.err != nil {
   157  				break
   158  			}
   159  			if out.resp = s.commit(); out.resp != nil {
   160  				break
   161  			}
   162  		}
   163  		outc <- out
   164  	}()
   165  	r := <-outc
   166  	return r.resp, r.err
   167  }
   169  // stm implements repeatable-read software transactional memory over etcd
   170  type stm struct {
   171  	client *v3.Client
   172  	ctx    context.Context
   173  	// rset holds read key values and revisions
   174  	rset readSet
   175  	// wset holds overwritten keys and their values
   176  	wset writeSet
   177  	// getOpts are the opts used for gets
   178  	getOpts []v3.OpOption
   179  	// conflicts computes the current conflicts on the txn
   180  	conflicts func() []v3.Cmp
   181  }
   183  type stmPut struct {
   184  	val string
   185  	op  v3.Op
   186  }
   188  type readSet map[string]*v3.GetResponse
   190  func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
   191  	for i, resp := range txnresp.Responses {
   192  		rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
   193  	}
   194  }
   196  // first returns the store revision from the first fetch
   197  func (rs readSet) first() int64 {
   198  	ret := int64(math.MaxInt64 - 1)
   199  	for _, resp := range rs {
   200  		if rev := resp.Header.Revision; rev < ret {
   201  			ret = rev
   202  		}
   203  	}
   204  	return ret
   205  }
   207  // cmps guards the txn from updates to read set
   208  func (rs readSet) cmps() []v3.Cmp {
   209  	cmps := make([]v3.Cmp, 0, len(rs))
   210  	for k, rk := range rs {
   211  		cmps = append(cmps, isKeyCurrent(k, rk))
   212  	}
   213  	return cmps
   214  }
   216  type writeSet map[string]stmPut
   218  func (ws writeSet) get(keys ...string) *stmPut {
   219  	for _, key := range keys {
   220  		if wv, ok := ws[key]; ok {
   221  			return &wv
   222  		}
   223  	}
   224  	return nil
   225  }
   227  // cmps returns a cmp list testing no writes have happened past rev
   228  func (ws writeSet) cmps(rev int64) []v3.Cmp {
   229  	cmps := make([]v3.Cmp, 0, len(ws))
   230  	for key := range ws {
   231  		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
   232  	}
   233  	return cmps
   234  }
   236  // puts is the list of ops for all pending writes
   237  func (ws writeSet) puts() []v3.Op {
   238  	puts := make([]v3.Op, 0, len(ws))
   239  	for _, v := range ws {
   240  		puts = append(puts, v.op)
   241  	}
   242  	return puts
   243  }
   245  func (s *stm) Get(keys ...string) string {
   246  	if wv := s.wset.get(keys...); wv != nil {
   247  		return wv.val
   248  	}
   249  	return respToValue(s.fetch(keys...))
   250  }
   252  func (s *stm) Put(key, val string, opts ...v3.OpOption) {
   253  	s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
   254  }
   256  func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
   258  func (s *stm) Rev(key string) int64 {
   259  	if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
   260  		return resp.Kvs[0].ModRevision
   261  	}
   262  	return 0
   263  }
   265  func (s *stm) commit() *v3.TxnResponse {
   266  	txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
   267  	if err != nil {
   268  		panic(stmError{err})
   269  	}
   270  	if txnresp.Succeeded {
   271  		return txnresp
   272  	}
   273  	return nil
   274  }
   276  func (s *stm) fetch(keys ...string) *v3.GetResponse {
   277  	if len(keys) == 0 {
   278  		return nil
   279  	}
   280  	ops := make([]v3.Op, len(keys))
   281  	for i, key := range keys {
   282  		if resp, ok := s.rset[key]; ok {
   283  			return resp
   284  		}
   285  		ops[i] = v3.OpGet(key, s.getOpts...)
   286  	}
   287  	txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
   288  	if err != nil {
   289  		panic(stmError{err})
   290  	}
   291  	s.rset.add(keys, txnresp)
   292  	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
   293  }
   295  func (s *stm) reset() {
   296  	s.rset = make(map[string]*v3.GetResponse)
   297  	s.wset = make(map[string]stmPut)
   298  }
   300  type stmSerializable struct {
   301  	stm
   302  	prefetch map[string]*v3.GetResponse
   303  }
   305  func (s *stmSerializable) Get(keys ...string) string {
   306  	if wv := s.wset.get(keys...); wv != nil {
   307  		return wv.val
   308  	}
   309  	firstRead := len(s.rset) == 0
   310  	for _, key := range keys {
   311  		if resp, ok := s.prefetch[key]; ok {
   312  			delete(s.prefetch, key)
   313  			s.rset[key] = resp
   314  		}
   315  	}
   316  	resp := s.stm.fetch(keys...)
   317  	if firstRead {
   318  		// txn's base revision is defined by the first read
   319  		s.getOpts = []v3.OpOption{
   320  			v3.WithRev(resp.Header.Revision),
   321  			v3.WithSerializable(),
   322  		}
   323  	}
   324  	return respToValue(resp)
   325  }
   327  func (s *stmSerializable) Rev(key string) int64 {
   328  	s.Get(key)
   329  	return s.stm.Rev(key)
   330  }
   332  func (s *stmSerializable) gets() ([]string, []v3.Op) {
   333  	keys := make([]string, 0, len(s.rset))
   334  	ops := make([]v3.Op, 0, len(s.rset))
   335  	for k := range s.rset {
   336  		keys = append(keys, k)
   337  		ops = append(ops, v3.OpGet(k))
   338  	}
   339  	return keys, ops
   340  }
   342  func (s *stmSerializable) commit() *v3.TxnResponse {
   343  	keys, getops := s.gets()
   344  	txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
   345  	// use Else to prefetch keys in case of conflict to save a round trip
   346  	txnresp, err := txn.Else(getops...).Commit()
   347  	if err != nil {
   348  		panic(stmError{err})
   349  	}
   350  	if txnresp.Succeeded {
   351  		return txnresp
   352  	}
   353  	// load prefetch with Else data
   354  	s.rset.add(keys, txnresp)
   355  	s.prefetch = s.rset
   356  	s.getOpts = nil
   357  	return nil
   358  }
   360  func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
   361  	if len(r.Kvs) != 0 {
   362  		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
   363  	}
   364  	return v3.Compare(v3.ModRevision(k), "=", 0)
   365  }
   367  func respToValue(resp *v3.GetResponse) string {
   368  	if resp == nil || len(resp.Kvs) == 0 {
   369  		return ""
   370  	}
   371  	return string(resp.Kvs[0].Value)
   372  }
   374  // NewSTMRepeatable is deprecated.
   375  func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
   376  	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
   377  }
   379  // NewSTMSerializable is deprecated.
   380  func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
   381  	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
   382  }
   384  // NewSTMReadCommitted is deprecated.
   385  func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
   386  	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
   387  }

