...

Source file src/github.com/jackc/pgx/v5/pgproto3/backend_test.go

Documentation: github.com/jackc/pgx/v5/pgproto3

     1  package pgproto3_test
     2  
     3  import (
     4  	"io"
     5  	"testing"
     6  
     7  	"github.com/jackc/pgx/v5/internal/pgio"
     8  	"github.com/jackc/pgx/v5/pgproto3"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestBackendReceiveInterrupted(t *testing.T) {
    14  	t.Parallel()
    15  
    16  	server := &interruptReader{}
    17  	server.push([]byte{'Q', 0, 0, 0, 6})
    18  
    19  	backend := pgproto3.NewBackend(server, nil)
    20  
    21  	msg, err := backend.Receive()
    22  	if err == nil {
    23  		t.Fatal("expected err")
    24  	}
    25  	if msg != nil {
    26  		t.Fatalf("did not expect msg, but %v", msg)
    27  	}
    28  
    29  	server.push([]byte{'I', 0})
    30  
    31  	msg, err = backend.Receive()
    32  	if err != nil {
    33  		t.Fatal(err)
    34  	}
    35  	if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
    36  		t.Fatalf("unexpected msg: %v", msg)
    37  	}
    38  }
    39  
    40  func TestBackendReceiveUnexpectedEOF(t *testing.T) {
    41  	t.Parallel()
    42  
    43  	server := &interruptReader{}
    44  	server.push([]byte{'Q', 0, 0, 0, 6})
    45  
    46  	backend := pgproto3.NewBackend(server, nil)
    47  
    48  	// Receive regular msg
    49  	msg, err := backend.Receive()
    50  	assert.Nil(t, msg)
    51  	assert.Equal(t, io.ErrUnexpectedEOF, err)
    52  
    53  	// Receive StartupMessage msg
    54  	dst := []byte{}
    55  	dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read
    56  	dst = pgio.AppendUint32(dst, 1)    // only send 1 byte
    57  	server.push(dst)
    58  
    59  	msg, err = backend.ReceiveStartupMessage()
    60  	assert.Nil(t, msg)
    61  	assert.Equal(t, io.ErrUnexpectedEOF, err)
    62  }
    63  
    64  func TestStartupMessage(t *testing.T) {
    65  	t.Parallel()
    66  
    67  	t.Run("valid StartupMessage", func(t *testing.T) {
    68  		want := &pgproto3.StartupMessage{
    69  			ProtocolVersion: pgproto3.ProtocolVersionNumber,
    70  			Parameters: map[string]string{
    71  				"username": "tester",
    72  			},
    73  		}
    74  		dst, err := want.Encode([]byte{})
    75  		require.NoError(t, err)
    76  
    77  		server := &interruptReader{}
    78  		server.push(dst)
    79  
    80  		backend := pgproto3.NewBackend(server, nil)
    81  
    82  		msg, err := backend.ReceiveStartupMessage()
    83  		require.NoError(t, err)
    84  		require.Equal(t, want, msg)
    85  	})
    86  
    87  	t.Run("invalid packet length", func(t *testing.T) {
    88  		wantErr := "invalid length of startup packet"
    89  		tests := []struct {
    90  			name      string
    91  			packetLen uint32
    92  		}{
    93  			{
    94  				name: "large packet length",
    95  				// Since the StartupMessage contains the "Length of message contents
    96  				//  in bytes, including self", the max startup packet length is actually
    97  				//  10000+4. Therefore, let's go past the limit with 10005
    98  				packetLen: 10005,
    99  			},
   100  			{
   101  				name:      "short packet length",
   102  				packetLen: 3,
   103  			},
   104  		}
   105  		for _, tt := range tests {
   106  			t.Run(tt.name, func(t *testing.T) {
   107  				server := &interruptReader{}
   108  				dst := []byte{}
   109  				dst = pgio.AppendUint32(dst, tt.packetLen)
   110  				dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
   111  				server.push(dst)
   112  
   113  				backend := pgproto3.NewBackend(server, nil)
   114  
   115  				msg, err := backend.ReceiveStartupMessage()
   116  				require.Error(t, err)
   117  				require.Nil(t, msg)
   118  				require.Contains(t, err.Error(), wantErr)
   119  			})
   120  		}
   121  	})
   122  }
   123  
   124  func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
   125  	t.Parallel()
   126  
   127  	server := &interruptReader{}
   128  	server.push([]byte{'Q', 0, 0, 10, 10})
   129  
   130  	backend := pgproto3.NewBackend(server, nil)
   131  
   132  	// Set max body len to 5
   133  	backend.SetMaxBodyLen(5)
   134  
   135  	// Receive regular msg
   136  	msg, err := backend.Receive()
   137  	assert.Nil(t, msg)
   138  	var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
   139  	assert.ErrorAs(t, err, &invalidBodyLenErr)
   140  }
   141  

View as plain text