...
1 package pgproto3
2
3 import (
4 "encoding/binary"
5 "errors"
6 "fmt"
7 "io"
8 )
9
10
11 type Frontend struct {
12 cr ChunkReader
13 w io.Writer
14
15
16 authenticationOk AuthenticationOk
17 authenticationCleartextPassword AuthenticationCleartextPassword
18 authenticationMD5Password AuthenticationMD5Password
19 authenticationGSS AuthenticationGSS
20 authenticationGSSContinue AuthenticationGSSContinue
21 authenticationSASL AuthenticationSASL
22 authenticationSASLContinue AuthenticationSASLContinue
23 authenticationSASLFinal AuthenticationSASLFinal
24 backendKeyData BackendKeyData
25 bindComplete BindComplete
26 closeComplete CloseComplete
27 commandComplete CommandComplete
28 copyBothResponse CopyBothResponse
29 copyData CopyData
30 copyInResponse CopyInResponse
31 copyOutResponse CopyOutResponse
32 copyDone CopyDone
33 dataRow DataRow
34 emptyQueryResponse EmptyQueryResponse
35 errorResponse ErrorResponse
36 functionCallResponse FunctionCallResponse
37 noData NoData
38 noticeResponse NoticeResponse
39 notificationResponse NotificationResponse
40 parameterDescription ParameterDescription
41 parameterStatus ParameterStatus
42 parseComplete ParseComplete
43 readyForQuery ReadyForQuery
44 rowDescription RowDescription
45 portalSuspended PortalSuspended
46
47 bodyLen int
48 msgType byte
49 partialMsg bool
50 authType uint32
51 }
52
53
54 func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
55 return &Frontend{cr: cr, w: w}
56 }
57
58
59 func (f *Frontend) Send(msg FrontendMessage) error {
60 buf, err := msg.Encode(nil)
61 if err != nil {
62 return err
63 }
64 _, err = f.w.Write(buf)
65 return err
66 }
67
68 func translateEOFtoErrUnexpectedEOF(err error) error {
69 if err == io.EOF {
70 return io.ErrUnexpectedEOF
71 }
72 return err
73 }
74
75
76 func (f *Frontend) Receive() (BackendMessage, error) {
77 if !f.partialMsg {
78 header, err := f.cr.Next(5)
79 if err != nil {
80 return nil, translateEOFtoErrUnexpectedEOF(err)
81 }
82
83 f.msgType = header[0]
84 f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
85 f.partialMsg = true
86 if f.bodyLen < 0 {
87 return nil, errors.New("invalid message with negative body length received")
88 }
89 }
90
91 msgBody, err := f.cr.Next(f.bodyLen)
92 if err != nil {
93 return nil, translateEOFtoErrUnexpectedEOF(err)
94 }
95
96 f.partialMsg = false
97
98 var msg BackendMessage
99 switch f.msgType {
100 case '1':
101 msg = &f.parseComplete
102 case '2':
103 msg = &f.bindComplete
104 case '3':
105 msg = &f.closeComplete
106 case 'A':
107 msg = &f.notificationResponse
108 case 'c':
109 msg = &f.copyDone
110 case 'C':
111 msg = &f.commandComplete
112 case 'd':
113 msg = &f.copyData
114 case 'D':
115 msg = &f.dataRow
116 case 'E':
117 msg = &f.errorResponse
118 case 'G':
119 msg = &f.copyInResponse
120 case 'H':
121 msg = &f.copyOutResponse
122 case 'I':
123 msg = &f.emptyQueryResponse
124 case 'K':
125 msg = &f.backendKeyData
126 case 'n':
127 msg = &f.noData
128 case 'N':
129 msg = &f.noticeResponse
130 case 'R':
131 var err error
132 msg, err = f.findAuthenticationMessageType(msgBody)
133 if err != nil {
134 return nil, err
135 }
136 case 's':
137 msg = &f.portalSuspended
138 case 'S':
139 msg = &f.parameterStatus
140 case 't':
141 msg = &f.parameterDescription
142 case 'T':
143 msg = &f.rowDescription
144 case 'V':
145 msg = &f.functionCallResponse
146 case 'W':
147 msg = &f.copyBothResponse
148 case 'Z':
149 msg = &f.readyForQuery
150 default:
151 return nil, fmt.Errorf("unknown message type: %c", f.msgType)
152 }
153
154 err = msg.Decode(msgBody)
155 return msg, err
156 }
157
158
159
160
161 const (
162 AuthTypeOk = 0
163 AuthTypeCleartextPassword = 3
164 AuthTypeMD5Password = 5
165 AuthTypeSCMCreds = 6
166 AuthTypeGSS = 7
167 AuthTypeGSSCont = 8
168 AuthTypeSSPI = 9
169 AuthTypeSASL = 10
170 AuthTypeSASLContinue = 11
171 AuthTypeSASLFinal = 12
172 )
173
174 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
175 if len(src) < 4 {
176 return nil, errors.New("authentication message too short")
177 }
178 f.authType = binary.BigEndian.Uint32(src[:4])
179
180 switch f.authType {
181 case AuthTypeOk:
182 return &f.authenticationOk, nil
183 case AuthTypeCleartextPassword:
184 return &f.authenticationCleartextPassword, nil
185 case AuthTypeMD5Password:
186 return &f.authenticationMD5Password, nil
187 case AuthTypeSCMCreds:
188 return nil, errors.New("AuthTypeSCMCreds is unimplemented")
189 case AuthTypeGSS:
190 return &f.authenticationGSS, nil
191 case AuthTypeGSSCont:
192 return &f.authenticationGSSContinue, nil
193 case AuthTypeSSPI:
194 return nil, errors.New("AuthTypeSSPI is unimplemented")
195 case AuthTypeSASL:
196 return &f.authenticationSASL, nil
197 case AuthTypeSASLContinue:
198 return &f.authenticationSASLContinue, nil
199 case AuthTypeSASLFinal:
200 return &f.authenticationSASLFinal, nil
201 default:
202 return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
203 }
204 }
205
206
207
208 func (f *Frontend) GetAuthType() uint32 {
209 return f.authType
210 }
211
View as plain text