1
2
3
4
5
6
7
8
9
10
11
12
13 package pgconn
14
15 import (
16 "bytes"
17 "crypto/hmac"
18 "crypto/rand"
19 "crypto/sha256"
20 "encoding/base64"
21 "errors"
22 "fmt"
23 "strconv"
24
25 "github.com/jackc/pgx/v5/pgproto3"
26 "golang.org/x/crypto/pbkdf2"
27 "golang.org/x/text/secure/precis"
28 )
29
30 const clientNonceLen = 18
31
32
33 func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
34 sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
35 if err != nil {
36 return err
37 }
38
39
40 saslInitialResponse := &pgproto3.SASLInitialResponse{
41 AuthMechanism: "SCRAM-SHA-256",
42 Data: sc.clientFirstMessage(),
43 }
44 c.frontend.Send(saslInitialResponse)
45 err = c.flushWithPotentialWriteReadDeadlock()
46 if err != nil {
47 return err
48 }
49
50
51 saslContinue, err := c.rxSASLContinue()
52 if err != nil {
53 return err
54 }
55 err = sc.recvServerFirstMessage(saslContinue.Data)
56 if err != nil {
57 return err
58 }
59
60
61 saslResponse := &pgproto3.SASLResponse{
62 Data: []byte(sc.clientFinalMessage()),
63 }
64 c.frontend.Send(saslResponse)
65 err = c.flushWithPotentialWriteReadDeadlock()
66 if err != nil {
67 return err
68 }
69
70
71 saslFinal, err := c.rxSASLFinal()
72 if err != nil {
73 return err
74 }
75 return sc.recvServerFinalMessage(saslFinal.Data)
76 }
77
78 func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
79 msg, err := c.receiveMessage()
80 if err != nil {
81 return nil, err
82 }
83 switch m := msg.(type) {
84 case *pgproto3.AuthenticationSASLContinue:
85 return m, nil
86 case *pgproto3.ErrorResponse:
87 return nil, ErrorResponseToPgError(m)
88 }
89
90 return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
91 }
92
93 func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
94 msg, err := c.receiveMessage()
95 if err != nil {
96 return nil, err
97 }
98 switch m := msg.(type) {
99 case *pgproto3.AuthenticationSASLFinal:
100 return m, nil
101 case *pgproto3.ErrorResponse:
102 return nil, ErrorResponseToPgError(m)
103 }
104
105 return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
106 }
107
108 type scramClient struct {
109 serverAuthMechanisms []string
110 password []byte
111 clientNonce []byte
112
113 clientFirstMessageBare []byte
114
115 serverFirstMessage []byte
116 clientAndServerNonce []byte
117 salt []byte
118 iterations int
119
120 saltedPassword []byte
121 authMessage []byte
122 }
123
124 func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
125 sc := &scramClient{
126 serverAuthMechanisms: serverAuthMechanisms,
127 }
128
129
130 hasScramSHA256 := false
131 for _, mech := range sc.serverAuthMechanisms {
132 if mech == "SCRAM-SHA-256" {
133 hasScramSHA256 = true
134 break
135 }
136 }
137 if !hasScramSHA256 {
138 return nil, errors.New("server does not support SCRAM-SHA-256")
139 }
140
141
142 var err error
143 sc.password, err = precis.OpaqueString.Bytes([]byte(password))
144 if err != nil {
145
146 sc.password = []byte(password)
147 }
148
149 buf := make([]byte, clientNonceLen)
150 _, err = rand.Read(buf)
151 if err != nil {
152 return nil, err
153 }
154 sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
155 base64.RawStdEncoding.Encode(sc.clientNonce, buf)
156
157 return sc, nil
158 }
159
160 func (sc *scramClient) clientFirstMessage() []byte {
161 sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
162 return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
163 }
164
165 func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
166 sc.serverFirstMessage = serverFirstMessage
167 buf := serverFirstMessage
168 if !bytes.HasPrefix(buf, []byte("r=")) {
169 return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
170 }
171 buf = buf[2:]
172
173 idx := bytes.IndexByte(buf, ',')
174 if idx == -1 {
175 return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
176 }
177 sc.clientAndServerNonce = buf[:idx]
178 buf = buf[idx+1:]
179
180 if !bytes.HasPrefix(buf, []byte("s=")) {
181 return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
182 }
183 buf = buf[2:]
184
185 idx = bytes.IndexByte(buf, ',')
186 if idx == -1 {
187 return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
188 }
189 saltStr := buf[:idx]
190 buf = buf[idx+1:]
191
192 if !bytes.HasPrefix(buf, []byte("i=")) {
193 return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
194 }
195 buf = buf[2:]
196 iterationsStr := buf
197
198 var err error
199 sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
200 if err != nil {
201 return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
202 }
203
204 sc.iterations, err = strconv.Atoi(string(iterationsStr))
205 if err != nil || sc.iterations <= 0 {
206 return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
207 }
208
209 if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
210 return errors.New("invalid SCRAM nonce: did not start with client nonce")
211 }
212
213 if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
214 return errors.New("invalid SCRAM nonce: did not include server nonce")
215 }
216
217 return nil
218 }
219
220 func (sc *scramClient) clientFinalMessage() string {
221 clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
222
223 sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
224 sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
225
226 clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
227
228 return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
229 }
230
231 func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
232 if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
233 return errors.New("invalid SCRAM server-final-message received from server")
234 }
235
236 serverSignature := serverFinalMessage[2:]
237
238 if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
239 return errors.New("invalid SCRAM ServerSignature received from server")
240 }
241
242 return nil
243 }
244
245 func computeHMAC(key, msg []byte) []byte {
246 mac := hmac.New(sha256.New, key)
247 mac.Write(msg)
248 return mac.Sum(nil)
249 }
250
251 func computeClientProof(saltedPassword, authMessage []byte) []byte {
252 clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
253 storedKey := sha256.Sum256(clientKey)
254 clientSignature := computeHMAC(storedKey[:], authMessage)
255
256 clientProof := make([]byte, len(clientSignature))
257 for i := 0; i < len(clientSignature); i++ {
258 clientProof[i] = clientKey[i] ^ clientSignature[i]
259 }
260
261 buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
262 base64.StdEncoding.Encode(buf, clientProof)
263 return buf
264 }
265
266 func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
267 serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
268 serverSignature := computeHMAC(serverKey, authMessage)
269 buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
270 base64.StdEncoding.Encode(buf, serverSignature)
271 return buf
272 }
273
View as plain text