...

Source file src/github.com/jackc/pgx/v5/pgproto3/frontend.go

Documentation: github.com/jackc/pgx/v5/pgproto3

     1  package pgproto3
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  )
    10  
    11  // Frontend acts as a client for the PostgreSQL wire protocol version 3.
    12  type Frontend struct {
    13  	cr *chunkReader
    14  	w  io.Writer
    15  
    16  	// tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
    17  	// before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is
    18  	// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
    19  	tracer *tracer
    20  
    21  	wbuf        []byte
    22  	encodeError error
    23  
    24  	// Backend message flyweights
    25  	authenticationOk                AuthenticationOk
    26  	authenticationCleartextPassword AuthenticationCleartextPassword
    27  	authenticationMD5Password       AuthenticationMD5Password
    28  	authenticationGSS               AuthenticationGSS
    29  	authenticationGSSContinue       AuthenticationGSSContinue
    30  	authenticationSASL              AuthenticationSASL
    31  	authenticationSASLContinue      AuthenticationSASLContinue
    32  	authenticationSASLFinal         AuthenticationSASLFinal
    33  	backendKeyData                  BackendKeyData
    34  	bindComplete                    BindComplete
    35  	closeComplete                   CloseComplete
    36  	commandComplete                 CommandComplete
    37  	copyBothResponse                CopyBothResponse
    38  	copyData                        CopyData
    39  	copyInResponse                  CopyInResponse
    40  	copyOutResponse                 CopyOutResponse
    41  	copyDone                        CopyDone
    42  	dataRow                         DataRow
    43  	emptyQueryResponse              EmptyQueryResponse
    44  	errorResponse                   ErrorResponse
    45  	functionCallResponse            FunctionCallResponse
    46  	noData                          NoData
    47  	noticeResponse                  NoticeResponse
    48  	notificationResponse            NotificationResponse
    49  	parameterDescription            ParameterDescription
    50  	parameterStatus                 ParameterStatus
    51  	parseComplete                   ParseComplete
    52  	readyForQuery                   ReadyForQuery
    53  	rowDescription                  RowDescription
    54  	portalSuspended                 PortalSuspended
    55  
    56  	bodyLen    int
    57  	msgType    byte
    58  	partialMsg bool
    59  	authType   uint32
    60  }
    61  
    62  // NewFrontend creates a new Frontend.
    63  func NewFrontend(r io.Reader, w io.Writer) *Frontend {
    64  	cr := newChunkReader(r, 0)
    65  	return &Frontend{cr: cr, w: w}
    66  }
    67  
    68  // Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
    69  // encountered will be returned from Flush.
    70  //
    71  // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
    72  // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
    73  // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
    74  // behind an interface.
    75  func (f *Frontend) Send(msg FrontendMessage) {
    76  	if f.encodeError != nil {
    77  		return
    78  	}
    79  
    80  	prevLen := len(f.wbuf)
    81  	newBuf, err := msg.Encode(f.wbuf)
    82  	if err != nil {
    83  		f.encodeError = err
    84  		return
    85  	}
    86  	f.wbuf = newBuf
    87  
    88  	if f.tracer != nil {
    89  		f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
    90  	}
    91  }
    92  
    93  // Flush writes any pending messages to the backend (i.e. the server).
    94  func (f *Frontend) Flush() error {
    95  	if err := f.encodeError; err != nil {
    96  		f.encodeError = nil
    97  		f.wbuf = f.wbuf[:0]
    98  		return &writeError{err: err, safeToRetry: true}
    99  	}
   100  
   101  	if len(f.wbuf) == 0 {
   102  		return nil
   103  	}
   104  
   105  	n, err := f.w.Write(f.wbuf)
   106  
   107  	const maxLen = 1024
   108  	if len(f.wbuf) > maxLen {
   109  		f.wbuf = make([]byte, 0, maxLen)
   110  	} else {
   111  		f.wbuf = f.wbuf[:0]
   112  	}
   113  
   114  	if err != nil {
   115  		return &writeError{err: err, safeToRetry: n == 0}
   116  	}
   117  
   118  	return nil
   119  }
   120  
   121  // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function
   122  // PQtrace.
   123  func (f *Frontend) Trace(w io.Writer, options TracerOptions) {
   124  	f.tracer = &tracer{
   125  		w:             w,
   126  		buf:           &bytes.Buffer{},
   127  		TracerOptions: options,
   128  	}
   129  }
   130  
   131  // Untrace stops tracing.
   132  func (f *Frontend) Untrace() {
   133  	f.tracer = nil
   134  }
   135  
   136  // SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
   137  // error encountered will be returned from Flush.
   138  func (f *Frontend) SendBind(msg *Bind) {
   139  	if f.encodeError != nil {
   140  		return
   141  	}
   142  
   143  	prevLen := len(f.wbuf)
   144  	newBuf, err := msg.Encode(f.wbuf)
   145  	if err != nil {
   146  		f.encodeError = err
   147  		return
   148  	}
   149  	f.wbuf = newBuf
   150  
   151  	if f.tracer != nil {
   152  		f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
   153  	}
   154  }
   155  
   156  // SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
   157  // error encountered will be returned from Flush.
   158  func (f *Frontend) SendParse(msg *Parse) {
   159  	if f.encodeError != nil {
   160  		return
   161  	}
   162  
   163  	prevLen := len(f.wbuf)
   164  	newBuf, err := msg.Encode(f.wbuf)
   165  	if err != nil {
   166  		f.encodeError = err
   167  		return
   168  	}
   169  	f.wbuf = newBuf
   170  
   171  	if f.tracer != nil {
   172  		f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
   173  	}
   174  }
   175  
   176  // SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
   177  // error encountered will be returned from Flush.
   178  func (f *Frontend) SendClose(msg *Close) {
   179  	if f.encodeError != nil {
   180  		return
   181  	}
   182  
   183  	prevLen := len(f.wbuf)
   184  	newBuf, err := msg.Encode(f.wbuf)
   185  	if err != nil {
   186  		f.encodeError = err
   187  		return
   188  	}
   189  	f.wbuf = newBuf
   190  
   191  	if f.tracer != nil {
   192  		f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
   193  	}
   194  }
   195  
   196  // SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
   197  // called. Any error encountered will be returned from Flush.
   198  func (f *Frontend) SendDescribe(msg *Describe) {
   199  	if f.encodeError != nil {
   200  		return
   201  	}
   202  
   203  	prevLen := len(f.wbuf)
   204  	newBuf, err := msg.Encode(f.wbuf)
   205  	if err != nil {
   206  		f.encodeError = err
   207  		return
   208  	}
   209  	f.wbuf = newBuf
   210  
   211  	if f.tracer != nil {
   212  		f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
   213  	}
   214  }
   215  
   216  // SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
   217  // Any error encountered will be returned from Flush.
   218  func (f *Frontend) SendExecute(msg *Execute) {
   219  	if f.encodeError != nil {
   220  		return
   221  	}
   222  
   223  	prevLen := len(f.wbuf)
   224  	newBuf, err := msg.Encode(f.wbuf)
   225  	if err != nil {
   226  		f.encodeError = err
   227  		return
   228  	}
   229  	f.wbuf = newBuf
   230  
   231  	if f.tracer != nil {
   232  		f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
   233  	}
   234  }
   235  
   236  // SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
   237  // error encountered will be returned from Flush.
   238  func (f *Frontend) SendSync(msg *Sync) {
   239  	if f.encodeError != nil {
   240  		return
   241  	}
   242  
   243  	prevLen := len(f.wbuf)
   244  	newBuf, err := msg.Encode(f.wbuf)
   245  	if err != nil {
   246  		f.encodeError = err
   247  		return
   248  	}
   249  	f.wbuf = newBuf
   250  
   251  	if f.tracer != nil {
   252  		f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
   253  	}
   254  }
   255  
   256  // SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
   257  // error encountered will be returned from Flush.
   258  func (f *Frontend) SendQuery(msg *Query) {
   259  	if f.encodeError != nil {
   260  		return
   261  	}
   262  
   263  	prevLen := len(f.wbuf)
   264  	newBuf, err := msg.Encode(f.wbuf)
   265  	if err != nil {
   266  		f.encodeError = err
   267  		return
   268  	}
   269  	f.wbuf = newBuf
   270  
   271  	if f.tracer != nil {
   272  		f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
   273  	}
   274  }
   275  
   276  // SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method
   277  // is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer
   278  // before being written out. The internal buffer is flushed before the message is sent.
   279  func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error {
   280  	err := f.Flush()
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	n, err := f.w.Write(msg)
   286  	if err != nil {
   287  		return &writeError{err: err, safeToRetry: n == 0}
   288  	}
   289  
   290  	if f.tracer != nil {
   291  		f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{})
   292  	}
   293  
   294  	return nil
   295  }
   296  
   297  func translateEOFtoErrUnexpectedEOF(err error) error {
   298  	if err == io.EOF {
   299  		return io.ErrUnexpectedEOF
   300  	}
   301  	return err
   302  }
   303  
   304  // Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
   305  func (f *Frontend) Receive() (BackendMessage, error) {
   306  	if !f.partialMsg {
   307  		header, err := f.cr.Next(5)
   308  		if err != nil {
   309  			return nil, translateEOFtoErrUnexpectedEOF(err)
   310  		}
   311  
   312  		f.msgType = header[0]
   313  
   314  		msgLength := int(binary.BigEndian.Uint32(header[1:]))
   315  		if msgLength < 4 {
   316  			return nil, fmt.Errorf("invalid message length: %d", msgLength)
   317  		}
   318  
   319  		f.bodyLen = msgLength - 4
   320  		f.partialMsg = true
   321  	}
   322  
   323  	msgBody, err := f.cr.Next(f.bodyLen)
   324  	if err != nil {
   325  		return nil, translateEOFtoErrUnexpectedEOF(err)
   326  	}
   327  
   328  	f.partialMsg = false
   329  
   330  	var msg BackendMessage
   331  	switch f.msgType {
   332  	case '1':
   333  		msg = &f.parseComplete
   334  	case '2':
   335  		msg = &f.bindComplete
   336  	case '3':
   337  		msg = &f.closeComplete
   338  	case 'A':
   339  		msg = &f.notificationResponse
   340  	case 'c':
   341  		msg = &f.copyDone
   342  	case 'C':
   343  		msg = &f.commandComplete
   344  	case 'd':
   345  		msg = &f.copyData
   346  	case 'D':
   347  		msg = &f.dataRow
   348  	case 'E':
   349  		msg = &f.errorResponse
   350  	case 'G':
   351  		msg = &f.copyInResponse
   352  	case 'H':
   353  		msg = &f.copyOutResponse
   354  	case 'I':
   355  		msg = &f.emptyQueryResponse
   356  	case 'K':
   357  		msg = &f.backendKeyData
   358  	case 'n':
   359  		msg = &f.noData
   360  	case 'N':
   361  		msg = &f.noticeResponse
   362  	case 'R':
   363  		var err error
   364  		msg, err = f.findAuthenticationMessageType(msgBody)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  	case 's':
   369  		msg = &f.portalSuspended
   370  	case 'S':
   371  		msg = &f.parameterStatus
   372  	case 't':
   373  		msg = &f.parameterDescription
   374  	case 'T':
   375  		msg = &f.rowDescription
   376  	case 'V':
   377  		msg = &f.functionCallResponse
   378  	case 'W':
   379  		msg = &f.copyBothResponse
   380  	case 'Z':
   381  		msg = &f.readyForQuery
   382  	default:
   383  		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
   384  	}
   385  
   386  	err = msg.Decode(msgBody)
   387  	if err != nil {
   388  		return nil, err
   389  	}
   390  
   391  	if f.tracer != nil {
   392  		f.tracer.traceMessage('B', int32(5+len(msgBody)), msg)
   393  	}
   394  
   395  	return msg, nil
   396  }
   397  
   398  // Authentication message type constants.
   399  // See src/include/libpq/pqcomm.h for all
   400  // constants.
   401  const (
   402  	AuthTypeOk                = 0
   403  	AuthTypeCleartextPassword = 3
   404  	AuthTypeMD5Password       = 5
   405  	AuthTypeSCMCreds          = 6
   406  	AuthTypeGSS               = 7
   407  	AuthTypeGSSCont           = 8
   408  	AuthTypeSSPI              = 9
   409  	AuthTypeSASL              = 10
   410  	AuthTypeSASLContinue      = 11
   411  	AuthTypeSASLFinal         = 12
   412  )
   413  
   414  func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
   415  	if len(src) < 4 {
   416  		return nil, errors.New("authentication message too short")
   417  	}
   418  	f.authType = binary.BigEndian.Uint32(src[:4])
   419  
   420  	switch f.authType {
   421  	case AuthTypeOk:
   422  		return &f.authenticationOk, nil
   423  	case AuthTypeCleartextPassword:
   424  		return &f.authenticationCleartextPassword, nil
   425  	case AuthTypeMD5Password:
   426  		return &f.authenticationMD5Password, nil
   427  	case AuthTypeSCMCreds:
   428  		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
   429  	case AuthTypeGSS:
   430  		return &f.authenticationGSS, nil
   431  	case AuthTypeGSSCont:
   432  		return &f.authenticationGSSContinue, nil
   433  	case AuthTypeSSPI:
   434  		return nil, errors.New("AuthTypeSSPI is unimplemented")
   435  	case AuthTypeSASL:
   436  		return &f.authenticationSASL, nil
   437  	case AuthTypeSASLContinue:
   438  		return &f.authenticationSASLContinue, nil
   439  	case AuthTypeSASLFinal:
   440  		return &f.authenticationSASLFinal, nil
   441  	default:
   442  		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
   443  	}
   444  }
   445  
   446  // GetAuthType returns the authType used in the current state of the frontend.
   447  // See SetAuthType for more information.
   448  func (f *Frontend) GetAuthType() uint32 {
   449  	return f.authType
   450  }
   451  
   452  func (f *Frontend) ReadBufferLen() int {
   453  	return f.cr.wp - f.cr.rp
   454  }
   455  

View as plain text