...

Source file src/github.com/jackc/pgproto3/v2/frontend.go

Documentation: github.com/jackc/pgproto3/v2

     1  package pgproto3
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  )
     9  
    10  // Frontend acts as a client for the PostgreSQL wire protocol version 3.
    11  type Frontend struct {
    12  	cr ChunkReader
    13  	w  io.Writer
    14  
    15  	// Backend message flyweights
    16  	authenticationOk                AuthenticationOk
    17  	authenticationCleartextPassword AuthenticationCleartextPassword
    18  	authenticationMD5Password       AuthenticationMD5Password
    19  	authenticationGSS               AuthenticationGSS
    20  	authenticationGSSContinue       AuthenticationGSSContinue
    21  	authenticationSASL              AuthenticationSASL
    22  	authenticationSASLContinue      AuthenticationSASLContinue
    23  	authenticationSASLFinal         AuthenticationSASLFinal
    24  	backendKeyData                  BackendKeyData
    25  	bindComplete                    BindComplete
    26  	closeComplete                   CloseComplete
    27  	commandComplete                 CommandComplete
    28  	copyBothResponse                CopyBothResponse
    29  	copyData                        CopyData
    30  	copyInResponse                  CopyInResponse
    31  	copyOutResponse                 CopyOutResponse
    32  	copyDone                        CopyDone
    33  	dataRow                         DataRow
    34  	emptyQueryResponse              EmptyQueryResponse
    35  	errorResponse                   ErrorResponse
    36  	functionCallResponse            FunctionCallResponse
    37  	noData                          NoData
    38  	noticeResponse                  NoticeResponse
    39  	notificationResponse            NotificationResponse
    40  	parameterDescription            ParameterDescription
    41  	parameterStatus                 ParameterStatus
    42  	parseComplete                   ParseComplete
    43  	readyForQuery                   ReadyForQuery
    44  	rowDescription                  RowDescription
    45  	portalSuspended                 PortalSuspended
    46  
    47  	bodyLen    int
    48  	msgType    byte
    49  	partialMsg bool
    50  	authType   uint32
    51  }
    52  
    53  // NewFrontend creates a new Frontend.
    54  func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
    55  	return &Frontend{cr: cr, w: w}
    56  }
    57  
    58  // Send sends a message to the backend.
    59  func (f *Frontend) Send(msg FrontendMessage) error {
    60  	buf, err := msg.Encode(nil)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	_, err = f.w.Write(buf)
    65  	return err
    66  }
    67  
    68  func translateEOFtoErrUnexpectedEOF(err error) error {
    69  	if err == io.EOF {
    70  		return io.ErrUnexpectedEOF
    71  	}
    72  	return err
    73  }
    74  
    75  // Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
    76  func (f *Frontend) Receive() (BackendMessage, error) {
    77  	if !f.partialMsg {
    78  		header, err := f.cr.Next(5)
    79  		if err != nil {
    80  			return nil, translateEOFtoErrUnexpectedEOF(err)
    81  		}
    82  
    83  		f.msgType = header[0]
    84  		f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
    85  		f.partialMsg = true
    86  		if f.bodyLen < 0 {
    87  			return nil, errors.New("invalid message with negative body length received")
    88  		}
    89  	}
    90  
    91  	msgBody, err := f.cr.Next(f.bodyLen)
    92  	if err != nil {
    93  		return nil, translateEOFtoErrUnexpectedEOF(err)
    94  	}
    95  
    96  	f.partialMsg = false
    97  
    98  	var msg BackendMessage
    99  	switch f.msgType {
   100  	case '1':
   101  		msg = &f.parseComplete
   102  	case '2':
   103  		msg = &f.bindComplete
   104  	case '3':
   105  		msg = &f.closeComplete
   106  	case 'A':
   107  		msg = &f.notificationResponse
   108  	case 'c':
   109  		msg = &f.copyDone
   110  	case 'C':
   111  		msg = &f.commandComplete
   112  	case 'd':
   113  		msg = &f.copyData
   114  	case 'D':
   115  		msg = &f.dataRow
   116  	case 'E':
   117  		msg = &f.errorResponse
   118  	case 'G':
   119  		msg = &f.copyInResponse
   120  	case 'H':
   121  		msg = &f.copyOutResponse
   122  	case 'I':
   123  		msg = &f.emptyQueryResponse
   124  	case 'K':
   125  		msg = &f.backendKeyData
   126  	case 'n':
   127  		msg = &f.noData
   128  	case 'N':
   129  		msg = &f.noticeResponse
   130  	case 'R':
   131  		var err error
   132  		msg, err = f.findAuthenticationMessageType(msgBody)
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  	case 's':
   137  		msg = &f.portalSuspended
   138  	case 'S':
   139  		msg = &f.parameterStatus
   140  	case 't':
   141  		msg = &f.parameterDescription
   142  	case 'T':
   143  		msg = &f.rowDescription
   144  	case 'V':
   145  		msg = &f.functionCallResponse
   146  	case 'W':
   147  		msg = &f.copyBothResponse
   148  	case 'Z':
   149  		msg = &f.readyForQuery
   150  	default:
   151  		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
   152  	}
   153  
   154  	err = msg.Decode(msgBody)
   155  	return msg, err
   156  }
   157  
   158  // Authentication message type constants.
   159  // See src/include/libpq/pqcomm.h for all
   160  // constants.
   161  const (
   162  	AuthTypeOk                = 0
   163  	AuthTypeCleartextPassword = 3
   164  	AuthTypeMD5Password       = 5
   165  	AuthTypeSCMCreds          = 6
   166  	AuthTypeGSS               = 7
   167  	AuthTypeGSSCont           = 8
   168  	AuthTypeSSPI              = 9
   169  	AuthTypeSASL              = 10
   170  	AuthTypeSASLContinue      = 11
   171  	AuthTypeSASLFinal         = 12
   172  )
   173  
   174  func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
   175  	if len(src) < 4 {
   176  		return nil, errors.New("authentication message too short")
   177  	}
   178  	f.authType = binary.BigEndian.Uint32(src[:4])
   179  
   180  	switch f.authType {
   181  	case AuthTypeOk:
   182  		return &f.authenticationOk, nil
   183  	case AuthTypeCleartextPassword:
   184  		return &f.authenticationCleartextPassword, nil
   185  	case AuthTypeMD5Password:
   186  		return &f.authenticationMD5Password, nil
   187  	case AuthTypeSCMCreds:
   188  		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
   189  	case AuthTypeGSS:
   190  		return &f.authenticationGSS, nil
   191  	case AuthTypeGSSCont:
   192  		return &f.authenticationGSSContinue, nil
   193  	case AuthTypeSSPI:
   194  		return nil, errors.New("AuthTypeSSPI is unimplemented")
   195  	case AuthTypeSASL:
   196  		return &f.authenticationSASL, nil
   197  	case AuthTypeSASLContinue:
   198  		return &f.authenticationSASLContinue, nil
   199  	case AuthTypeSASLFinal:
   200  		return &f.authenticationSASLFinal, nil
   201  	default:
   202  		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
   203  	}
   204  }
   205  
   206  // GetAuthType returns the authType used in the current state of the frontend.
   207  // See SetAuthType for more information.
   208  func (f *Frontend) GetAuthType() uint32 {
   209  	return f.authType
   210  }
   211  

View as plain text