1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "strings"
13 )
14
15 type authResult int
16
17 const (
18 authFailure authResult = iota
19 authPartialSuccess
20 authSuccess
21 )
22
23
24 func (c *connection) clientAuthenticate(config *ClientConfig) error {
25
26 if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
27 return err
28 }
29 packet, err := c.transport.readPacket()
30 if err != nil {
31 return err
32 }
33
34
35
36 extensions := make(map[string][]byte)
37 if len(packet) > 0 && packet[0] == msgExtInfo {
38 var extInfo extInfoMsg
39 if err := Unmarshal(packet, &extInfo); err != nil {
40 return err
41 }
42 payload := extInfo.Payload
43 for i := uint32(0); i < extInfo.NumExtensions; i++ {
44 name, rest, ok := parseString(payload)
45 if !ok {
46 return parseError(msgExtInfo)
47 }
48 value, rest, ok := parseString(rest)
49 if !ok {
50 return parseError(msgExtInfo)
51 }
52 extensions[string(name)] = value
53 payload = rest
54 }
55 packet, err = c.transport.readPacket()
56 if err != nil {
57 return err
58 }
59 }
60 var serviceAccept serviceAcceptMsg
61 if err := Unmarshal(packet, &serviceAccept); err != nil {
62 return err
63 }
64
65
66
67 var tried []string
68 var lastMethods []string
69
70 sessionID := c.transport.getSessionID()
71 for auth := AuthMethod(new(noneAuth)); auth != nil; {
72 ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
73 if err != nil {
74
75 if _, ok := err.(*disconnectMsg); ok {
76 return err
77 }
78
79
80 ok = authFailure
81 }
82 if ok == authSuccess {
83
84 return nil
85 } else if ok == authFailure {
86 if m := auth.method(); !contains(tried, m) {
87 tried = append(tried, m)
88 }
89 }
90 if methods == nil {
91 methods = lastMethods
92 }
93 lastMethods = methods
94
95 auth = nil
96
97 findNext:
98 for _, a := range config.Auth {
99 candidateMethod := a.method()
100 if contains(tried, candidateMethod) {
101 continue
102 }
103 for _, meth := range methods {
104 if meth == candidateMethod {
105 auth = a
106 break findNext
107 }
108 }
109 }
110
111 if auth == nil && err != nil {
112
113
114 return err
115 }
116 }
117 return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
118 }
119
120 func contains(list []string, e string) bool {
121 for _, s := range list {
122 if s == e {
123 return true
124 }
125 }
126 return false
127 }
128
129
130 type AuthMethod interface {
131
132
133
134
135
136 auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
137
138
139 method() string
140 }
141
142
143 type noneAuth int
144
145 func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
146 if err := c.writePacket(Marshal(&userAuthRequestMsg{
147 User: user,
148 Service: serviceSSH,
149 Method: "none",
150 })); err != nil {
151 return authFailure, nil, err
152 }
153
154 return handleAuthResponse(c)
155 }
156
157 func (n *noneAuth) method() string {
158 return "none"
159 }
160
161
162
163 type passwordCallback func() (password string, err error)
164
165 func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
166 type passwordAuthMsg struct {
167 User string `sshtype:"50"`
168 Service string
169 Method string
170 Reply bool
171 Password string
172 }
173
174 pw, err := cb()
175
176
177
178 if err != nil {
179 return authFailure, nil, err
180 }
181
182 if err := c.writePacket(Marshal(&passwordAuthMsg{
183 User: user,
184 Service: serviceSSH,
185 Method: cb.method(),
186 Reply: false,
187 Password: pw,
188 })); err != nil {
189 return authFailure, nil, err
190 }
191
192 return handleAuthResponse(c)
193 }
194
195 func (cb passwordCallback) method() string {
196 return "password"
197 }
198
199
200 func Password(secret string) AuthMethod {
201 return passwordCallback(func() (string, error) { return secret, nil })
202 }
203
204
205
206 func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
207 return passwordCallback(prompt)
208 }
209
210 type publickeyAuthMsg struct {
211 User string `sshtype:"50"`
212 Service string
213 Method string
214
215
216 HasSig bool
217 Algoname string
218 PubKey []byte
219
220
221 Sig []byte `ssh:"rest"`
222 }
223
224
225
226 type publicKeyCallback func() ([]Signer, error)
227
228 func (cb publicKeyCallback) method() string {
229 return "publickey"
230 }
231
232 func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
233 var as MultiAlgorithmSigner
234 keyFormat := signer.PublicKey().Type()
235
236
237
238
239 switch s := signer.(type) {
240 case MultiAlgorithmSigner:
241 as = s
242 case AlgorithmSigner:
243 as = &multiAlgorithmSigner{
244 AlgorithmSigner: s,
245 supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
246 }
247 default:
248 as = &multiAlgorithmSigner{
249 AlgorithmSigner: algorithmSignerWrapper{signer},
250 supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
251 }
252 }
253
254 getFallbackAlgo := func() (string, error) {
255
256
257
258 if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
259 return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
260 underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
261 }
262 return keyFormat, nil
263 }
264
265 extPayload, ok := extensions["server-sig-algs"]
266 if !ok {
267
268
269 algo, err := getFallbackAlgo()
270 return as, algo, err
271 }
272
273
274
275
276
277 serverAlgos := strings.Split(string(extPayload), ",")
278 for _, algo := range serverAlgos {
279 if certAlgo, ok := certificateAlgo(algo); ok {
280 serverAlgos = append(serverAlgos, certAlgo)
281 }
282 }
283
284
285 var keyAlgos []string
286 for _, algo := range algorithmsForKeyFormat(keyFormat) {
287 if contains(as.Algorithms(), underlyingAlgo(algo)) {
288 keyAlgos = append(keyAlgos, algo)
289 }
290 }
291
292 algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
293 if err != nil {
294
295
296 algo, err := getFallbackAlgo()
297 return as, algo, err
298 }
299 return as, algo, nil
300 }
301
302 func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
303
304
305
306
307
308 signers, err := cb()
309 if err != nil {
310 return authFailure, nil, err
311 }
312 var methods []string
313 var errSigAlgo error
314
315 origSignersLen := len(signers)
316 for idx := 0; idx < len(signers); idx++ {
317 signer := signers[idx]
318 pub := signer.PublicKey()
319 as, algo, err := pickSignatureAlgorithm(signer, extensions)
320 if err != nil && errSigAlgo == nil {
321
322
323
324 errSigAlgo = err
325 continue
326 }
327 ok, err := validateKey(pub, algo, user, c)
328 if err != nil {
329 return authFailure, nil, err
330 }
331
332
333
334
335
336 if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
337 if contains(as.Algorithms(), KeyAlgoRSA) {
338
339
340 signers = append(signers, &multiAlgorithmSigner{
341 AlgorithmSigner: as,
342 supportedAlgorithms: []string{KeyAlgoRSA},
343 })
344 }
345 }
346 if !ok {
347 continue
348 }
349
350 pubKey := pub.Marshal()
351 data := buildDataSignedForAuth(session, userAuthRequestMsg{
352 User: user,
353 Service: serviceSSH,
354 Method: cb.method(),
355 }, algo, pubKey)
356 sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
357 if err != nil {
358 return authFailure, nil, err
359 }
360
361
362 s := Marshal(sign)
363 sig := make([]byte, stringLength(len(s)))
364 marshalString(sig, s)
365 msg := publickeyAuthMsg{
366 User: user,
367 Service: serviceSSH,
368 Method: cb.method(),
369 HasSig: true,
370 Algoname: algo,
371 PubKey: pubKey,
372 Sig: sig,
373 }
374 p := Marshal(&msg)
375 if err := c.writePacket(p); err != nil {
376 return authFailure, nil, err
377 }
378 var success authResult
379 success, methods, err = handleAuthResponse(c)
380 if err != nil {
381 return authFailure, nil, err
382 }
383
384
385
386
387
388 if success == authSuccess || !contains(methods, cb.method()) {
389 return success, methods, err
390 }
391 }
392
393 return authFailure, methods, errSigAlgo
394 }
395
396
397 func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
398 pubKey := key.Marshal()
399 msg := publickeyAuthMsg{
400 User: user,
401 Service: serviceSSH,
402 Method: "publickey",
403 HasSig: false,
404 Algoname: algo,
405 PubKey: pubKey,
406 }
407 if err := c.writePacket(Marshal(&msg)); err != nil {
408 return false, err
409 }
410
411 return confirmKeyAck(key, c)
412 }
413
414 func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
415 pubKey := key.Marshal()
416
417 for {
418 packet, err := c.readPacket()
419 if err != nil {
420 return false, err
421 }
422 switch packet[0] {
423 case msgUserAuthBanner:
424 if err := handleBannerResponse(c, packet); err != nil {
425 return false, err
426 }
427 case msgUserAuthPubKeyOk:
428 var msg userAuthPubKeyOkMsg
429 if err := Unmarshal(packet, &msg); err != nil {
430 return false, err
431 }
432
433
434
435
436
437 if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) {
438 return false, nil
439 }
440 if !bytes.Equal(msg.PubKey, pubKey) {
441 return false, nil
442 }
443 return true, nil
444 case msgUserAuthFailure:
445 return false, nil
446 default:
447 return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
448 }
449 }
450 }
451
452
453
454 func PublicKeys(signers ...Signer) AuthMethod {
455 return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
456 }
457
458
459
460 func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
461 return publicKeyCallback(getSigners)
462 }
463
464
465
466
467 func handleAuthResponse(c packetConn) (authResult, []string, error) {
468 gotMsgExtInfo := false
469 for {
470 packet, err := c.readPacket()
471 if err != nil {
472 return authFailure, nil, err
473 }
474
475 switch packet[0] {
476 case msgUserAuthBanner:
477 if err := handleBannerResponse(c, packet); err != nil {
478 return authFailure, nil, err
479 }
480 case msgExtInfo:
481
482 if gotMsgExtInfo {
483 return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
484 }
485 gotMsgExtInfo = true
486 case msgUserAuthFailure:
487 var msg userAuthFailureMsg
488 if err := Unmarshal(packet, &msg); err != nil {
489 return authFailure, nil, err
490 }
491 if msg.PartialSuccess {
492 return authPartialSuccess, msg.Methods, nil
493 }
494 return authFailure, msg.Methods, nil
495 case msgUserAuthSuccess:
496 return authSuccess, nil, nil
497 default:
498 return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
499 }
500 }
501 }
502
503 func handleBannerResponse(c packetConn, packet []byte) error {
504 var msg userAuthBannerMsg
505 if err := Unmarshal(packet, &msg); err != nil {
506 return err
507 }
508
509 transport, ok := c.(*handshakeTransport)
510 if !ok {
511 return nil
512 }
513
514 if transport.bannerCallback != nil {
515 return transport.bannerCallback(msg.Message)
516 }
517
518 return nil
519 }
520
521
522
523
524
525
526
527
528 type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
529
530
531
532 func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
533 return challenge
534 }
535
536 func (cb KeyboardInteractiveChallenge) method() string {
537 return "keyboard-interactive"
538 }
539
540 func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
541 type initiateMsg struct {
542 User string `sshtype:"50"`
543 Service string
544 Method string
545 Language string
546 Submethods string
547 }
548
549 if err := c.writePacket(Marshal(&initiateMsg{
550 User: user,
551 Service: serviceSSH,
552 Method: "keyboard-interactive",
553 })); err != nil {
554 return authFailure, nil, err
555 }
556
557 gotMsgExtInfo := false
558 gotUserAuthInfoRequest := false
559 for {
560 packet, err := c.readPacket()
561 if err != nil {
562 return authFailure, nil, err
563 }
564
565
566 switch packet[0] {
567 case msgUserAuthBanner:
568 if err := handleBannerResponse(c, packet); err != nil {
569 return authFailure, nil, err
570 }
571 continue
572 case msgExtInfo:
573
574 if gotMsgExtInfo {
575 return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
576 }
577 gotMsgExtInfo = true
578 continue
579 case msgUserAuthInfoRequest:
580
581 case msgUserAuthFailure:
582 var msg userAuthFailureMsg
583 if err := Unmarshal(packet, &msg); err != nil {
584 return authFailure, nil, err
585 }
586 if msg.PartialSuccess {
587 return authPartialSuccess, msg.Methods, nil
588 }
589 if !gotUserAuthInfoRequest {
590 return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
591 }
592 return authFailure, msg.Methods, nil
593 case msgUserAuthSuccess:
594 return authSuccess, nil, nil
595 default:
596 return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
597 }
598
599 var msg userAuthInfoRequestMsg
600 if err := Unmarshal(packet, &msg); err != nil {
601 return authFailure, nil, err
602 }
603 gotUserAuthInfoRequest = true
604
605
606 rest := msg.Prompts
607 var prompts []string
608 var echos []bool
609 for i := 0; i < int(msg.NumPrompts); i++ {
610 prompt, r, ok := parseString(rest)
611 if !ok || len(r) == 0 {
612 return authFailure, nil, errors.New("ssh: prompt format error")
613 }
614 prompts = append(prompts, string(prompt))
615 echos = append(echos, r[0] != 0)
616 rest = r[1:]
617 }
618
619 if len(rest) != 0 {
620 return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
621 }
622
623 answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
624 if err != nil {
625 return authFailure, nil, err
626 }
627
628 if len(answers) != len(prompts) {
629 return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts))
630 }
631 responseLength := 1 + 4
632 for _, a := range answers {
633 responseLength += stringLength(len(a))
634 }
635 serialized := make([]byte, responseLength)
636 p := serialized
637 p[0] = msgUserAuthInfoResponse
638 p = p[1:]
639 p = marshalUint32(p, uint32(len(answers)))
640 for _, a := range answers {
641 p = marshalString(p, []byte(a))
642 }
643
644 if err := c.writePacket(serialized); err != nil {
645 return authFailure, nil, err
646 }
647 }
648 }
649
650 type retryableAuthMethod struct {
651 authMethod AuthMethod
652 maxTries int
653 }
654
655 func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
656 for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
657 ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
658 if ok != authFailure || err != nil {
659 return ok, methods, err
660 }
661 }
662 return ok, methods, err
663 }
664
665 func (r *retryableAuthMethod) method() string {
666 return r.authMethod.method()
667 }
668
669
670
671
672
673
674
675
676
677
678
679
680 func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
681 return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
682 }
683
684
685
686
687
688 func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod {
689 if gssAPIClient == nil {
690 panic("gss-api client must be not nil with enable gssapi-with-mic")
691 }
692 return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target}
693 }
694
695 type gssAPIWithMICCallback struct {
696 gssAPIClient GSSAPIClient
697 target string
698 }
699
700 func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
701 m := &userAuthRequestMsg{
702 User: user,
703 Service: serviceSSH,
704 Method: g.method(),
705 }
706
707
708 m.Payload = appendU32(m.Payload, 1)
709 m.Payload = appendString(m.Payload, string(krb5OID))
710 if err := c.writePacket(Marshal(m)); err != nil {
711 return authFailure, nil, err
712 }
713
714
715
716
717
718
719 packet, err := c.readPacket()
720 if err != nil {
721 return authFailure, nil, err
722 }
723 userAuthGSSAPIResp := &userAuthGSSAPIResponse{}
724 if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil {
725 return authFailure, nil, err
726 }
727
728
729 var token []byte
730 defer g.gssAPIClient.DeleteSecContext()
731 for {
732
733 nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false)
734 if err != nil {
735 return authFailure, nil, err
736 }
737 if len(nextToken) > 0 {
738 if err := c.writePacket(Marshal(&userAuthGSSAPIToken{
739 Token: nextToken,
740 })); err != nil {
741 return authFailure, nil, err
742 }
743 }
744 if !needContinue {
745 break
746 }
747 packet, err = c.readPacket()
748 if err != nil {
749 return authFailure, nil, err
750 }
751 switch packet[0] {
752 case msgUserAuthFailure:
753 var msg userAuthFailureMsg
754 if err := Unmarshal(packet, &msg); err != nil {
755 return authFailure, nil, err
756 }
757 if msg.PartialSuccess {
758 return authPartialSuccess, msg.Methods, nil
759 }
760 return authFailure, msg.Methods, nil
761 case msgUserAuthGSSAPIError:
762 userAuthGSSAPIErrorResp := &userAuthGSSAPIError{}
763 if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil {
764 return authFailure, nil, err
765 }
766 return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+
767 "Major Status: %d\n"+
768 "Minor Status: %d\n"+
769 "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus,
770 userAuthGSSAPIErrorResp.Message)
771 case msgUserAuthGSSAPIToken:
772 userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
773 if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
774 return authFailure, nil, err
775 }
776 token = userAuthGSSAPITokenReq.Token
777 }
778 }
779
780
781 micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic")
782 micToken, err := g.gssAPIClient.GetMIC(micField)
783 if err != nil {
784 return authFailure, nil, err
785 }
786 if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{
787 MIC: micToken,
788 })); err != nil {
789 return authFailure, nil, err
790 }
791 return handleAuthResponse(c)
792 }
793
794 func (g *gssAPIWithMICCallback) method() string {
795 return "gssapi-with-mic"
796 }
797
View as plain text