...

Source file src/github.com/jackc/pgproto3/v2/backend.go

Documentation: github.com/jackc/pgproto3/v2

     1  package pgproto3
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  )
     9  
    10  // Backend acts as a server for the PostgreSQL wire protocol version 3.
    11  type Backend struct {
    12  	cr ChunkReader
    13  	w  io.Writer
    14  
    15  	// Frontend message flyweights
    16  	bind           Bind
    17  	cancelRequest  CancelRequest
    18  	_close         Close
    19  	copyFail       CopyFail
    20  	copyData       CopyData
    21  	copyDone       CopyDone
    22  	describe       Describe
    23  	execute        Execute
    24  	flush          Flush
    25  	functionCall   FunctionCall
    26  	gssEncRequest  GSSEncRequest
    27  	parse          Parse
    28  	query          Query
    29  	sslRequest     SSLRequest
    30  	startupMessage StartupMessage
    31  	sync           Sync
    32  	terminate      Terminate
    33  
    34  	bodyLen    int
    35  	msgType    byte
    36  	partialMsg bool
    37  	authType   uint32
    38  }
    39  
    40  const (
    41  	minStartupPacketLen = 4     // minStartupPacketLen is a single 32-bit int version or code.
    42  	maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
    43  )
    44  
    45  // NewBackend creates a new Backend.
    46  func NewBackend(cr ChunkReader, w io.Writer) *Backend {
    47  	return &Backend{cr: cr, w: w}
    48  }
    49  
    50  // Send sends a message to the frontend.
    51  func (b *Backend) Send(msg BackendMessage) error {
    52  	buf, err := msg.Encode(nil)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	_, err = b.w.Write(buf)
    58  	return err
    59  }
    60  
    61  // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
    62  // because the initial connection message is "special" and does not include the message type as the first byte. This
    63  // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
    64  func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
    65  	buf, err := b.cr.Next(4)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	msgSize := int(binary.BigEndian.Uint32(buf) - 4)
    70  
    71  	if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
    72  		return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
    73  	}
    74  
    75  	buf, err = b.cr.Next(msgSize)
    76  	if err != nil {
    77  		return nil, translateEOFtoErrUnexpectedEOF(err)
    78  	}
    79  
    80  	code := binary.BigEndian.Uint32(buf)
    81  
    82  	switch code {
    83  	case ProtocolVersionNumber:
    84  		err = b.startupMessage.Decode(buf)
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		return &b.startupMessage, nil
    89  	case sslRequestNumber:
    90  		err = b.sslRequest.Decode(buf)
    91  		if err != nil {
    92  			return nil, err
    93  		}
    94  		return &b.sslRequest, nil
    95  	case cancelRequestCode:
    96  		err = b.cancelRequest.Decode(buf)
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  		return &b.cancelRequest, nil
   101  	case gssEncReqNumber:
   102  		err = b.gssEncRequest.Decode(buf)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		return &b.gssEncRequest, nil
   107  	default:
   108  		return nil, fmt.Errorf("unknown startup message code: %d", code)
   109  	}
   110  }
   111  
   112  // Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
   113  func (b *Backend) Receive() (FrontendMessage, error) {
   114  	if !b.partialMsg {
   115  		header, err := b.cr.Next(5)
   116  		if err != nil {
   117  			return nil, translateEOFtoErrUnexpectedEOF(err)
   118  		}
   119  
   120  		b.msgType = header[0]
   121  		b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
   122  		b.partialMsg = true
   123  		if b.bodyLen < 0 {
   124  			return nil, errors.New("invalid message with negative body length received")
   125  		}
   126  	}
   127  
   128  	var msg FrontendMessage
   129  	switch b.msgType {
   130  	case 'B':
   131  		msg = &b.bind
   132  	case 'C':
   133  		msg = &b._close
   134  	case 'D':
   135  		msg = &b.describe
   136  	case 'E':
   137  		msg = &b.execute
   138  	case 'F':
   139  		msg = &b.functionCall
   140  	case 'f':
   141  		msg = &b.copyFail
   142  	case 'd':
   143  		msg = &b.copyData
   144  	case 'c':
   145  		msg = &b.copyDone
   146  	case 'H':
   147  		msg = &b.flush
   148  	case 'P':
   149  		msg = &b.parse
   150  	case 'p':
   151  		switch b.authType {
   152  		case AuthTypeSASL:
   153  			msg = &SASLInitialResponse{}
   154  		case AuthTypeSASLContinue:
   155  			msg = &SASLResponse{}
   156  		case AuthTypeSASLFinal:
   157  			msg = &SASLResponse{}
   158  		case AuthTypeGSS, AuthTypeGSSCont:
   159  			msg = &GSSResponse{}
   160  		case AuthTypeCleartextPassword, AuthTypeMD5Password:
   161  			fallthrough
   162  		default:
   163  			// to maintain backwards compatability
   164  			msg = &PasswordMessage{}
   165  		}
   166  	case 'Q':
   167  		msg = &b.query
   168  	case 'S':
   169  		msg = &b.sync
   170  	case 'X':
   171  		msg = &b.terminate
   172  	default:
   173  		return nil, fmt.Errorf("unknown message type: %c", b.msgType)
   174  	}
   175  
   176  	msgBody, err := b.cr.Next(b.bodyLen)
   177  	if err != nil {
   178  		return nil, translateEOFtoErrUnexpectedEOF(err)
   179  	}
   180  
   181  	b.partialMsg = false
   182  
   183  	err = msg.Decode(msgBody)
   184  	return msg, err
   185  }
   186  
   187  // SetAuthType sets the authentication type in the backend.
   188  // Since multiple message types can start with 'p', SetAuthType allows
   189  // contextual identification of FrontendMessages. For example, in the
   190  // PG message flow documentation for PasswordMessage:
   191  //
   192  //			Byte1('p')
   193  //
   194  //	     Identifies the message as a password response. Note that this is also used for
   195  //			GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
   196  //			the context.
   197  //
   198  // Since the Frontend does not know about the state of a backend, it is important
   199  // to call SetAuthType() after an authentication request is received by the Frontend.
   200  func (b *Backend) SetAuthType(authType uint32) error {
   201  	switch authType {
   202  	case AuthTypeOk,
   203  		AuthTypeCleartextPassword,
   204  		AuthTypeMD5Password,
   205  		AuthTypeSCMCreds,
   206  		AuthTypeGSS,
   207  		AuthTypeGSSCont,
   208  		AuthTypeSSPI,
   209  		AuthTypeSASL,
   210  		AuthTypeSASLContinue,
   211  		AuthTypeSASLFinal:
   212  		b.authType = authType
   213  	default:
   214  		return fmt.Errorf("authType not recognized: %d", authType)
   215  	}
   216  
   217  	return nil
   218  }
   219  

View as plain text