package pgproto3_test import ( "io" "testing" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestBackendReceiveInterrupted(t *testing.T) { t.Parallel() server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) backend := pgproto3.NewBackend(server, nil) msg, err := backend.Receive() if err == nil { t.Fatal("expected err") } if msg != nil { t.Fatalf("did not expect msg, but %v", msg) } server.push([]byte{'I', 0}) msg, err = backend.Receive() if err != nil { t.Fatal(err) } if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { t.Fatalf("unexpected msg: %v", msg) } } func TestBackendReceiveUnexpectedEOF(t *testing.T) { t.Parallel() server := &interruptReader{} server.push([]byte{'Q', 0, 0, 0, 6}) backend := pgproto3.NewBackend(server, nil) // Receive regular msg msg, err := backend.Receive() assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) // Receive StartupMessage msg dst := []byte{} dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read dst = pgio.AppendUint32(dst, 1) // only send 1 byte server.push(dst) msg, err = backend.ReceiveStartupMessage() assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) } func TestStartupMessage(t *testing.T) { t.Parallel() t.Run("valid StartupMessage", func(t *testing.T) { want := &pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{ "username": "tester", }, } dst, err := want.Encode([]byte{}) require.NoError(t, err) server := &interruptReader{} server.push(dst) backend := pgproto3.NewBackend(server, nil) msg, err := backend.ReceiveStartupMessage() require.NoError(t, err) require.Equal(t, want, msg) }) t.Run("invalid packet length", func(t *testing.T) { wantErr := "invalid length of startup packet" tests := []struct { name string packetLen uint32 }{ { name: "large packet length", // Since the StartupMessage contains the "Length of message contents // in bytes, including self", the max startup packet length is actually // 10000+4. Therefore, let's go past the limit with 10005 packetLen: 10005, }, { name: "short packet length", packetLen: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := &interruptReader{} dst := []byte{} dst = pgio.AppendUint32(dst, tt.packetLen) dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) server.push(dst) backend := pgproto3.NewBackend(server, nil) msg, err := backend.ReceiveStartupMessage() require.Error(t, err) require.Nil(t, msg) require.Contains(t, err.Error(), wantErr) }) } }) } func TestBackendReceiveExceededMaxBodyLen(t *testing.T) { t.Parallel() server := &interruptReader{} server.push([]byte{'Q', 0, 0, 10, 10}) backend := pgproto3.NewBackend(server, nil) // Set max body len to 5 backend.SetMaxBodyLen(5) // Receive regular msg msg, err := backend.Receive() assert.Nil(t, msg) var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr assert.ErrorAs(t, err, &invalidBodyLenErr) }