1
2
3
4
5
6
7
8
9
10
11
12
13 package pgconn
14
15 import (
16 "bytes"
17 "crypto/hmac"
18 "crypto/rand"
19 "crypto/sha256"
20 "encoding/base64"
21 "errors"
22 "fmt"
23 "strconv"
24
25 "github.com/jackc/pgproto3/v2"
26 "golang.org/x/crypto/pbkdf2"
27 "golang.org/x/text/secure/precis"
28 )
29
30 const clientNonceLen = 18
31
32
33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
34 sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
35 if err != nil {
36 return err
37 }
38
39
40 saslInitialResponse := &pgproto3.SASLInitialResponse{
41 AuthMechanism: "SCRAM-SHA-256",
42 Data: sc.clientFirstMessage(),
43 }
44 buf, err := saslInitialResponse.Encode(nil)
45 if err != nil {
46 return err
47 }
48 _, err = c.conn.Write(buf)
49 if err != nil {
50 return err
51 }
52
53
54 saslContinue, err := c.rxSASLContinue()
55 if err != nil {
56 return err
57 }
58 err = sc.recvServerFirstMessage(saslContinue.Data)
59 if err != nil {
60 return err
61 }
62
63
64 saslResponse := &pgproto3.SASLResponse{
65 Data: []byte(sc.clientFinalMessage()),
66 }
67 buf, err = saslResponse.Encode(nil)
68 if err != nil {
69 return err
70 }
71 _, err = c.conn.Write(buf)
72 if err != nil {
73 return err
74 }
75
76
77 saslFinal, err := c.rxSASLFinal()
78 if err != nil {
79 return err
80 }
81 return sc.recvServerFinalMessage(saslFinal.Data)
82 }
83
84 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
85 msg, err := c.receiveMessage()
86 if err != nil {
87 return nil, err
88 }
89 switch m := msg.(type) {
90 case *pgproto3.AuthenticationSASLContinue:
91 return m, nil
92 case *pgproto3.ErrorResponse:
93 return nil, ErrorResponseToPgError(m)
94 }
95
96 return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
97 }
98
99 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
100 msg, err := c.receiveMessage()
101 if err != nil {
102 return nil, err
103 }
104 switch m := msg.(type) {
105 case *pgproto3.AuthenticationSASLFinal:
106 return m, nil
107 case *pgproto3.ErrorResponse:
108 return nil, ErrorResponseToPgError(m)
109 }
110
111 return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
112 }
113
114 type scramClient struct {
115 serverAuthMechanisms []string
116 password []byte
117 clientNonce []byte
118
119 clientFirstMessageBare []byte
120
121 serverFirstMessage []byte
122 clientAndServerNonce []byte
123 salt []byte
124 iterations int
125
126 saltedPassword []byte
127 authMessage []byte
128 }
129
130 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
131 sc := &scramClient{
132 serverAuthMechanisms: serverAuthMechanisms,
133 }
134
135
136 hasScramSHA256 := false
137 for _, mech := range sc.serverAuthMechanisms {
138 if mech == "SCRAM-SHA-256" {
139 hasScramSHA256 = true
140 break
141 }
142 }
143 if !hasScramSHA256 {
144 return nil, errors.New("server does not support SCRAM-SHA-256")
145 }
146
147
148 var err error
149 sc.password, err = precis.OpaqueString.Bytes([]byte(password))
150 if err != nil {
151
152 sc.password = []byte(password)
153 }
154
155 buf := make([]byte, clientNonceLen)
156 _, err = rand.Read(buf)
157 if err != nil {
158 return nil, err
159 }
160 sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
161 base64.RawStdEncoding.Encode(sc.clientNonce, buf)
162
163 return sc, nil
164 }
165
166 func (sc *scramClient) clientFirstMessage() []byte {
167 sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
168 return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
169 }
170
171 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
172 sc.serverFirstMessage = serverFirstMessage
173 buf := serverFirstMessage
174 if !bytes.HasPrefix(buf, []byte("r=")) {
175 return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
176 }
177 buf = buf[2:]
178
179 idx := bytes.IndexByte(buf, ',')
180 if idx == -1 {
181 return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
182 }
183 sc.clientAndServerNonce = buf[:idx]
184 buf = buf[idx+1:]
185
186 if !bytes.HasPrefix(buf, []byte("s=")) {
187 return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
188 }
189 buf = buf[2:]
190
191 idx = bytes.IndexByte(buf, ',')
192 if idx == -1 {
193 return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
194 }
195 saltStr := buf[:idx]
196 buf = buf[idx+1:]
197
198 if !bytes.HasPrefix(buf, []byte("i=")) {
199 return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
200 }
201 buf = buf[2:]
202 iterationsStr := buf
203
204 var err error
205 sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
206 if err != nil {
207 return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
208 }
209
210 sc.iterations, err = strconv.Atoi(string(iterationsStr))
211 if err != nil || sc.iterations <= 0 {
212 return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
213 }
214
215 if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
216 return errors.New("invalid SCRAM nonce: did not start with client nonce")
217 }
218
219 if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
220 return errors.New("invalid SCRAM nonce: did not include server nonce")
221 }
222
223 return nil
224 }
225
226 func (sc *scramClient) clientFinalMessage() string {
227 clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
228
229 sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
230 sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
231
232 clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
233
234 return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
235 }
236
237 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
238 if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
239 return errors.New("invalid SCRAM server-final-message received from server")
240 }
241
242 serverSignature := serverFinalMessage[2:]
243
244 if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
245 return errors.New("invalid SCRAM ServerSignature received from server")
246 }
247
248 return nil
249 }
250
251 func computeHMAC(key, msg []byte) []byte {
252 mac := hmac.New(sha256.New, key)
253 mac.Write(msg)
254 return mac.Sum(nil)
255 }
256
257 func computeClientProof(saltedPassword, authMessage []byte) []byte {
258 clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
259 storedKey := sha256.Sum256(clientKey)
260 clientSignature := computeHMAC(storedKey[:], authMessage)
261
262 clientProof := make([]byte, len(clientSignature))
263 for i := 0; i < len(clientSignature); i++ {
264 clientProof[i] = clientKey[i] ^ clientSignature[i]
265 }
266
267 buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
268 base64.StdEncoding.Encode(buf, clientProof)
269 return buf
270 }
271
272 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
273 serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
274 serverSignature := computeHMAC(serverKey, authMessage)
275 buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
276 base64.StdEncoding.Encode(buf, serverSignature)
277 return buf
278 }
279
View as plain text