...

Source file src/golang.org/x/crypto/ssh/test/test_unix_test.go

Documentation: golang.org/x/crypto/ssh/test

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || plan9 || solaris
     6  
     7  package test
     8  
     9  // functional test harness for unix.
    10  
    11  import (
    12  	"bytes"
    13  	"crypto/rand"
    14  	"encoding/base64"
    15  	"fmt"
    16  	"log"
    17  	"net"
    18  	"os"
    19  	"os/exec"
    20  	"os/user"
    21  	"path/filepath"
    22  	"testing"
    23  	"text/template"
    24  
    25  	"golang.org/x/crypto/internal/testenv"
    26  	"golang.org/x/crypto/ssh"
    27  	"golang.org/x/crypto/ssh/testdata"
    28  )
    29  
    30  const (
    31  	defaultSshdConfig = `
    32  Protocol 2
    33  Banner {{.Dir}}/banner
    34  HostKey {{.Dir}}/id_rsa
    35  HostKey {{.Dir}}/id_dsa
    36  HostKey {{.Dir}}/id_ecdsa
    37  HostCertificate {{.Dir}}/id_rsa-sha2-512-cert.pub
    38  Pidfile {{.Dir}}/sshd.pid
    39  #UsePrivilegeSeparation no
    40  KeyRegenerationInterval 3600
    41  ServerKeyBits 768
    42  SyslogFacility AUTH
    43  LogLevel DEBUG2
    44  LoginGraceTime 120
    45  PermitRootLogin no
    46  StrictModes no
    47  RSAAuthentication yes
    48  PubkeyAuthentication yes
    49  AuthorizedKeysFile	{{.Dir}}/authorized_keys
    50  TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
    51  IgnoreRhosts yes
    52  RhostsRSAAuthentication no
    53  HostbasedAuthentication no
    54  PubkeyAcceptedKeyTypes=*
    55  `
    56  	multiAuthSshdConfigTail = `
    57  UsePAM yes
    58  PasswordAuthentication yes
    59  ChallengeResponseAuthentication yes
    60  AuthenticationMethods {{.AuthMethods}}
    61  `
    62  	maxAuthTriesSshdConfigTail = `
    63  PasswordAuthentication yes
    64  MaxAuthTries 1
    65  `
    66  )
    67  
    68  var configTmpl = map[string]*template.Template{
    69  	"default":      template.Must(template.New("").Parse(defaultSshdConfig)),
    70  	"MultiAuth":    template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail)),
    71  	"MaxAuthTries": template.Must(template.New("").Parse(defaultSshdConfig + maxAuthTriesSshdConfigTail))}
    72  
    73  type server struct {
    74  	t          *testing.T
    75  	configfile string
    76  
    77  	testUser     string // test username for sshd
    78  	testPasswd   string // test password for sshd
    79  	sshdTestPwSo string // dynamic library to inject a custom password into sshd
    80  
    81  	lastDialConn net.Conn
    82  }
    83  
    84  func username() string {
    85  	var username string
    86  	if user, err := user.Current(); err == nil {
    87  		username = user.Username
    88  	} else {
    89  		// user.Current() currently requires cgo. If an error is
    90  		// returned attempt to get the username from the environment.
    91  		log.Printf("user.Current: %v; falling back on $USER", err)
    92  		username = os.Getenv("USER")
    93  	}
    94  	if username == "" {
    95  		panic("Unable to get username")
    96  	}
    97  	return username
    98  }
    99  
   100  type storedHostKey struct {
   101  	// keys map from an algorithm string to binary key data.
   102  	keys map[string][]byte
   103  
   104  	// checkCount counts the Check calls. Used for testing
   105  	// rekeying.
   106  	checkCount int
   107  }
   108  
   109  func (k *storedHostKey) Add(key ssh.PublicKey) {
   110  	if k.keys == nil {
   111  		k.keys = map[string][]byte{}
   112  	}
   113  	k.keys[key.Type()] = key.Marshal()
   114  }
   115  
   116  func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
   117  	k.checkCount++
   118  	algo := key.Type()
   119  
   120  	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
   121  		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
   122  	}
   123  	return nil
   124  }
   125  
   126  func hostKeyDB() *storedHostKey {
   127  	keyChecker := &storedHostKey{}
   128  	keyChecker.Add(testPublicKeys["ecdsa"])
   129  	keyChecker.Add(testPublicKeys["rsa"])
   130  	keyChecker.Add(testPublicKeys["dsa"])
   131  	return keyChecker
   132  }
   133  
   134  func clientConfig() *ssh.ClientConfig {
   135  	config := &ssh.ClientConfig{
   136  		User: username(),
   137  		Auth: []ssh.AuthMethod{
   138  			ssh.PublicKeys(testSigners["user"]),
   139  		},
   140  		HostKeyCallback: hostKeyDB().Check,
   141  		HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
   142  			ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
   143  			ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
   144  			ssh.KeyAlgoED25519,
   145  		},
   146  	}
   147  	return config
   148  }
   149  
   150  // unixConnection creates two halves of a connected net.UnixConn.  It
   151  // is used for connecting the Go SSH client with sshd without opening
   152  // ports.
   153  func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
   154  	dir, err := os.MkdirTemp("", "unixConnection")
   155  	if err != nil {
   156  		return nil, nil, err
   157  	}
   158  	defer os.Remove(dir)
   159  
   160  	addr := filepath.Join(dir, "ssh")
   161  	listener, err := net.Listen("unix", addr)
   162  	if err != nil {
   163  		return nil, nil, err
   164  	}
   165  	defer listener.Close()
   166  	c1, err := net.Dial("unix", addr)
   167  	if err != nil {
   168  		return nil, nil, err
   169  	}
   170  
   171  	c2, err := listener.Accept()
   172  	if err != nil {
   173  		c1.Close()
   174  		return nil, nil, err
   175  	}
   176  
   177  	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
   178  }
   179  
   180  func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
   181  	return s.TryDialWithAddr(config, "")
   182  }
   183  
   184  // addr is the user specified host:port. While we don't actually dial it,
   185  // we need to know this for host key matching
   186  func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
   187  	sshd, err := exec.LookPath("sshd")
   188  	if err != nil {
   189  		s.t.Skipf("skipping test: %v", err)
   190  	}
   191  
   192  	c1, c2, err := unixConnection()
   193  	if err != nil {
   194  		s.t.Fatalf("unixConnection: %v", err)
   195  	}
   196  	defer func() {
   197  		// Close c2 after we've started the sshd command so that it won't prevent c1
   198  		// from returning EOF when the sshd command exits.
   199  		c2.Close()
   200  
   201  		// Leave c1 open if we're returning a client that wraps it.
   202  		// (The client is responsible for closing it.)
   203  		// Otherwise, close it to free up the socket.
   204  		if client == nil {
   205  			c1.Close()
   206  		}
   207  	}()
   208  
   209  	f, err := c2.File()
   210  	if err != nil {
   211  		s.t.Fatalf("UnixConn.File: %v", err)
   212  	}
   213  	defer f.Close()
   214  
   215  	cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
   216  	cmd.Stdin = f
   217  	cmd.Stdout = f
   218  	cmd.Stderr = new(bytes.Buffer)
   219  
   220  	if s.sshdTestPwSo != "" {
   221  		if s.testUser == "" {
   222  			s.t.Fatal("user missing from sshd_test_pw.so config")
   223  		}
   224  		if s.testPasswd == "" {
   225  			s.t.Fatal("password missing from sshd_test_pw.so config")
   226  		}
   227  		cmd.Env = append(os.Environ(),
   228  			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
   229  			fmt.Sprintf("TEST_USER=%s", s.testUser),
   230  			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
   231  	}
   232  
   233  	if err := cmd.Start(); err != nil {
   234  		s.t.Fatalf("s.cmd.Start: %v", err)
   235  	}
   236  	s.lastDialConn = c1
   237  	s.t.Cleanup(func() {
   238  		// Don't check for errors; if it fails it's most
   239  		// likely "os: process already finished", and we don't
   240  		// care about that. Use os.Interrupt, so child
   241  		// processes are killed too.
   242  		cmd.Process.Signal(os.Interrupt)
   243  		cmd.Wait()
   244  		if s.t.Failed() || testing.Verbose() {
   245  			// log any output from sshd process
   246  			s.t.Logf("sshd:\n%s", cmd.Stderr)
   247  		}
   248  	})
   249  
   250  	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	return ssh.NewClient(conn, chans, reqs), nil
   255  }
   256  
   257  func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
   258  	conn, err := s.TryDial(config)
   259  	if err != nil {
   260  		s.t.Fatalf("ssh.Client: %v", err)
   261  	}
   262  	return conn
   263  }
   264  
   265  func writeFile(path string, contents []byte) {
   266  	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
   267  	if err != nil {
   268  		panic(err)
   269  	}
   270  	defer f.Close()
   271  	if _, err := f.Write(contents); err != nil {
   272  		panic(err)
   273  	}
   274  }
   275  
   276  // generate random password
   277  func randomPassword() (string, error) {
   278  	b := make([]byte, 12)
   279  	_, err := rand.Read(b)
   280  	if err != nil {
   281  		return "", err
   282  	}
   283  	return base64.RawURLEncoding.EncodeToString(b), nil
   284  }
   285  
   286  // setTestPassword is used for setting user and password data for sshd_test_pw.so
   287  // This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
   288  func (s *server) setTestPassword(user, passwd string) error {
   289  	wd, _ := os.Getwd()
   290  	wrapper := filepath.Join(wd, "sshd_test_pw.so")
   291  	if _, err := os.Stat(wrapper); err != nil {
   292  		s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
   293  		return err
   294  	}
   295  
   296  	s.sshdTestPwSo = wrapper
   297  	s.testUser = user
   298  	s.testPasswd = passwd
   299  	return nil
   300  }
   301  
   302  // newServer returns a new mock ssh server.
   303  func newServer(t *testing.T) *server {
   304  	return newServerForConfig(t, "default", map[string]string{})
   305  }
   306  
   307  // newServerForConfig returns a new mock ssh server.
   308  func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
   309  	if testing.Short() {
   310  		t.Skip("skipping test due to -short")
   311  	}
   312  	u, err := user.Current()
   313  	if err != nil {
   314  		t.Fatalf("user.Current: %v", err)
   315  	}
   316  	uname := u.Name
   317  	if uname == "" {
   318  		// Check the value of u.Username as u.Name
   319  		// can be "" on some OSes like AIX.
   320  		uname = u.Username
   321  	}
   322  	if uname == "root" {
   323  		t.Skip("skipping test because current user is root")
   324  	}
   325  	dir, err := os.MkdirTemp("", "sshtest")
   326  	if err != nil {
   327  		t.Fatal(err)
   328  	}
   329  	f, err := os.Create(filepath.Join(dir, "sshd_config"))
   330  	if err != nil {
   331  		t.Fatal(err)
   332  	}
   333  	if _, ok := configTmpl[config]; ok == false {
   334  		t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
   335  	}
   336  	configVars["Dir"] = dir
   337  	err = configTmpl[config].Execute(f, configVars)
   338  	if err != nil {
   339  		t.Fatal(err)
   340  	}
   341  	f.Close()
   342  
   343  	writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
   344  
   345  	for k, v := range testdata.PEMBytes {
   346  		filename := "id_" + k
   347  		writeFile(filepath.Join(dir, filename), v)
   348  		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   349  	}
   350  
   351  	for k, v := range testdata.SSHCertificates {
   352  		filename := "id_" + k + "-cert.pub"
   353  		writeFile(filepath.Join(dir, filename), v)
   354  	}
   355  
   356  	var authkeys bytes.Buffer
   357  	for k := range testdata.PEMBytes {
   358  		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
   359  	}
   360  	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
   361  	t.Cleanup(func() {
   362  		if err := os.RemoveAll(dir); err != nil {
   363  			t.Error(err)
   364  		}
   365  	})
   366  
   367  	return &server{
   368  		t:          t,
   369  		configfile: f.Name(),
   370  	}
   371  }
   372  
   373  func newTempSocket(t *testing.T) (string, func()) {
   374  	dir, err := os.MkdirTemp("", "socket")
   375  	if err != nil {
   376  		t.Fatal(err)
   377  	}
   378  	deferFunc := func() { os.RemoveAll(dir) }
   379  	addr := filepath.Join(dir, "sock")
   380  	return addr, deferFunc
   381  }
   382  

View as plain text