...

Source file src/github.com/lib/pq/conn_go18.go

Documentation: github.com/lib/pq

     1  package pq
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"time"
    11  )
    12  
    13  const (
    14  	watchCancelDialContextTimeout = time.Second * 10
    15  )
    16  
    17  // Implement the "QueryerContext" interface
    18  func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
    19  	list := make([]driver.Value, len(args))
    20  	for i, nv := range args {
    21  		list[i] = nv.Value
    22  	}
    23  	finish := cn.watchCancel(ctx)
    24  	r, err := cn.query(query, list)
    25  	if err != nil {
    26  		if finish != nil {
    27  			finish()
    28  		}
    29  		return nil, err
    30  	}
    31  	r.finish = finish
    32  	return r, nil
    33  }
    34  
    35  // Implement the "ExecerContext" interface
    36  func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
    37  	list := make([]driver.Value, len(args))
    38  	for i, nv := range args {
    39  		list[i] = nv.Value
    40  	}
    41  
    42  	if finish := cn.watchCancel(ctx); finish != nil {
    43  		defer finish()
    44  	}
    45  
    46  	return cn.Exec(query, list)
    47  }
    48  
    49  // Implement the "ConnPrepareContext" interface
    50  func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
    51  	if finish := cn.watchCancel(ctx); finish != nil {
    52  		defer finish()
    53  	}
    54  	return cn.Prepare(query)
    55  }
    56  
    57  // Implement the "ConnBeginTx" interface
    58  func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    59  	var mode string
    60  
    61  	switch sql.IsolationLevel(opts.Isolation) {
    62  	case sql.LevelDefault:
    63  		// Don't touch mode: use the server's default
    64  	case sql.LevelReadUncommitted:
    65  		mode = " ISOLATION LEVEL READ UNCOMMITTED"
    66  	case sql.LevelReadCommitted:
    67  		mode = " ISOLATION LEVEL READ COMMITTED"
    68  	case sql.LevelRepeatableRead:
    69  		mode = " ISOLATION LEVEL REPEATABLE READ"
    70  	case sql.LevelSerializable:
    71  		mode = " ISOLATION LEVEL SERIALIZABLE"
    72  	default:
    73  		return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
    74  	}
    75  
    76  	if opts.ReadOnly {
    77  		mode += " READ ONLY"
    78  	} else {
    79  		mode += " READ WRITE"
    80  	}
    81  
    82  	tx, err := cn.begin(mode)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	cn.txnFinish = cn.watchCancel(ctx)
    87  	return tx, nil
    88  }
    89  
    90  func (cn *conn) Ping(ctx context.Context) error {
    91  	if finish := cn.watchCancel(ctx); finish != nil {
    92  		defer finish()
    93  	}
    94  	rows, err := cn.simpleQuery(";")
    95  	if err != nil {
    96  		return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
    97  	}
    98  	rows.Close()
    99  	return nil
   100  }
   101  
   102  func (cn *conn) watchCancel(ctx context.Context) func() {
   103  	if done := ctx.Done(); done != nil {
   104  		finished := make(chan struct{}, 1)
   105  		go func() {
   106  			select {
   107  			case <-done:
   108  				select {
   109  				case finished <- struct{}{}:
   110  				default:
   111  					// We raced with the finish func, let the next query handle this with the
   112  					// context.
   113  					return
   114  				}
   115  
   116  				// Set the connection state to bad so it does not get reused.
   117  				cn.err.set(ctx.Err())
   118  
   119  				// At this point the function level context is canceled,
   120  				// so it must not be used for the additional network
   121  				// request to cancel the query.
   122  				// Create a new context to pass into the dial.
   123  				ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
   124  				defer cancel()
   125  
   126  				_ = cn.cancel(ctxCancel)
   127  			case <-finished:
   128  			}
   129  		}()
   130  		return func() {
   131  			select {
   132  			case <-finished:
   133  				cn.err.set(ctx.Err())
   134  				cn.Close()
   135  			case finished <- struct{}{}:
   136  			}
   137  		}
   138  	}
   139  	return nil
   140  }
   141  
   142  func (cn *conn) cancel(ctx context.Context) error {
   143  	// Create a new values map (copy). This makes sure the connection created
   144  	// in this method cannot write to the same underlying data, which could
   145  	// cause a concurrent map write panic. This is necessary because cancel
   146  	// is called from a goroutine in watchCancel.
   147  	o := make(values)
   148  	for k, v := range cn.opts {
   149  		o[k] = v
   150  	}
   151  
   152  	c, err := dial(ctx, cn.dialer, o)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	defer c.Close()
   157  
   158  	{
   159  		can := conn{
   160  			c: c,
   161  		}
   162  		err = can.ssl(o)
   163  		if err != nil {
   164  			return err
   165  		}
   166  
   167  		w := can.writeBuf(0)
   168  		w.int32(80877102) // cancel request code
   169  		w.int32(cn.processID)
   170  		w.int32(cn.secretKey)
   171  
   172  		if err := can.sendStartupPacket(w); err != nil {
   173  			return err
   174  		}
   175  	}
   176  
   177  	// Read until EOF to ensure that the server received the cancel.
   178  	{
   179  		_, err := io.Copy(ioutil.Discard, c)
   180  		return err
   181  	}
   182  }
   183  
   184  // Implement the "StmtQueryContext" interface
   185  func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   186  	list := make([]driver.Value, len(args))
   187  	for i, nv := range args {
   188  		list[i] = nv.Value
   189  	}
   190  	finish := st.watchCancel(ctx)
   191  	r, err := st.query(list)
   192  	if err != nil {
   193  		if finish != nil {
   194  			finish()
   195  		}
   196  		return nil, err
   197  	}
   198  	r.finish = finish
   199  	return r, nil
   200  }
   201  
   202  // Implement the "StmtExecContext" interface
   203  func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   204  	list := make([]driver.Value, len(args))
   205  	for i, nv := range args {
   206  		list[i] = nv.Value
   207  	}
   208  
   209  	if finish := st.watchCancel(ctx); finish != nil {
   210  		defer finish()
   211  	}
   212  
   213  	return st.Exec(list)
   214  }
   215  
   216  // watchCancel is implemented on stmt in order to not mark the parent conn as bad
   217  func (st *stmt) watchCancel(ctx context.Context) func() {
   218  	if done := ctx.Done(); done != nil {
   219  		finished := make(chan struct{})
   220  		go func() {
   221  			select {
   222  			case <-done:
   223  				// At this point the function level context is canceled,
   224  				// so it must not be used for the additional network
   225  				// request to cancel the query.
   226  				// Create a new context to pass into the dial.
   227  				ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
   228  				defer cancel()
   229  
   230  				_ = st.cancel(ctxCancel)
   231  				finished <- struct{}{}
   232  			case <-finished:
   233  			}
   234  		}()
   235  		return func() {
   236  			select {
   237  			case <-finished:
   238  			case finished <- struct{}{}:
   239  			}
   240  		}
   241  	}
   242  	return nil
   243  }
   244  
   245  func (st *stmt) cancel(ctx context.Context) error {
   246  	return st.cn.cancel(ctx)
   247  }
   248  

View as plain text