1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 package scram
30
31 import (
32 "bytes"
33 "crypto/hmac"
34 "crypto/rand"
35 "encoding/base64"
36 "fmt"
37 "hash"
38 "strconv"
39 "strings"
40 )
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 type Client struct {
58 newHash func() hash.Hash
59
60 user string
61 pass string
62 step int
63 out bytes.Buffer
64 err error
65
66 clientNonce []byte
67 serverNonce []byte
68 saltedPass []byte
69 authMsg bytes.Buffer
70 }
71
72
73
74
75
76
77
78 func NewClient(newHash func() hash.Hash, user, pass string) *Client {
79 c := &Client{
80 newHash: newHash,
81 user: user,
82 pass: pass,
83 }
84 c.out.Grow(256)
85 c.authMsg.Grow(256)
86 return c
87 }
88
89
90 func (c *Client) Out() []byte {
91 if c.out.Len() == 0 {
92 return nil
93 }
94 return c.out.Bytes()
95 }
96
97
98 func (c *Client) Err() error {
99 return c.err
100 }
101
102
103
104 func (c *Client) SetNonce(nonce []byte) {
105 c.clientNonce = nonce
106 }
107
108 var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
109
110
111
112
113
114 func (c *Client) Step(in []byte) bool {
115 c.out.Reset()
116 if c.step > 2 || c.err != nil {
117 return false
118 }
119 c.step++
120 switch c.step {
121 case 1:
122 c.err = c.step1(in)
123 case 2:
124 c.err = c.step2(in)
125 case 3:
126 c.err = c.step3(in)
127 }
128 return c.step > 2 || c.err != nil
129 }
130
131 func (c *Client) step1(in []byte) error {
132 if len(c.clientNonce) == 0 {
133 const nonceLen = 16
134 buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen))
135 if _, err := rand.Read(buf[:nonceLen]); err != nil {
136 return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err)
137 }
138 c.clientNonce = buf[nonceLen:]
139 b64.Encode(c.clientNonce, buf[:nonceLen])
140 }
141 c.authMsg.WriteString("n=")
142 escaper.WriteString(&c.authMsg, c.user)
143 c.authMsg.WriteString(",r=")
144 c.authMsg.Write(c.clientNonce)
145
146 c.out.WriteString("n,,")
147 c.out.Write(c.authMsg.Bytes())
148 return nil
149 }
150
151 var b64 = base64.StdEncoding
152
153 func (c *Client) step2(in []byte) error {
154 c.authMsg.WriteByte(',')
155 c.authMsg.Write(in)
156
157 fields := bytes.Split(in, []byte(","))
158 if len(fields) != 3 {
159 return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in)
160 }
161 if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
162 return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0])
163 }
164 if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
165 return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1])
166 }
167 if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 {
168 return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
169 }
170
171 c.serverNonce = fields[0][2:]
172 if !bytes.HasPrefix(c.serverNonce, c.clientNonce) {
173 return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce)
174 }
175
176 salt := make([]byte, b64.DecodedLen(len(fields[1][2:])))
177 n, err := b64.Decode(salt, fields[1][2:])
178 if err != nil {
179 return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1])
180 }
181 salt = salt[:n]
182 iterCount, err := strconv.Atoi(string(fields[2][2:]))
183 if err != nil {
184 return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2])
185 }
186 c.saltPassword(salt, iterCount)
187
188 c.authMsg.WriteString(",c=biws,r=")
189 c.authMsg.Write(c.serverNonce)
190
191 c.out.WriteString("c=biws,r=")
192 c.out.Write(c.serverNonce)
193 c.out.WriteString(",p=")
194 c.out.Write(c.clientProof())
195 return nil
196 }
197
198 func (c *Client) step3(in []byte) error {
199 var isv, ise bool
200 var fields = bytes.Split(in, []byte(","))
201 if len(fields) == 1 {
202 isv = bytes.HasPrefix(fields[0], []byte("v="))
203 ise = bytes.HasPrefix(fields[0], []byte("e="))
204 }
205 if ise {
206 return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:])
207 } else if !isv {
208 return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in)
209 }
210 if !bytes.Equal(c.serverSignature(), fields[0][2:]) {
211 return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:])
212 }
213 return nil
214 }
215
216 func (c *Client) saltPassword(salt []byte, iterCount int) {
217 mac := hmac.New(c.newHash, []byte(c.pass))
218 mac.Write(salt)
219 mac.Write([]byte{0, 0, 0, 1})
220 ui := mac.Sum(nil)
221 hi := make([]byte, len(ui))
222 copy(hi, ui)
223 for i := 1; i < iterCount; i++ {
224 mac.Reset()
225 mac.Write(ui)
226 mac.Sum(ui[:0])
227 for j, b := range ui {
228 hi[j] ^= b
229 }
230 }
231 c.saltedPass = hi
232 }
233
234 func (c *Client) clientProof() []byte {
235 mac := hmac.New(c.newHash, c.saltedPass)
236 mac.Write([]byte("Client Key"))
237 clientKey := mac.Sum(nil)
238 hash := c.newHash()
239 hash.Write(clientKey)
240 storedKey := hash.Sum(nil)
241 mac = hmac.New(c.newHash, storedKey)
242 mac.Write(c.authMsg.Bytes())
243 clientProof := mac.Sum(nil)
244 for i, b := range clientKey {
245 clientProof[i] ^= b
246 }
247 clientProof64 := make([]byte, b64.EncodedLen(len(clientProof)))
248 b64.Encode(clientProof64, clientProof)
249 return clientProof64
250 }
251
252 func (c *Client) serverSignature() []byte {
253 mac := hmac.New(c.newHash, c.saltedPass)
254 mac.Write([]byte("Server Key"))
255 serverKey := mac.Sum(nil)
256
257 mac = hmac.New(c.newHash, serverKey)
258 mac.Write(c.authMsg.Bytes())
259 serverSignature := mac.Sum(nil)
260
261 encoded := make([]byte, b64.EncodedLen(len(serverSignature)))
262 b64.Encode(encoded, serverSignature)
263 return encoded
264 }
265
View as plain text