...

Source file src/golang.org/x/crypto/ssh/server_test.go

Documentation: golang.org/x/crypto/ssh

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"reflect"
    14  	"strings"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
    21  	for _, tt := range []struct {
    22  		name      string
    23  		key       Signer
    24  		wantError bool
    25  	}{
    26  		{"rsa", testSigners["rsa"], false},
    27  		{"dsa", testSigners["dsa"], true},
    28  		{"ed25519", testSigners["ed25519"], true},
    29  	} {
    30  		c1, c2, err := netPipe()
    31  		if err != nil {
    32  			t.Fatalf("netPipe: %v", err)
    33  		}
    34  		defer c1.Close()
    35  		defer c2.Close()
    36  		serverConf := &ServerConfig{
    37  			PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
    38  			PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
    39  				return nil, nil
    40  			},
    41  		}
    42  		serverConf.AddHostKey(testSigners["ecdsap256"])
    43  
    44  		done := make(chan struct{})
    45  		go func() {
    46  			defer close(done)
    47  			NewServerConn(c1, serverConf)
    48  		}()
    49  
    50  		clientConf := ClientConfig{
    51  			User: "user",
    52  			Auth: []AuthMethod{
    53  				PublicKeys(tt.key),
    54  			},
    55  			HostKeyCallback: InsecureIgnoreHostKey(),
    56  		}
    57  
    58  		_, _, _, err = NewClientConn(c2, "", &clientConf)
    59  		if err != nil {
    60  			if !tt.wantError {
    61  				t.Errorf("%s: got unexpected error %q", tt.name, err.Error())
    62  			}
    63  		} else if tt.wantError {
    64  			t.Errorf("%s: succeeded, but want error", tt.name)
    65  		}
    66  		<-done
    67  	}
    68  }
    69  
    70  func TestMaxAuthTriesNoneMethod(t *testing.T) {
    71  	username := "testuser"
    72  	serverConfig := &ServerConfig{
    73  		MaxAuthTries: 2,
    74  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
    75  			if conn.User() == username && string(password) == clientPassword {
    76  				return nil, nil
    77  			}
    78  			return nil, errors.New("invalid credentials")
    79  		},
    80  	}
    81  	c1, c2, err := netPipe()
    82  	if err != nil {
    83  		t.Fatalf("netPipe: %v", err)
    84  	}
    85  	defer c1.Close()
    86  	defer c2.Close()
    87  
    88  	var serverAuthErrors []error
    89  
    90  	serverConfig.AddHostKey(testSigners["rsa"])
    91  	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
    92  		serverAuthErrors = append(serverAuthErrors, err)
    93  	}
    94  	go newServer(c1, serverConfig)
    95  
    96  	clientConfig := ClientConfig{
    97  		User:            username,
    98  		HostKeyCallback: InsecureIgnoreHostKey(),
    99  	}
   100  	clientConfig.SetDefaults()
   101  	// Our client will send 'none' auth only once, so we need to send the
   102  	// requests manually.
   103  	c := &connection{
   104  		sshConn: sshConn{
   105  			conn:          c2,
   106  			user:          username,
   107  			clientVersion: []byte(packageVersion),
   108  		},
   109  	}
   110  	c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
   111  	if err != nil {
   112  		t.Fatalf("unable to exchange version: %v", err)
   113  	}
   114  	c.transport = newClientTransport(
   115  		newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */),
   116  		c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr())
   117  	if err := c.transport.waitSession(); err != nil {
   118  		t.Fatalf("unable to wait session: %v", err)
   119  	}
   120  	c.sessionID = c.transport.getSessionID()
   121  	if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
   122  		t.Fatalf("unable to send ssh-userauth message: %v", err)
   123  	}
   124  	packet, err := c.transport.readPacket()
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	if len(packet) > 0 && packet[0] == msgExtInfo {
   129  		packet, err = c.transport.readPacket()
   130  		if err != nil {
   131  			t.Fatal(err)
   132  		}
   133  	}
   134  	var serviceAccept serviceAcceptMsg
   135  	if err := Unmarshal(packet, &serviceAccept); err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	for i := 0; i <= serverConfig.MaxAuthTries; i++ {
   139  		auth := new(noneAuth)
   140  		_, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil)
   141  		if i < serverConfig.MaxAuthTries {
   142  			if err != nil {
   143  				t.Fatal(err)
   144  			}
   145  			continue
   146  		}
   147  		if err == nil {
   148  			t.Fatal("client: got no error")
   149  		} else if !strings.Contains(err.Error(), "too many authentication failures") {
   150  			t.Fatalf("client: got unexpected error: %v", err)
   151  		}
   152  	}
   153  	if len(serverAuthErrors) != 3 {
   154  		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
   155  	}
   156  	for _, err := range serverAuthErrors {
   157  		if !errors.Is(err, ErrNoAuth) {
   158  			t.Errorf("go error: %v; want: %v", err, ErrNoAuth)
   159  		}
   160  	}
   161  }
   162  
   163  func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) {
   164  	username := "testuser"
   165  	serverConfig := &ServerConfig{
   166  		MaxAuthTries: 1,
   167  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   168  			if conn.User() == username && string(password) == clientPassword {
   169  				return nil, nil
   170  			}
   171  			return nil, errors.New("invalid credentials")
   172  		},
   173  	}
   174  	clientConfig := &ClientConfig{
   175  		User: username,
   176  		Auth: []AuthMethod{
   177  			Password(clientPassword),
   178  		},
   179  		HostKeyCallback: InsecureIgnoreHostKey(),
   180  	}
   181  
   182  	serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
   183  	if err != nil {
   184  		t.Fatalf("client login error: %s", err)
   185  	}
   186  	if len(serverAuthErrors) != 2 {
   187  		t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
   188  	}
   189  	if !errors.Is(serverAuthErrors[0], ErrNoAuth) {
   190  		t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth)
   191  	}
   192  	if serverAuthErrors[1] != nil {
   193  		t.Errorf("unexpected error: %v", serverAuthErrors[1])
   194  	}
   195  }
   196  
   197  func TestNewServerConnValidationErrors(t *testing.T) {
   198  	serverConf := &ServerConfig{
   199  		PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
   200  	}
   201  	c := &markerConn{}
   202  	_, _, _, err := NewServerConn(c, serverConf)
   203  	if err == nil {
   204  		t.Fatal("NewServerConn with invalid public key auth algorithms succeeded")
   205  	}
   206  	if !c.isClosed() {
   207  		t.Fatal("NewServerConn with invalid public key auth algorithms left connection open")
   208  	}
   209  	if c.isUsed() {
   210  		t.Fatal("NewServerConn with invalid public key auth algorithms used connection")
   211  	}
   212  
   213  	serverConf = &ServerConfig{
   214  		Config: Config{
   215  			KeyExchanges: []string{kexAlgoDHGEXSHA256},
   216  		},
   217  	}
   218  	c = &markerConn{}
   219  	_, _, _, err = NewServerConn(c, serverConf)
   220  	if err == nil {
   221  		t.Fatal("NewServerConn with unsupported key exchange succeeded")
   222  	}
   223  	if !c.isClosed() {
   224  		t.Fatal("NewServerConn with unsupported key exchange left connection open")
   225  	}
   226  	if c.isUsed() {
   227  		t.Fatal("NewServerConn with unsupported key exchange used connection")
   228  	}
   229  }
   230  
   231  func TestBannerError(t *testing.T) {
   232  	serverConfig := &ServerConfig{
   233  		BannerCallback: func(ConnMetadata) string {
   234  			return "banner from BannerCallback"
   235  		},
   236  		NoClientAuth: true,
   237  		NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
   238  			err := &BannerError{
   239  				Err:     errors.New("error from NoClientAuthCallback"),
   240  				Message: "banner from NoClientAuthCallback",
   241  			}
   242  			return nil, fmt.Errorf("wrapped: %w", err)
   243  		},
   244  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   245  			return &Permissions{}, nil
   246  		},
   247  		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
   248  			return nil, &BannerError{
   249  				Err:     errors.New("error from PublicKeyCallback"),
   250  				Message: "banner from PublicKeyCallback",
   251  			}
   252  		},
   253  		KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
   254  			return nil, &BannerError{
   255  				Err:     nil, // make sure that a nil inner error is allowed
   256  				Message: "banner from KeyboardInteractiveCallback",
   257  			}
   258  		},
   259  	}
   260  	serverConfig.AddHostKey(testSigners["rsa"])
   261  
   262  	var banners []string
   263  	clientConfig := &ClientConfig{
   264  		User: "test",
   265  		Auth: []AuthMethod{
   266  			PublicKeys(testSigners["rsa"]),
   267  			KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
   268  				return []string{"letmein"}, nil
   269  			}),
   270  			Password(clientPassword),
   271  		},
   272  		HostKeyCallback: InsecureIgnoreHostKey(),
   273  		BannerCallback: func(msg string) error {
   274  			banners = append(banners, msg)
   275  			return nil
   276  		},
   277  	}
   278  
   279  	c1, c2, err := netPipe()
   280  	if err != nil {
   281  		t.Fatalf("netPipe: %v", err)
   282  	}
   283  	defer c1.Close()
   284  	defer c2.Close()
   285  	go newServer(c1, serverConfig)
   286  	c, _, _, err := NewClientConn(c2, "", clientConfig)
   287  	if err != nil {
   288  		t.Fatalf("client connection failed: %v", err)
   289  	}
   290  	defer c.Close()
   291  
   292  	wantBanners := []string{
   293  		"banner from BannerCallback",
   294  		"banner from NoClientAuthCallback",
   295  		"banner from PublicKeyCallback",
   296  		"banner from KeyboardInteractiveCallback",
   297  	}
   298  	if !reflect.DeepEqual(banners, wantBanners) {
   299  		t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
   300  	}
   301  }
   302  
   303  func TestPublicKeyCallbackLastSeen(t *testing.T) {
   304  	var lastSeenKey PublicKey
   305  
   306  	c1, c2, err := netPipe()
   307  	if err != nil {
   308  		t.Fatalf("netPipe: %v", err)
   309  	}
   310  	defer c1.Close()
   311  	defer c2.Close()
   312  	serverConf := &ServerConfig{
   313  		PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
   314  			lastSeenKey = key
   315  			fmt.Printf("seen %#v\n", key)
   316  			if _, ok := key.(*dsaPublicKey); !ok {
   317  				return nil, errors.New("nope")
   318  			}
   319  			return nil, nil
   320  		},
   321  	}
   322  	serverConf.AddHostKey(testSigners["ecdsap256"])
   323  
   324  	done := make(chan struct{})
   325  	go func() {
   326  		defer close(done)
   327  		NewServerConn(c1, serverConf)
   328  	}()
   329  
   330  	clientConf := ClientConfig{
   331  		User: "user",
   332  		Auth: []AuthMethod{
   333  			PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]),
   334  		},
   335  		HostKeyCallback: InsecureIgnoreHostKey(),
   336  	}
   337  
   338  	_, _, _, err = NewClientConn(c2, "", &clientConf)
   339  	if err != nil {
   340  		t.Fatal(err)
   341  	}
   342  	<-done
   343  
   344  	expectedPublicKey := testSigners["dsa"].PublicKey().Marshal()
   345  	lastSeenMarshalled := lastSeenKey.Marshal()
   346  	if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) {
   347  		t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey())
   348  	}
   349  }
   350  
   351  type markerConn struct {
   352  	closed uint32
   353  	used   uint32
   354  }
   355  
   356  func (c *markerConn) isClosed() bool {
   357  	return atomic.LoadUint32(&c.closed) != 0
   358  }
   359  
   360  func (c *markerConn) isUsed() bool {
   361  	return atomic.LoadUint32(&c.used) != 0
   362  }
   363  
   364  func (c *markerConn) Close() error {
   365  	atomic.StoreUint32(&c.closed, 1)
   366  	return nil
   367  }
   368  
   369  func (c *markerConn) Read(b []byte) (n int, err error) {
   370  	atomic.StoreUint32(&c.used, 1)
   371  	if atomic.LoadUint32(&c.closed) != 0 {
   372  		return 0, net.ErrClosed
   373  	} else {
   374  		return 0, io.EOF
   375  	}
   376  }
   377  
   378  func (c *markerConn) Write(b []byte) (n int, err error) {
   379  	atomic.StoreUint32(&c.used, 1)
   380  	if atomic.LoadUint32(&c.closed) != 0 {
   381  		return 0, net.ErrClosed
   382  	} else {
   383  		return 0, io.ErrClosedPipe
   384  	}
   385  }
   386  
   387  func (*markerConn) LocalAddr() net.Addr  { return nil }
   388  func (*markerConn) RemoteAddr() net.Addr { return nil }
   389  
   390  func (*markerConn) SetDeadline(t time.Time) error      { return nil }
   391  func (*markerConn) SetReadDeadline(t time.Time) error  { return nil }
   392  func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }
   393  

View as plain text