1 package pgproto3_test
2
3 import (
4 "io"
5 "testing"
6
7 "github.com/jackc/pgx/v5/internal/pgio"
8 "github.com/jackc/pgx/v5/pgproto3"
9 "github.com/stretchr/testify/assert"
10 "github.com/stretchr/testify/require"
11 )
12
13 func TestBackendReceiveInterrupted(t *testing.T) {
14 t.Parallel()
15
16 server := &interruptReader{}
17 server.push([]byte{'Q', 0, 0, 0, 6})
18
19 backend := pgproto3.NewBackend(server, nil)
20
21 msg, err := backend.Receive()
22 if err == nil {
23 t.Fatal("expected err")
24 }
25 if msg != nil {
26 t.Fatalf("did not expect msg, but %v", msg)
27 }
28
29 server.push([]byte{'I', 0})
30
31 msg, err = backend.Receive()
32 if err != nil {
33 t.Fatal(err)
34 }
35 if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" {
36 t.Fatalf("unexpected msg: %v", msg)
37 }
38 }
39
40 func TestBackendReceiveUnexpectedEOF(t *testing.T) {
41 t.Parallel()
42
43 server := &interruptReader{}
44 server.push([]byte{'Q', 0, 0, 0, 6})
45
46 backend := pgproto3.NewBackend(server, nil)
47
48
49 msg, err := backend.Receive()
50 assert.Nil(t, msg)
51 assert.Equal(t, io.ErrUnexpectedEOF, err)
52
53
54 dst := []byte{}
55 dst = pgio.AppendUint32(dst, 1000)
56 dst = pgio.AppendUint32(dst, 1)
57 server.push(dst)
58
59 msg, err = backend.ReceiveStartupMessage()
60 assert.Nil(t, msg)
61 assert.Equal(t, io.ErrUnexpectedEOF, err)
62 }
63
64 func TestStartupMessage(t *testing.T) {
65 t.Parallel()
66
67 t.Run("valid StartupMessage", func(t *testing.T) {
68 want := &pgproto3.StartupMessage{
69 ProtocolVersion: pgproto3.ProtocolVersionNumber,
70 Parameters: map[string]string{
71 "username": "tester",
72 },
73 }
74 dst, err := want.Encode([]byte{})
75 require.NoError(t, err)
76
77 server := &interruptReader{}
78 server.push(dst)
79
80 backend := pgproto3.NewBackend(server, nil)
81
82 msg, err := backend.ReceiveStartupMessage()
83 require.NoError(t, err)
84 require.Equal(t, want, msg)
85 })
86
87 t.Run("invalid packet length", func(t *testing.T) {
88 wantErr := "invalid length of startup packet"
89 tests := []struct {
90 name string
91 packetLen uint32
92 }{
93 {
94 name: "large packet length",
95
96
97
98 packetLen: 10005,
99 },
100 {
101 name: "short packet length",
102 packetLen: 3,
103 },
104 }
105 for _, tt := range tests {
106 t.Run(tt.name, func(t *testing.T) {
107 server := &interruptReader{}
108 dst := []byte{}
109 dst = pgio.AppendUint32(dst, tt.packetLen)
110 dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
111 server.push(dst)
112
113 backend := pgproto3.NewBackend(server, nil)
114
115 msg, err := backend.ReceiveStartupMessage()
116 require.Error(t, err)
117 require.Nil(t, msg)
118 require.Contains(t, err.Error(), wantErr)
119 })
120 }
121 })
122 }
123
124 func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
125 t.Parallel()
126
127 server := &interruptReader{}
128 server.push([]byte{'Q', 0, 0, 10, 10})
129
130 backend := pgproto3.NewBackend(server, nil)
131
132
133 backend.SetMaxBodyLen(5)
134
135
136 msg, err := backend.Receive()
137 assert.Nil(t, msg)
138 var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
139 assert.ErrorAs(t, err, &invalidBodyLenErr)
140 }
141
View as plain text