...

Source file src/github.com/lib/pq/copy.go

Documentation: github.com/lib/pq

     1  package pq
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"encoding/binary"
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  )
    12  
    13  var (
    14  	errCopyInClosed               = errors.New("pq: copyin statement has already been closed")
    15  	errBinaryCopyNotSupported     = errors.New("pq: only text format supported for COPY")
    16  	errCopyToNotSupported         = errors.New("pq: COPY TO is not supported")
    17  	errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
    18  	errCopyInProgress             = errors.New("pq: COPY in progress")
    19  )
    20  
    21  // CopyIn creates a COPY FROM statement which can be prepared with
    22  // Tx.Prepare().  The target table should be visible in search_path.
    23  func CopyIn(table string, columns ...string) string {
    24  	buffer := bytes.NewBufferString("COPY ")
    25  	BufferQuoteIdentifier(table, buffer)
    26  	buffer.WriteString(" (")
    27  	makeStmt(buffer, columns...)
    28  	return buffer.String()
    29  }
    30  
    31  // MakeStmt makes the stmt string for CopyIn and CopyInSchema.
    32  func makeStmt(buffer *bytes.Buffer, columns ...string) {
    33  	//s := bytes.NewBufferString()
    34  	for i, col := range columns {
    35  		if i != 0 {
    36  			buffer.WriteString(", ")
    37  		}
    38  		BufferQuoteIdentifier(col, buffer)
    39  	}
    40  	buffer.WriteString(") FROM STDIN")
    41  }
    42  
    43  // CopyInSchema creates a COPY FROM statement which can be prepared with
    44  // Tx.Prepare().
    45  func CopyInSchema(schema, table string, columns ...string) string {
    46  	buffer := bytes.NewBufferString("COPY ")
    47  	BufferQuoteIdentifier(schema, buffer)
    48  	buffer.WriteRune('.')
    49  	BufferQuoteIdentifier(table, buffer)
    50  	buffer.WriteString(" (")
    51  	makeStmt(buffer, columns...)
    52  	return buffer.String()
    53  }
    54  
    55  type copyin struct {
    56  	cn      *conn
    57  	buffer  []byte
    58  	rowData chan []byte
    59  	done    chan bool
    60  
    61  	closed bool
    62  
    63  	mu struct {
    64  		sync.Mutex
    65  		err error
    66  		driver.Result
    67  	}
    68  }
    69  
    70  const ciBufferSize = 64 * 1024
    71  
    72  // flush buffer before the buffer is filled up and needs reallocation
    73  const ciBufferFlushSize = 63 * 1024
    74  
    75  func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
    76  	if !cn.isInTransaction() {
    77  		return nil, errCopyNotSupportedOutsideTxn
    78  	}
    79  
    80  	ci := &copyin{
    81  		cn:      cn,
    82  		buffer:  make([]byte, 0, ciBufferSize),
    83  		rowData: make(chan []byte),
    84  		done:    make(chan bool, 1),
    85  	}
    86  	// add CopyData identifier + 4 bytes for message length
    87  	ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
    88  
    89  	b := cn.writeBuf('Q')
    90  	b.string(q)
    91  	cn.send(b)
    92  
    93  awaitCopyInResponse:
    94  	for {
    95  		t, r := cn.recv1()
    96  		switch t {
    97  		case 'G':
    98  			if r.byte() != 0 {
    99  				err = errBinaryCopyNotSupported
   100  				break awaitCopyInResponse
   101  			}
   102  			go ci.resploop()
   103  			return ci, nil
   104  		case 'H':
   105  			err = errCopyToNotSupported
   106  			break awaitCopyInResponse
   107  		case 'E':
   108  			err = parseError(r)
   109  		case 'Z':
   110  			if err == nil {
   111  				ci.setBad(driver.ErrBadConn)
   112  				errorf("unexpected ReadyForQuery in response to COPY")
   113  			}
   114  			cn.processReadyForQuery(r)
   115  			return nil, err
   116  		default:
   117  			ci.setBad(driver.ErrBadConn)
   118  			errorf("unknown response for copy query: %q", t)
   119  		}
   120  	}
   121  
   122  	// something went wrong, abort COPY before we return
   123  	b = cn.writeBuf('f')
   124  	b.string(err.Error())
   125  	cn.send(b)
   126  
   127  	for {
   128  		t, r := cn.recv1()
   129  		switch t {
   130  		case 'c', 'C', 'E':
   131  		case 'Z':
   132  			// correctly aborted, we're done
   133  			cn.processReadyForQuery(r)
   134  			return nil, err
   135  		default:
   136  			ci.setBad(driver.ErrBadConn)
   137  			errorf("unknown response for CopyFail: %q", t)
   138  		}
   139  	}
   140  }
   141  
   142  func (ci *copyin) flush(buf []byte) {
   143  	// set message length (without message identifier)
   144  	binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
   145  
   146  	_, err := ci.cn.c.Write(buf)
   147  	if err != nil {
   148  		panic(err)
   149  	}
   150  }
   151  
   152  func (ci *copyin) resploop() {
   153  	for {
   154  		var r readBuf
   155  		t, err := ci.cn.recvMessage(&r)
   156  		if err != nil {
   157  			ci.setBad(driver.ErrBadConn)
   158  			ci.setError(err)
   159  			ci.done <- true
   160  			return
   161  		}
   162  		switch t {
   163  		case 'C':
   164  			// complete
   165  			res, _ := ci.cn.parseComplete(r.string())
   166  			ci.setResult(res)
   167  		case 'N':
   168  			if n := ci.cn.noticeHandler; n != nil {
   169  				n(parseError(&r))
   170  			}
   171  		case 'Z':
   172  			ci.cn.processReadyForQuery(&r)
   173  			ci.done <- true
   174  			return
   175  		case 'E':
   176  			err := parseError(&r)
   177  			ci.setError(err)
   178  		default:
   179  			ci.setBad(driver.ErrBadConn)
   180  			ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
   181  			ci.done <- true
   182  			return
   183  		}
   184  	}
   185  }
   186  
   187  func (ci *copyin) setBad(err error) {
   188  	ci.cn.err.set(err)
   189  }
   190  
   191  func (ci *copyin) getBad() error {
   192  	return ci.cn.err.get()
   193  }
   194  
   195  func (ci *copyin) err() error {
   196  	ci.mu.Lock()
   197  	err := ci.mu.err
   198  	ci.mu.Unlock()
   199  	return err
   200  }
   201  
   202  // setError() sets ci.err if one has not been set already.  Caller must not be
   203  // holding ci.Mutex.
   204  func (ci *copyin) setError(err error) {
   205  	ci.mu.Lock()
   206  	if ci.mu.err == nil {
   207  		ci.mu.err = err
   208  	}
   209  	ci.mu.Unlock()
   210  }
   211  
   212  func (ci *copyin) setResult(result driver.Result) {
   213  	ci.mu.Lock()
   214  	ci.mu.Result = result
   215  	ci.mu.Unlock()
   216  }
   217  
   218  func (ci *copyin) getResult() driver.Result {
   219  	ci.mu.Lock()
   220  	result := ci.mu.Result
   221  	ci.mu.Unlock()
   222  	if result == nil {
   223  		return driver.RowsAffected(0)
   224  	}
   225  	return result
   226  }
   227  
   228  func (ci *copyin) NumInput() int {
   229  	return -1
   230  }
   231  
   232  func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
   233  	return nil, ErrNotSupported
   234  }
   235  
   236  // Exec inserts values into the COPY stream. The insert is asynchronous
   237  // and Exec can return errors from previous Exec calls to the same
   238  // COPY stmt.
   239  //
   240  // You need to call Exec(nil) to sync the COPY stream and to get any
   241  // errors from pending data, since Stmt.Close() doesn't return errors
   242  // to the user.
   243  func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
   244  	if ci.closed {
   245  		return nil, errCopyInClosed
   246  	}
   247  
   248  	if err := ci.getBad(); err != nil {
   249  		return nil, err
   250  	}
   251  	defer ci.cn.errRecover(&err)
   252  
   253  	if err := ci.err(); err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	if len(v) == 0 {
   258  		if err := ci.Close(); err != nil {
   259  			return driver.RowsAffected(0), err
   260  		}
   261  
   262  		return ci.getResult(), nil
   263  	}
   264  
   265  	numValues := len(v)
   266  	for i, value := range v {
   267  		ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
   268  		if i < numValues-1 {
   269  			ci.buffer = append(ci.buffer, '\t')
   270  		}
   271  	}
   272  
   273  	ci.buffer = append(ci.buffer, '\n')
   274  
   275  	if len(ci.buffer) > ciBufferFlushSize {
   276  		ci.flush(ci.buffer)
   277  		// reset buffer, keep bytes for message identifier and length
   278  		ci.buffer = ci.buffer[:5]
   279  	}
   280  
   281  	return driver.RowsAffected(0), nil
   282  }
   283  
   284  // CopyData inserts a raw string into the COPY stream. The insert is
   285  // asynchronous and CopyData can return errors from previous CopyData calls to
   286  // the same COPY stmt.
   287  //
   288  // You need to call Exec(nil) to sync the COPY stream and to get any
   289  // errors from pending data, since Stmt.Close() doesn't return errors
   290  // to the user.
   291  func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) {
   292  	if ci.closed {
   293  		return nil, errCopyInClosed
   294  	}
   295  
   296  	if finish := ci.cn.watchCancel(ctx); finish != nil {
   297  		defer finish()
   298  	}
   299  
   300  	if err := ci.getBad(); err != nil {
   301  		return nil, err
   302  	}
   303  	defer ci.cn.errRecover(&err)
   304  
   305  	if err := ci.err(); err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	ci.buffer = append(ci.buffer, []byte(line)...)
   310  	ci.buffer = append(ci.buffer, '\n')
   311  
   312  	if len(ci.buffer) > ciBufferFlushSize {
   313  		ci.flush(ci.buffer)
   314  		// reset buffer, keep bytes for message identifier and length
   315  		ci.buffer = ci.buffer[:5]
   316  	}
   317  
   318  	return driver.RowsAffected(0), nil
   319  }
   320  
   321  func (ci *copyin) Close() (err error) {
   322  	if ci.closed { // Don't do anything, we're already closed
   323  		return nil
   324  	}
   325  	ci.closed = true
   326  
   327  	if err := ci.getBad(); err != nil {
   328  		return err
   329  	}
   330  	defer ci.cn.errRecover(&err)
   331  
   332  	if len(ci.buffer) > 0 {
   333  		ci.flush(ci.buffer)
   334  	}
   335  	// Avoid touching the scratch buffer as resploop could be using it.
   336  	err = ci.cn.sendSimpleMessage('c')
   337  	if err != nil {
   338  		return err
   339  	}
   340  
   341  	<-ci.done
   342  	ci.cn.inCopy = false
   343  
   344  	if err := ci.err(); err != nil {
   345  		return err
   346  	}
   347  	return nil
   348  }
   349  

View as plain text