// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "bytes" "errors" "fmt" "io" "net" "reflect" "strings" "sync/atomic" "testing" "time" ) func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { for _, tt := range []struct { name string key Signer wantError bool }{ {"rsa", testSigners["rsa"], false}, {"dsa", testSigners["dsa"], true}, {"ed25519", testSigners["ed25519"], true}, } { c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() serverConf := &ServerConfig{ PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { return nil, nil }, } serverConf.AddHostKey(testSigners["ecdsap256"]) done := make(chan struct{}) go func() { defer close(done) NewServerConn(c1, serverConf) }() clientConf := ClientConfig{ User: "user", Auth: []AuthMethod{ PublicKeys(tt.key), }, HostKeyCallback: InsecureIgnoreHostKey(), } _, _, _, err = NewClientConn(c2, "", &clientConf) if err != nil { if !tt.wantError { t.Errorf("%s: got unexpected error %q", tt.name, err.Error()) } } else if tt.wantError { t.Errorf("%s: succeeded, but want error", tt.name) } <-done } } func TestMaxAuthTriesNoneMethod(t *testing.T) { username := "testuser" serverConfig := &ServerConfig{ MaxAuthTries: 2, PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { if conn.User() == username && string(password) == clientPassword { return nil, nil } return nil, errors.New("invalid credentials") }, } c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() var serverAuthErrors []error serverConfig.AddHostKey(testSigners["rsa"]) serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { serverAuthErrors = append(serverAuthErrors, err) } go newServer(c1, serverConfig) clientConfig := ClientConfig{ User: username, HostKeyCallback: InsecureIgnoreHostKey(), } clientConfig.SetDefaults() // Our client will send 'none' auth only once, so we need to send the // requests manually. c := &connection{ sshConn: sshConn{ conn: c2, user: username, clientVersion: []byte(packageVersion), }, } c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) if err != nil { t.Fatalf("unable to exchange version: %v", err) } c.transport = newClientTransport( newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */), c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr()) if err := c.transport.waitSession(); err != nil { t.Fatalf("unable to wait session: %v", err) } c.sessionID = c.transport.getSessionID() if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { t.Fatalf("unable to send ssh-userauth message: %v", err) } packet, err := c.transport.readPacket() if err != nil { t.Fatal(err) } if len(packet) > 0 && packet[0] == msgExtInfo { packet, err = c.transport.readPacket() if err != nil { t.Fatal(err) } } var serviceAccept serviceAcceptMsg if err := Unmarshal(packet, &serviceAccept); err != nil { t.Fatal(err) } for i := 0; i <= serverConfig.MaxAuthTries; i++ { auth := new(noneAuth) _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil) if i < serverConfig.MaxAuthTries { if err != nil { t.Fatal(err) } continue } if err == nil { t.Fatal("client: got no error") } else if !strings.Contains(err.Error(), "too many authentication failures") { t.Fatalf("client: got unexpected error: %v", err) } } if len(serverAuthErrors) != 3 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } for _, err := range serverAuthErrors { if !errors.Is(err, ErrNoAuth) { t.Errorf("go error: %v; want: %v", err, ErrNoAuth) } } } func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) { username := "testuser" serverConfig := &ServerConfig{ MaxAuthTries: 1, PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { if conn.User() == username && string(password) == clientPassword { return nil, nil } return nil, errors.New("invalid credentials") }, } clientConfig := &ClientConfig{ User: username, Auth: []AuthMethod{ Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), } serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) if err != nil { t.Fatalf("client login error: %s", err) } if len(serverAuthErrors) != 2 { t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) } if !errors.Is(serverAuthErrors[0], ErrNoAuth) { t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth) } if serverAuthErrors[1] != nil { t.Errorf("unexpected error: %v", serverAuthErrors[1]) } } func TestNewServerConnValidationErrors(t *testing.T) { serverConf := &ServerConfig{ PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, } c := &markerConn{} _, _, _, err := NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") } if !c.isClosed() { t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") } if c.isUsed() { t.Fatal("NewServerConn with invalid public key auth algorithms used connection") } serverConf = &ServerConfig{ Config: Config{ KeyExchanges: []string{kexAlgoDHGEXSHA256}, }, } c = &markerConn{} _, _, _, err = NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with unsupported key exchange succeeded") } if !c.isClosed() { t.Fatal("NewServerConn with unsupported key exchange left connection open") } if c.isUsed() { t.Fatal("NewServerConn with unsupported key exchange used connection") } } func TestBannerError(t *testing.T) { serverConfig := &ServerConfig{ BannerCallback: func(ConnMetadata) string { return "banner from BannerCallback" }, NoClientAuth: true, NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { err := &BannerError{ Err: errors.New("error from NoClientAuthCallback"), Message: "banner from NoClientAuthCallback", } return nil, fmt.Errorf("wrapped: %w", err) }, PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { return &Permissions{}, nil }, PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { return nil, &BannerError{ Err: errors.New("error from PublicKeyCallback"), Message: "banner from PublicKeyCallback", } }, KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { return nil, &BannerError{ Err: nil, // make sure that a nil inner error is allowed Message: "banner from KeyboardInteractiveCallback", } }, } serverConfig.AddHostKey(testSigners["rsa"]) var banners []string clientConfig := &ClientConfig{ User: "test", Auth: []AuthMethod{ PublicKeys(testSigners["rsa"]), KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { return []string{"letmein"}, nil }), Password(clientPassword), }, HostKeyCallback: InsecureIgnoreHostKey(), BannerCallback: func(msg string) error { banners = append(banners, msg) return nil }, } c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() go newServer(c1, serverConfig) c, _, _, err := NewClientConn(c2, "", clientConfig) if err != nil { t.Fatalf("client connection failed: %v", err) } defer c.Close() wantBanners := []string{ "banner from BannerCallback", "banner from NoClientAuthCallback", "banner from PublicKeyCallback", "banner from KeyboardInteractiveCallback", } if !reflect.DeepEqual(banners, wantBanners) { t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) } } func TestPublicKeyCallbackLastSeen(t *testing.T) { var lastSeenKey PublicKey c1, c2, err := netPipe() if err != nil { t.Fatalf("netPipe: %v", err) } defer c1.Close() defer c2.Close() serverConf := &ServerConfig{ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { lastSeenKey = key fmt.Printf("seen %#v\n", key) if _, ok := key.(*dsaPublicKey); !ok { return nil, errors.New("nope") } return nil, nil }, } serverConf.AddHostKey(testSigners["ecdsap256"]) done := make(chan struct{}) go func() { defer close(done) NewServerConn(c1, serverConf) }() clientConf := ClientConfig{ User: "user", Auth: []AuthMethod{ PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]), }, HostKeyCallback: InsecureIgnoreHostKey(), } _, _, _, err = NewClientConn(c2, "", &clientConf) if err != nil { t.Fatal(err) } <-done expectedPublicKey := testSigners["dsa"].PublicKey().Marshal() lastSeenMarshalled := lastSeenKey.Marshal() if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) { t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey()) } } type markerConn struct { closed uint32 used uint32 } func (c *markerConn) isClosed() bool { return atomic.LoadUint32(&c.closed) != 0 } func (c *markerConn) isUsed() bool { return atomic.LoadUint32(&c.used) != 0 } func (c *markerConn) Close() error { atomic.StoreUint32(&c.closed, 1) return nil } func (c *markerConn) Read(b []byte) (n int, err error) { atomic.StoreUint32(&c.used, 1) if atomic.LoadUint32(&c.closed) != 0 { return 0, net.ErrClosed } else { return 0, io.EOF } } func (c *markerConn) Write(b []byte) (n int, err error) { atomic.StoreUint32(&c.used, 1) if atomic.LoadUint32(&c.closed) != 0 { return 0, net.ErrClosed } else { return 0, io.ErrClosedPipe } } func (*markerConn) LocalAddr() net.Addr { return nil } func (*markerConn) RemoteAddr() net.Addr { return nil } func (*markerConn) SetDeadline(t time.Time) error { return nil } func (*markerConn) SetReadDeadline(t time.Time) error { return nil } func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }