1 package pgproto3
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "io"
9 )
10
11
12 type Frontend struct {
13 cr *chunkReader
14 w io.Writer
15
16
17
18
19 tracer *tracer
20
21 wbuf []byte
22 encodeError error
23
24
25 authenticationOk AuthenticationOk
26 authenticationCleartextPassword AuthenticationCleartextPassword
27 authenticationMD5Password AuthenticationMD5Password
28 authenticationGSS AuthenticationGSS
29 authenticationGSSContinue AuthenticationGSSContinue
30 authenticationSASL AuthenticationSASL
31 authenticationSASLContinue AuthenticationSASLContinue
32 authenticationSASLFinal AuthenticationSASLFinal
33 backendKeyData BackendKeyData
34 bindComplete BindComplete
35 closeComplete CloseComplete
36 commandComplete CommandComplete
37 copyBothResponse CopyBothResponse
38 copyData CopyData
39 copyInResponse CopyInResponse
40 copyOutResponse CopyOutResponse
41 copyDone CopyDone
42 dataRow DataRow
43 emptyQueryResponse EmptyQueryResponse
44 errorResponse ErrorResponse
45 functionCallResponse FunctionCallResponse
46 noData NoData
47 noticeResponse NoticeResponse
48 notificationResponse NotificationResponse
49 parameterDescription ParameterDescription
50 parameterStatus ParameterStatus
51 parseComplete ParseComplete
52 readyForQuery ReadyForQuery
53 rowDescription RowDescription
54 portalSuspended PortalSuspended
55
56 bodyLen int
57 msgType byte
58 partialMsg bool
59 authType uint32
60 }
61
62
63 func NewFrontend(r io.Reader, w io.Writer) *Frontend {
64 cr := newChunkReader(r, 0)
65 return &Frontend{cr: cr, w: w}
66 }
67
68
69
70
71
72
73
74
75 func (f *Frontend) Send(msg FrontendMessage) {
76 if f.encodeError != nil {
77 return
78 }
79
80 prevLen := len(f.wbuf)
81 newBuf, err := msg.Encode(f.wbuf)
82 if err != nil {
83 f.encodeError = err
84 return
85 }
86 f.wbuf = newBuf
87
88 if f.tracer != nil {
89 f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
90 }
91 }
92
93
94 func (f *Frontend) Flush() error {
95 if err := f.encodeError; err != nil {
96 f.encodeError = nil
97 f.wbuf = f.wbuf[:0]
98 return &writeError{err: err, safeToRetry: true}
99 }
100
101 if len(f.wbuf) == 0 {
102 return nil
103 }
104
105 n, err := f.w.Write(f.wbuf)
106
107 const maxLen = 1024
108 if len(f.wbuf) > maxLen {
109 f.wbuf = make([]byte, 0, maxLen)
110 } else {
111 f.wbuf = f.wbuf[:0]
112 }
113
114 if err != nil {
115 return &writeError{err: err, safeToRetry: n == 0}
116 }
117
118 return nil
119 }
120
121
122
123 func (f *Frontend) Trace(w io.Writer, options TracerOptions) {
124 f.tracer = &tracer{
125 w: w,
126 buf: &bytes.Buffer{},
127 TracerOptions: options,
128 }
129 }
130
131
132 func (f *Frontend) Untrace() {
133 f.tracer = nil
134 }
135
136
137
138 func (f *Frontend) SendBind(msg *Bind) {
139 if f.encodeError != nil {
140 return
141 }
142
143 prevLen := len(f.wbuf)
144 newBuf, err := msg.Encode(f.wbuf)
145 if err != nil {
146 f.encodeError = err
147 return
148 }
149 f.wbuf = newBuf
150
151 if f.tracer != nil {
152 f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
153 }
154 }
155
156
157
158 func (f *Frontend) SendParse(msg *Parse) {
159 if f.encodeError != nil {
160 return
161 }
162
163 prevLen := len(f.wbuf)
164 newBuf, err := msg.Encode(f.wbuf)
165 if err != nil {
166 f.encodeError = err
167 return
168 }
169 f.wbuf = newBuf
170
171 if f.tracer != nil {
172 f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
173 }
174 }
175
176
177
178 func (f *Frontend) SendClose(msg *Close) {
179 if f.encodeError != nil {
180 return
181 }
182
183 prevLen := len(f.wbuf)
184 newBuf, err := msg.Encode(f.wbuf)
185 if err != nil {
186 f.encodeError = err
187 return
188 }
189 f.wbuf = newBuf
190
191 if f.tracer != nil {
192 f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
193 }
194 }
195
196
197
198 func (f *Frontend) SendDescribe(msg *Describe) {
199 if f.encodeError != nil {
200 return
201 }
202
203 prevLen := len(f.wbuf)
204 newBuf, err := msg.Encode(f.wbuf)
205 if err != nil {
206 f.encodeError = err
207 return
208 }
209 f.wbuf = newBuf
210
211 if f.tracer != nil {
212 f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
213 }
214 }
215
216
217
218 func (f *Frontend) SendExecute(msg *Execute) {
219 if f.encodeError != nil {
220 return
221 }
222
223 prevLen := len(f.wbuf)
224 newBuf, err := msg.Encode(f.wbuf)
225 if err != nil {
226 f.encodeError = err
227 return
228 }
229 f.wbuf = newBuf
230
231 if f.tracer != nil {
232 f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
233 }
234 }
235
236
237
238 func (f *Frontend) SendSync(msg *Sync) {
239 if f.encodeError != nil {
240 return
241 }
242
243 prevLen := len(f.wbuf)
244 newBuf, err := msg.Encode(f.wbuf)
245 if err != nil {
246 f.encodeError = err
247 return
248 }
249 f.wbuf = newBuf
250
251 if f.tracer != nil {
252 f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
253 }
254 }
255
256
257
258 func (f *Frontend) SendQuery(msg *Query) {
259 if f.encodeError != nil {
260 return
261 }
262
263 prevLen := len(f.wbuf)
264 newBuf, err := msg.Encode(f.wbuf)
265 if err != nil {
266 f.encodeError = err
267 return
268 }
269 f.wbuf = newBuf
270
271 if f.tracer != nil {
272 f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
273 }
274 }
275
276
277
278
279 func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error {
280 err := f.Flush()
281 if err != nil {
282 return err
283 }
284
285 n, err := f.w.Write(msg)
286 if err != nil {
287 return &writeError{err: err, safeToRetry: n == 0}
288 }
289
290 if f.tracer != nil {
291 f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{})
292 }
293
294 return nil
295 }
296
297 func translateEOFtoErrUnexpectedEOF(err error) error {
298 if err == io.EOF {
299 return io.ErrUnexpectedEOF
300 }
301 return err
302 }
303
304
305 func (f *Frontend) Receive() (BackendMessage, error) {
306 if !f.partialMsg {
307 header, err := f.cr.Next(5)
308 if err != nil {
309 return nil, translateEOFtoErrUnexpectedEOF(err)
310 }
311
312 f.msgType = header[0]
313
314 msgLength := int(binary.BigEndian.Uint32(header[1:]))
315 if msgLength < 4 {
316 return nil, fmt.Errorf("invalid message length: %d", msgLength)
317 }
318
319 f.bodyLen = msgLength - 4
320 f.partialMsg = true
321 }
322
323 msgBody, err := f.cr.Next(f.bodyLen)
324 if err != nil {
325 return nil, translateEOFtoErrUnexpectedEOF(err)
326 }
327
328 f.partialMsg = false
329
330 var msg BackendMessage
331 switch f.msgType {
332 case '1':
333 msg = &f.parseComplete
334 case '2':
335 msg = &f.bindComplete
336 case '3':
337 msg = &f.closeComplete
338 case 'A':
339 msg = &f.notificationResponse
340 case 'c':
341 msg = &f.copyDone
342 case 'C':
343 msg = &f.commandComplete
344 case 'd':
345 msg = &f.copyData
346 case 'D':
347 msg = &f.dataRow
348 case 'E':
349 msg = &f.errorResponse
350 case 'G':
351 msg = &f.copyInResponse
352 case 'H':
353 msg = &f.copyOutResponse
354 case 'I':
355 msg = &f.emptyQueryResponse
356 case 'K':
357 msg = &f.backendKeyData
358 case 'n':
359 msg = &f.noData
360 case 'N':
361 msg = &f.noticeResponse
362 case 'R':
363 var err error
364 msg, err = f.findAuthenticationMessageType(msgBody)
365 if err != nil {
366 return nil, err
367 }
368 case 's':
369 msg = &f.portalSuspended
370 case 'S':
371 msg = &f.parameterStatus
372 case 't':
373 msg = &f.parameterDescription
374 case 'T':
375 msg = &f.rowDescription
376 case 'V':
377 msg = &f.functionCallResponse
378 case 'W':
379 msg = &f.copyBothResponse
380 case 'Z':
381 msg = &f.readyForQuery
382 default:
383 return nil, fmt.Errorf("unknown message type: %c", f.msgType)
384 }
385
386 err = msg.Decode(msgBody)
387 if err != nil {
388 return nil, err
389 }
390
391 if f.tracer != nil {
392 f.tracer.traceMessage('B', int32(5+len(msgBody)), msg)
393 }
394
395 return msg, nil
396 }
397
398
399
400
401 const (
402 AuthTypeOk = 0
403 AuthTypeCleartextPassword = 3
404 AuthTypeMD5Password = 5
405 AuthTypeSCMCreds = 6
406 AuthTypeGSS = 7
407 AuthTypeGSSCont = 8
408 AuthTypeSSPI = 9
409 AuthTypeSASL = 10
410 AuthTypeSASLContinue = 11
411 AuthTypeSASLFinal = 12
412 )
413
414 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
415 if len(src) < 4 {
416 return nil, errors.New("authentication message too short")
417 }
418 f.authType = binary.BigEndian.Uint32(src[:4])
419
420 switch f.authType {
421 case AuthTypeOk:
422 return &f.authenticationOk, nil
423 case AuthTypeCleartextPassword:
424 return &f.authenticationCleartextPassword, nil
425 case AuthTypeMD5Password:
426 return &f.authenticationMD5Password, nil
427 case AuthTypeSCMCreds:
428 return nil, errors.New("AuthTypeSCMCreds is unimplemented")
429 case AuthTypeGSS:
430 return &f.authenticationGSS, nil
431 case AuthTypeGSSCont:
432 return &f.authenticationGSSContinue, nil
433 case AuthTypeSSPI:
434 return nil, errors.New("AuthTypeSSPI is unimplemented")
435 case AuthTypeSASL:
436 return &f.authenticationSASL, nil
437 case AuthTypeSASLContinue:
438 return &f.authenticationSASLContinue, nil
439 case AuthTypeSASLFinal:
440 return &f.authenticationSASLFinal, nil
441 default:
442 return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
443 }
444 }
445
446
447
448 func (f *Frontend) GetAuthType() uint32 {
449 return f.authType
450 }
451
452 func (f *Frontend) ReadBufferLen() int {
453 return f.cr.wp - f.cr.rp
454 }
455
View as plain text