1 package pgproto3
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "fmt"
7 "io"
8 )
9
10
11 type Backend struct {
12 cr *chunkReader
13 w io.Writer
14
15
16
17 tracer *tracer
18
19 wbuf []byte
20 encodeError error
21
22
23 bind Bind
24 cancelRequest CancelRequest
25 _close Close
26 copyFail CopyFail
27 copyData CopyData
28 copyDone CopyDone
29 describe Describe
30 execute Execute
31 flush Flush
32 functionCall FunctionCall
33 gssEncRequest GSSEncRequest
34 parse Parse
35 query Query
36 sslRequest SSLRequest
37 startupMessage StartupMessage
38 sync Sync
39 terminate Terminate
40
41 bodyLen int
42 maxBodyLen int
43 msgType byte
44 partialMsg bool
45 authType uint32
46 }
47
48 const (
49 minStartupPacketLen = 4
50 maxStartupPacketLen = 10000
51 )
52
53
54 func NewBackend(r io.Reader, w io.Writer) *Backend {
55 cr := newChunkReader(r, 0)
56 return &Backend{cr: cr, w: w}
57 }
58
59
60
61 func (b *Backend) Send(msg BackendMessage) {
62 if b.encodeError != nil {
63 return
64 }
65
66 prevLen := len(b.wbuf)
67 newBuf, err := msg.Encode(b.wbuf)
68 if err != nil {
69 b.encodeError = err
70 return
71 }
72 b.wbuf = newBuf
73
74 if b.tracer != nil {
75 b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
76 }
77 }
78
79
80 func (b *Backend) Flush() error {
81 if err := b.encodeError; err != nil {
82 b.encodeError = nil
83 b.wbuf = b.wbuf[:0]
84 return &writeError{err: err, safeToRetry: true}
85 }
86
87 n, err := b.w.Write(b.wbuf)
88
89 const maxLen = 1024
90 if len(b.wbuf) > maxLen {
91 b.wbuf = make([]byte, 0, maxLen)
92 } else {
93 b.wbuf = b.wbuf[:0]
94 }
95
96 if err != nil {
97 return &writeError{err: err, safeToRetry: n == 0}
98 }
99
100 return nil
101 }
102
103
104
105 func (b *Backend) Trace(w io.Writer, options TracerOptions) {
106 b.tracer = &tracer{
107 w: w,
108 buf: &bytes.Buffer{},
109 TracerOptions: options,
110 }
111 }
112
113
114 func (b *Backend) Untrace() {
115 b.tracer = nil
116 }
117
118
119
120
121 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
122 buf, err := b.cr.Next(4)
123 if err != nil {
124 return nil, err
125 }
126 msgSize := int(binary.BigEndian.Uint32(buf) - 4)
127
128 if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
129 return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
130 }
131
132 buf, err = b.cr.Next(msgSize)
133 if err != nil {
134 return nil, translateEOFtoErrUnexpectedEOF(err)
135 }
136
137 code := binary.BigEndian.Uint32(buf)
138
139 switch code {
140 case ProtocolVersionNumber:
141 err = b.startupMessage.Decode(buf)
142 if err != nil {
143 return nil, err
144 }
145 return &b.startupMessage, nil
146 case sslRequestNumber:
147 err = b.sslRequest.Decode(buf)
148 if err != nil {
149 return nil, err
150 }
151 return &b.sslRequest, nil
152 case cancelRequestCode:
153 err = b.cancelRequest.Decode(buf)
154 if err != nil {
155 return nil, err
156 }
157 return &b.cancelRequest, nil
158 case gssEncReqNumber:
159 err = b.gssEncRequest.Decode(buf)
160 if err != nil {
161 return nil, err
162 }
163 return &b.gssEncRequest, nil
164 default:
165 return nil, fmt.Errorf("unknown startup message code: %d", code)
166 }
167 }
168
169
170 func (b *Backend) Receive() (FrontendMessage, error) {
171 if !b.partialMsg {
172 header, err := b.cr.Next(5)
173 if err != nil {
174 return nil, translateEOFtoErrUnexpectedEOF(err)
175 }
176
177 b.msgType = header[0]
178 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
179 if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
180 return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
181 }
182 b.partialMsg = true
183 }
184
185 var msg FrontendMessage
186 switch b.msgType {
187 case 'B':
188 msg = &b.bind
189 case 'C':
190 msg = &b._close
191 case 'D':
192 msg = &b.describe
193 case 'E':
194 msg = &b.execute
195 case 'F':
196 msg = &b.functionCall
197 case 'f':
198 msg = &b.copyFail
199 case 'd':
200 msg = &b.copyData
201 case 'c':
202 msg = &b.copyDone
203 case 'H':
204 msg = &b.flush
205 case 'P':
206 msg = &b.parse
207 case 'p':
208 switch b.authType {
209 case AuthTypeSASL:
210 msg = &SASLInitialResponse{}
211 case AuthTypeSASLContinue:
212 msg = &SASLResponse{}
213 case AuthTypeSASLFinal:
214 msg = &SASLResponse{}
215 case AuthTypeGSS, AuthTypeGSSCont:
216 msg = &GSSResponse{}
217 case AuthTypeCleartextPassword, AuthTypeMD5Password:
218 fallthrough
219 default:
220
221 msg = &PasswordMessage{}
222 }
223 case 'Q':
224 msg = &b.query
225 case 'S':
226 msg = &b.sync
227 case 'X':
228 msg = &b.terminate
229 default:
230 return nil, fmt.Errorf("unknown message type: %c", b.msgType)
231 }
232
233 msgBody, err := b.cr.Next(b.bodyLen)
234 if err != nil {
235 return nil, translateEOFtoErrUnexpectedEOF(err)
236 }
237
238 b.partialMsg = false
239
240 err = msg.Decode(msgBody)
241 if err != nil {
242 return nil, err
243 }
244
245 if b.tracer != nil {
246 b.tracer.traceMessage('F', int32(5+len(msgBody)), msg)
247 }
248
249 return msg, nil
250 }
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265 func (b *Backend) SetAuthType(authType uint32) error {
266 switch authType {
267 case AuthTypeOk,
268 AuthTypeCleartextPassword,
269 AuthTypeMD5Password,
270 AuthTypeSCMCreds,
271 AuthTypeGSS,
272 AuthTypeGSSCont,
273 AuthTypeSSPI,
274 AuthTypeSASL,
275 AuthTypeSASLContinue,
276 AuthTypeSASLFinal:
277 b.authType = authType
278 default:
279 return fmt.Errorf("authType not recognized: %d", authType)
280 }
281
282 return nil
283 }
284
285
286
287
288
289
290 func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
291 b.maxBodyLen = maxBodyLen
292 }
293
View as plain text