1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "strings"
12 "testing"
13 )
14
15 func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) {
16 c1, c2, err := netPipe()
17 if err != nil {
18 t.Fatalf("netPipe: %v", err)
19 }
20 defer c1.Close()
21 defer c2.Close()
22
23 var serverAuthErrors []error
24
25 serverConfig.AddHostKey(testSigners["rsa"])
26 serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
27 serverAuthErrors = append(serverAuthErrors, err)
28 }
29 go newServer(c1, serverConfig)
30 c, _, _, err := NewClientConn(c2, "", clientConfig)
31 if err == nil {
32 c.Close()
33 }
34 return serverAuthErrors, err
35 }
36
37 func TestMultiStepAuth(t *testing.T) {
38
39 username := "testuser"
40
41 usernameSecondFactor := "testuser_second_factor"
42 errPwdAuthFailed := errors.New("password auth failed")
43 errWrongSequence := errors.New("wrong sequence")
44
45 serverConfig := &ServerConfig{
46 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
47 if conn.User() == usernameSecondFactor {
48 return nil, errWrongSequence
49 }
50 if conn.User() == username && string(password) == clientPassword {
51 return nil, nil
52 }
53 return nil, errPwdAuthFailed
54 },
55 PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
56 if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
57 if conn.User() == usernameSecondFactor {
58 return nil, &PartialSuccessError{
59 Next: ServerAuthCallbacks{
60 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
61 if string(password) == clientPassword {
62 return nil, nil
63 }
64 return nil, errPwdAuthFailed
65 },
66 },
67 }
68 }
69 return nil, nil
70 }
71 return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
72 },
73 }
74
75 clientConfig := &ClientConfig{
76 User: usernameSecondFactor,
77 Auth: []AuthMethod{
78 PublicKeys(testSigners["rsa"]),
79 Password(clientPassword),
80 },
81 HostKeyCallback: InsecureIgnoreHostKey(),
82 }
83
84 serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
85 if err != nil {
86 t.Fatalf("client login error: %s", err)
87 }
88
89
90
91
92
93 if len(serverAuthErrors) != 3 {
94 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
95 }
96 if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
97 t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1])
98 }
99
100 clientConfig.Auth = []AuthMethod{
101 Password(clientPassword),
102 PublicKeys(testSigners["rsa"]),
103 }
104
105 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
106 if err == nil {
107 t.Fatal("client login with wrong sequence must fail")
108 }
109
110
111
112
113 if len(serverAuthErrors) != 3 {
114 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
115 }
116 if serverAuthErrors[1] != errWrongSequence {
117 t.Fatal("server not returned wrong sequence")
118 }
119 if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok {
120 t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2])
121 }
122
123
124 n := 0
125 passwords := []string{"WRONG", "WRONG", clientPassword}
126 clientConfig.Auth = []AuthMethod{
127 PublicKeys(testSigners["rsa"]),
128 RetryableAuthMethod(PasswordCallback(func() (string, error) {
129 p := passwords[n]
130 n++
131 return p, nil
132 }), 3),
133 }
134
135 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
136 if err != nil {
137 t.Fatalf("client login error: %s", err)
138 }
139
140
141
142
143
144
145 if len(serverAuthErrors) != 5 {
146 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
147 }
148 if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
149 t.Fatal("server not returned partial success")
150 }
151 if serverAuthErrors[2] != errPwdAuthFailed {
152 t.Fatal("server not returned password authentication failed")
153 }
154 if serverAuthErrors[3] != errPwdAuthFailed {
155 t.Fatal("server not returned password authentication failed")
156 }
157
158 clientConfig.Auth = []AuthMethod{
159 Password(clientPassword),
160 }
161
162 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
163 if err == nil {
164 t.Fatal("client login with password only must fail")
165 }
166
167
168
169 if len(serverAuthErrors) != 2 {
170 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
171 }
172 if serverAuthErrors[1] != errWrongSequence {
173 t.Fatal("server not returned wrong sequence")
174 }
175
176
177 clientConfig.Auth = []AuthMethod{
178 PublicKeys(testSigners["rsa"]),
179 }
180
181 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
182 if err == nil {
183 t.Fatal("client login with public key only must fail")
184 }
185
186
187
188 if len(serverAuthErrors) != 2 {
189 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
190 }
191 if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
192 t.Fatal("server not returned partial success")
193 }
194
195
196 clientConfig.Auth = []AuthMethod{
197 PublicKeys(testSigners["rsa"]),
198 Password("WRONG"),
199 }
200
201 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
202 if err == nil {
203 t.Fatal("client login with wrong password after public key must fail")
204 }
205
206
207
208
209 if len(serverAuthErrors) != 3 {
210 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
211 }
212 if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
213 t.Fatal("server not returned partial success")
214 }
215 if serverAuthErrors[2] != errPwdAuthFailed {
216 t.Fatal("server not returned password authentication failed")
217 }
218
219
220
221
222 clientConfig.Auth = []AuthMethod{
223 PublicKeys(testSigners["rsa"]),
224 PublicKeys(testSigners["rsa"]),
225 Password(clientPassword),
226 }
227
228 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
229 if err != nil {
230 t.Fatalf("client login error: %s", err)
231 }
232
233
234
235
236 if len(serverAuthErrors) != 3 {
237 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
238 }
239 if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok {
240 t.Fatal("server not returned partial success")
241 }
242
243
244 clientConfig = &ClientConfig{
245 User: username,
246 Auth: []AuthMethod{
247 PublicKeys(testSigners["rsa"]),
248 Password(clientPassword),
249 },
250 HostKeyCallback: InsecureIgnoreHostKey(),
251 }
252
253 _, err = doClientServerAuth(t, serverConfig, clientConfig)
254 if err != nil {
255 t.Fatalf("unrestricted client login error: %s", err)
256 }
257
258 clientConfig = &ClientConfig{
259 User: username,
260 Auth: []AuthMethod{
261 PublicKeys(testSigners["rsa"]),
262 },
263 HostKeyCallback: InsecureIgnoreHostKey(),
264 }
265
266 _, err = doClientServerAuth(t, serverConfig, clientConfig)
267 if err != nil {
268 t.Fatalf("unrestricted client login error: %s", err)
269 }
270
271 clientConfig = &ClientConfig{
272 User: username,
273 Auth: []AuthMethod{
274 Password(clientPassword),
275 },
276 HostKeyCallback: InsecureIgnoreHostKey(),
277 }
278
279 _, err = doClientServerAuth(t, serverConfig, clientConfig)
280 if err != nil {
281 t.Fatalf("unrestricted client login error: %s", err)
282 }
283 }
284
285 func TestDynamicAuthCallbacks(t *testing.T) {
286 user1 := "user1"
287 user2 := "user2"
288 errInvalidCredentials := errors.New("invalid credentials")
289
290 serverConfig := &ServerConfig{
291 NoClientAuth: true,
292 NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) {
293 switch conn.User() {
294 case user1:
295 return nil, &PartialSuccessError{
296 Next: ServerAuthCallbacks{
297 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
298 if conn.User() == user1 && string(password) == clientPassword {
299 return nil, nil
300 }
301 return nil, errInvalidCredentials
302 },
303 },
304 }
305 case user2:
306 return nil, &PartialSuccessError{
307 Next: ServerAuthCallbacks{
308 PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
309 if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
310 if conn.User() == user2 {
311 return nil, nil
312 }
313 }
314 return nil, errInvalidCredentials
315 },
316 },
317 }
318 default:
319 return nil, errInvalidCredentials
320 }
321 },
322 }
323
324 clientConfig := &ClientConfig{
325 User: user1,
326 Auth: []AuthMethod{
327 Password(clientPassword),
328 },
329 HostKeyCallback: InsecureIgnoreHostKey(),
330 }
331
332 serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
333 if err != nil {
334 t.Fatalf("client login error: %s", err)
335 }
336
337
338
339 if len(serverAuthErrors) != 2 {
340 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
341 }
342 if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
343 t.Fatal("server not returned partial success")
344 }
345
346 clientConfig = &ClientConfig{
347 User: user2,
348 Auth: []AuthMethod{
349 PublicKeys(testSigners["rsa"]),
350 },
351 HostKeyCallback: InsecureIgnoreHostKey(),
352 }
353
354 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
355 if err != nil {
356 t.Fatalf("client login error: %s", err)
357 }
358
359
360
361 if len(serverAuthErrors) != 2 {
362 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
363 }
364 if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
365 t.Fatal("server not returned partial success")
366 }
367
368
369 clientConfig = &ClientConfig{
370 User: user1,
371 Auth: []AuthMethod{
372 PublicKeys(testSigners["rsa"]),
373 },
374 HostKeyCallback: InsecureIgnoreHostKey(),
375 }
376
377 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
378 if err == nil {
379 t.Fatal("user1 login with public key must fail")
380 }
381 if !strings.Contains(err.Error(), "no supported methods remain") {
382 t.Errorf("got %v, expected 'no supported methods remain'", err)
383 }
384 if len(serverAuthErrors) != 1 {
385 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
386 }
387 if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
388 t.Fatal("server not returned partial success")
389 }
390
391 clientConfig = &ClientConfig{
392 User: user2,
393 Auth: []AuthMethod{
394 Password(clientPassword),
395 },
396 HostKeyCallback: InsecureIgnoreHostKey(),
397 }
398
399 serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig)
400 if err == nil {
401 t.Fatal("user2 login with password must fail")
402 }
403 if !strings.Contains(err.Error(), "no supported methods remain") {
404 t.Errorf("got %v, expected 'no supported methods remain'", err)
405 }
406 if len(serverAuthErrors) != 1 {
407 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
408 }
409 if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok {
410 t.Fatal("server not returned partial success")
411 }
412 }
413
View as plain text