1
2
3 package scram
4
5 import (
6 "bytes"
7 "context"
8 "crypto/hmac"
9 "crypto/rand"
10 "crypto/sha256"
11 "crypto/sha512"
12 "encoding/base64"
13 "errors"
14 "fmt"
15 "hash"
16 "strconv"
17 "strings"
18
19 "golang.org/x/crypto/pbkdf2"
20
21 "github.com/twmb/franz-go/pkg/sasl"
22 )
23
24
25
26
27
28 type Auth struct {
29
30 Zid string
31
32
33
34
35
36
37
38
39
40 User string
41
42
43 Pass string
44
45
46
47 Nonce []byte
48
49
50
51
52
53 IsToken bool
54
55 _ struct{}
56 }
57
58
59
60
61
62
63 func (a Auth) AsSha256Mechanism() sasl.Mechanism {
64 return Sha256(func(context.Context) (Auth, error) {
65 return a, nil
66 })
67 }
68
69
70
71
72
73
74 func (a Auth) AsSha512Mechanism() sasl.Mechanism {
75 return Sha512(func(context.Context) (Auth, error) {
76 return a, nil
77 })
78 }
79
80
81
82
83 func Sha256(authFn func(context.Context) (Auth, error)) sasl.Mechanism {
84 return scram{authFn, sha256.New, "SCRAM-SHA-256"}
85 }
86
87
88
89
90 func Sha512(authFn func(context.Context) (Auth, error)) sasl.Mechanism {
91 return scram{authFn, sha512.New, "SCRAM-SHA-512"}
92 }
93
94 type scram struct {
95 authFn func(context.Context) (Auth, error)
96 newhash func() hash.Hash
97 name string
98 }
99
100 var escaper = strings.NewReplacer("=", "=3D", ",", "=2C")
101
102 func (s scram) Name() string { return s.name }
103 func (s scram) Authenticate(ctx context.Context, _ string) (sasl.Session, []byte, error) {
104 auth, err := s.authFn(ctx)
105 if err != nil {
106 return nil, nil, err
107 }
108 if auth.User == "" || auth.Pass == "" {
109 return nil, nil, errors.New(s.name + " user and pass must be non-empty")
110 }
111 if len(auth.Nonce) == 0 {
112 buf := make([]byte, 20)
113 if _, err = rand.Read(buf); err != nil {
114 return nil, nil, err
115 }
116 auth.Nonce = buf
117 }
118
119 auth.Nonce = []byte(base64.RawStdEncoding.EncodeToString(auth.Nonce))
120
121 clientFirstMsgBare := make([]byte, 0, 100)
122 clientFirstMsgBare = append(clientFirstMsgBare, "n="...)
123 clientFirstMsgBare = append(clientFirstMsgBare, escaper.Replace(auth.User)...)
124 clientFirstMsgBare = append(clientFirstMsgBare, ",r="...)
125 clientFirstMsgBare = append(clientFirstMsgBare, auth.Nonce...)
126 if auth.IsToken {
127 clientFirstMsgBare = append(clientFirstMsgBare, ",tokenauth=true"...)
128 }
129
130 gs2Header := "n,"
131 if auth.Zid != "" {
132 gs2Header += "a=" + escaper.Replace(auth.Zid)
133 }
134 gs2Header += ","
135 clientFirstMsg := append([]byte(gs2Header), clientFirstMsgBare...)
136 return &session{
137 step: 0,
138 auth: auth,
139 newhash: s.newhash,
140
141 clientFirstMsgBare: clientFirstMsgBare,
142 }, clientFirstMsg, nil
143 }
144
145 type session struct {
146 step int
147 auth Auth
148 newhash func() hash.Hash
149
150 clientFirstMsgBare []byte
151 expServerSignature []byte
152 }
153
154 func (s *session) Challenge(resp []byte) (bool, []byte, error) {
155 step := s.step
156 s.step++
157 switch step {
158 case 0:
159 response, err := s.authenticateClient(resp)
160 return false, response, err
161 case 1:
162 err := s.verifyServer(resp)
163 return err == nil, nil, err
164 default:
165 return false, nil, fmt.Errorf("challenge / response should be done, but still going at %d", step)
166 }
167 }
168
169
170
171 func (s *session) authenticateClient(serverFirstMsg []byte) ([]byte, error) {
172 kvs := bytes.Split(serverFirstMsg, []byte(","))
173 if len(kvs) < 3 {
174 return nil, fmt.Errorf("got %d kvs != exp min 3", len(kvs))
175 }
176
177
178 if !bytes.HasPrefix(kvs[0], []byte("r=")) {
179 return nil, fmt.Errorf("unexpected kv %q where nonce expected", kvs[0])
180 }
181 serverNonce := kvs[0][2:]
182 if !bytes.HasPrefix(serverNonce, s.auth.Nonce) {
183 return nil, errors.New("server did not reply with nonce beginning with client nonce")
184 }
185
186
187 if !bytes.HasPrefix(kvs[1], []byte("s=")) {
188 return nil, fmt.Errorf("unexpected kv %q where salt expected", kvs[1])
189 }
190 salt, err := base64.StdEncoding.DecodeString(string(kvs[1][2:]))
191 if err != nil {
192 return nil, fmt.Errorf("server salt %q decode err: %v", kvs[1][2:], err)
193 }
194
195
196 if !bytes.HasPrefix(kvs[2], []byte("i=")) {
197 return nil, fmt.Errorf("unexpected kv %q where iterations expected", kvs[2])
198 }
199 iters, err := strconv.Atoi(string(kvs[2][2:]))
200 if err != nil {
201 return nil, fmt.Errorf("server iterations %q parse err: %v", kvs[2][2:], err)
202 }
203 if iters < 4096 {
204 return nil, fmt.Errorf("server iterations %d less than minimum 4096", iters)
205 }
206
207
208
209
210
211 h := s.newhash()
212 saltedPassword := pbkdf2.Key([]byte(s.auth.Pass), salt, iters, h.Size(), s.newhash)
213
214 mac := hmac.New(s.newhash, saltedPassword)
215 if _, err = mac.Write([]byte("Client Key")); err != nil {
216 return nil, fmt.Errorf("hmac err: %v", err)
217 }
218 clientKey := mac.Sum(nil)
219 if _, err = h.Write(clientKey); err != nil {
220 return nil, fmt.Errorf("sha err: %v", err)
221 }
222 storedKey := h.Sum(nil)
223
224
225 clientFinalMsgWithoutProof := append([]byte("c=biws,r="), serverNonce...)
226 authMsg := append(s.clientFirstMsgBare, ',')
227 authMsg = append(authMsg, serverFirstMsg...)
228 authMsg = append(authMsg, ',')
229 authMsg = append(authMsg, clientFinalMsgWithoutProof...)
230
231 mac = hmac.New(s.newhash, storedKey)
232 if _, err = mac.Write(authMsg); err != nil {
233 return nil, fmt.Errorf("hmac err: %v", err)
234 }
235 clientSignature := mac.Sum(nil)
236
237 clientProof := clientSignature
238 for i, c := range clientKey {
239 clientProof[i] ^= c
240 }
241
242 mac = hmac.New(s.newhash, saltedPassword)
243 if _, err = mac.Write([]byte("Server Key")); err != nil {
244 return nil, fmt.Errorf("hmac err: %v", err)
245 }
246 serverKey := mac.Sum(nil)
247 mac = hmac.New(s.newhash, serverKey)
248 if _, err = mac.Write(authMsg); err != nil {
249 return nil, fmt.Errorf("hmac err: %v", err)
250 }
251 s.expServerSignature = []byte(base64.StdEncoding.EncodeToString(mac.Sum(nil)))
252
253 clientFinalMsg := append(clientFinalMsgWithoutProof, ",p="...)
254 clientFinalMsg = append(clientFinalMsg, base64.StdEncoding.EncodeToString(clientProof)...)
255 return clientFinalMsg, nil
256 }
257
258 func (s *session) verifyServer(serverFinalMsg []byte) error {
259 kvs := bytes.Split(serverFinalMsg, []byte(","))
260 if len(kvs) < 1 {
261 return errors.New("received no kvs, even though this should be impossible")
262 }
263
264 kv := kvs[0]
265 if isErr := bytes.HasPrefix(kv, []byte("e=")); isErr {
266 return fmt.Errorf("server sent authentication error %q", kv[2:])
267 }
268 if !bytes.HasPrefix(kv, []byte("v=")) {
269 return fmt.Errorf("server sent unexpected first kv %q", kv)
270 }
271 if !bytes.Equal(s.expServerSignature, kv[2:]) {
272 return fmt.Errorf("server signature mismatch; got %q != exp %q", kv[2:], s.expServerSignature)
273 }
274 return nil
275 }
276
View as plain text