...
1 package pgproto3
2
3 import (
4 "encoding/binary"
5 "errors"
6 "fmt"
7 "io"
8 )
9
10
11 type Backend struct {
12 cr ChunkReader
13 w io.Writer
14
15
16 bind Bind
17 cancelRequest CancelRequest
18 _close Close
19 copyFail CopyFail
20 copyData CopyData
21 copyDone CopyDone
22 describe Describe
23 execute Execute
24 flush Flush
25 functionCall FunctionCall
26 gssEncRequest GSSEncRequest
27 parse Parse
28 query Query
29 sslRequest SSLRequest
30 startupMessage StartupMessage
31 sync Sync
32 terminate Terminate
33
34 bodyLen int
35 msgType byte
36 partialMsg bool
37 authType uint32
38 }
39
40 const (
41 minStartupPacketLen = 4
42 maxStartupPacketLen = 10000
43 )
44
45
46 func NewBackend(cr ChunkReader, w io.Writer) *Backend {
47 return &Backend{cr: cr, w: w}
48 }
49
50
51 func (b *Backend) Send(msg BackendMessage) error {
52 buf, err := msg.Encode(nil)
53 if err != nil {
54 return err
55 }
56
57 _, err = b.w.Write(buf)
58 return err
59 }
60
61
62
63
64 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
65 buf, err := b.cr.Next(4)
66 if err != nil {
67 return nil, err
68 }
69 msgSize := int(binary.BigEndian.Uint32(buf) - 4)
70
71 if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
72 return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
73 }
74
75 buf, err = b.cr.Next(msgSize)
76 if err != nil {
77 return nil, translateEOFtoErrUnexpectedEOF(err)
78 }
79
80 code := binary.BigEndian.Uint32(buf)
81
82 switch code {
83 case ProtocolVersionNumber:
84 err = b.startupMessage.Decode(buf)
85 if err != nil {
86 return nil, err
87 }
88 return &b.startupMessage, nil
89 case sslRequestNumber:
90 err = b.sslRequest.Decode(buf)
91 if err != nil {
92 return nil, err
93 }
94 return &b.sslRequest, nil
95 case cancelRequestCode:
96 err = b.cancelRequest.Decode(buf)
97 if err != nil {
98 return nil, err
99 }
100 return &b.cancelRequest, nil
101 case gssEncReqNumber:
102 err = b.gssEncRequest.Decode(buf)
103 if err != nil {
104 return nil, err
105 }
106 return &b.gssEncRequest, nil
107 default:
108 return nil, fmt.Errorf("unknown startup message code: %d", code)
109 }
110 }
111
112
113 func (b *Backend) Receive() (FrontendMessage, error) {
114 if !b.partialMsg {
115 header, err := b.cr.Next(5)
116 if err != nil {
117 return nil, translateEOFtoErrUnexpectedEOF(err)
118 }
119
120 b.msgType = header[0]
121 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
122 b.partialMsg = true
123 if b.bodyLen < 0 {
124 return nil, errors.New("invalid message with negative body length received")
125 }
126 }
127
128 var msg FrontendMessage
129 switch b.msgType {
130 case 'B':
131 msg = &b.bind
132 case 'C':
133 msg = &b._close
134 case 'D':
135 msg = &b.describe
136 case 'E':
137 msg = &b.execute
138 case 'F':
139 msg = &b.functionCall
140 case 'f':
141 msg = &b.copyFail
142 case 'd':
143 msg = &b.copyData
144 case 'c':
145 msg = &b.copyDone
146 case 'H':
147 msg = &b.flush
148 case 'P':
149 msg = &b.parse
150 case 'p':
151 switch b.authType {
152 case AuthTypeSASL:
153 msg = &SASLInitialResponse{}
154 case AuthTypeSASLContinue:
155 msg = &SASLResponse{}
156 case AuthTypeSASLFinal:
157 msg = &SASLResponse{}
158 case AuthTypeGSS, AuthTypeGSSCont:
159 msg = &GSSResponse{}
160 case AuthTypeCleartextPassword, AuthTypeMD5Password:
161 fallthrough
162 default:
163
164 msg = &PasswordMessage{}
165 }
166 case 'Q':
167 msg = &b.query
168 case 'S':
169 msg = &b.sync
170 case 'X':
171 msg = &b.terminate
172 default:
173 return nil, fmt.Errorf("unknown message type: %c", b.msgType)
174 }
175
176 msgBody, err := b.cr.Next(b.bodyLen)
177 if err != nil {
178 return nil, translateEOFtoErrUnexpectedEOF(err)
179 }
180
181 b.partialMsg = false
182
183 err = msg.Decode(msgBody)
184 return msg, err
185 }
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200 func (b *Backend) SetAuthType(authType uint32) error {
201 switch authType {
202 case AuthTypeOk,
203 AuthTypeCleartextPassword,
204 AuthTypeMD5Password,
205 AuthTypeSCMCreds,
206 AuthTypeGSS,
207 AuthTypeGSSCont,
208 AuthTypeSSPI,
209 AuthTypeSASL,
210 AuthTypeSASLContinue,
211 AuthTypeSASLFinal:
212 b.authType = authType
213 default:
214 return fmt.Errorf("authType not recognized: %d", authType)
215 }
216
217 return nil
218 }
219
View as plain text