1
2
3
4
5
6
7 package test
8
9
10
11 import (
12 "bytes"
13 "crypto/rand"
14 "encoding/base64"
15 "fmt"
16 "log"
17 "net"
18 "os"
19 "os/exec"
20 "os/user"
21 "path/filepath"
22 "testing"
23 "text/template"
24
25 "golang.org/x/crypto/internal/testenv"
26 "golang.org/x/crypto/ssh"
27 "golang.org/x/crypto/ssh/testdata"
28 )
29
30 const (
31 defaultSshdConfig = `
32 Protocol 2
33 Banner {{.Dir}}/banner
34 HostKey {{.Dir}}/id_rsa
35 HostKey {{.Dir}}/id_dsa
36 HostKey {{.Dir}}/id_ecdsa
37 HostCertificate {{.Dir}}/id_rsa-sha2-512-cert.pub
38 Pidfile {{.Dir}}/sshd.pid
39 #UsePrivilegeSeparation no
40 KeyRegenerationInterval 3600
41 ServerKeyBits 768
42 SyslogFacility AUTH
43 LogLevel DEBUG2
44 LoginGraceTime 120
45 PermitRootLogin no
46 StrictModes no
47 RSAAuthentication yes
48 PubkeyAuthentication yes
49 AuthorizedKeysFile {{.Dir}}/authorized_keys
50 TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
51 IgnoreRhosts yes
52 RhostsRSAAuthentication no
53 HostbasedAuthentication no
54 PubkeyAcceptedKeyTypes=*
55 `
56 multiAuthSshdConfigTail = `
57 UsePAM yes
58 PasswordAuthentication yes
59 ChallengeResponseAuthentication yes
60 AuthenticationMethods {{.AuthMethods}}
61 `
62 maxAuthTriesSshdConfigTail = `
63 PasswordAuthentication yes
64 MaxAuthTries 1
65 `
66 )
67
68 var configTmpl = map[string]*template.Template{
69 "default": template.Must(template.New("").Parse(defaultSshdConfig)),
70 "MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail)),
71 "MaxAuthTries": template.Must(template.New("").Parse(defaultSshdConfig + maxAuthTriesSshdConfigTail))}
72
73 type server struct {
74 t *testing.T
75 configfile string
76
77 testUser string
78 testPasswd string
79 sshdTestPwSo string
80
81 lastDialConn net.Conn
82 }
83
84 func username() string {
85 var username string
86 if user, err := user.Current(); err == nil {
87 username = user.Username
88 } else {
89
90
91 log.Printf("user.Current: %v; falling back on $USER", err)
92 username = os.Getenv("USER")
93 }
94 if username == "" {
95 panic("Unable to get username")
96 }
97 return username
98 }
99
100 type storedHostKey struct {
101
102 keys map[string][]byte
103
104
105
106 checkCount int
107 }
108
109 func (k *storedHostKey) Add(key ssh.PublicKey) {
110 if k.keys == nil {
111 k.keys = map[string][]byte{}
112 }
113 k.keys[key.Type()] = key.Marshal()
114 }
115
116 func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
117 k.checkCount++
118 algo := key.Type()
119
120 if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
121 return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
122 }
123 return nil
124 }
125
126 func hostKeyDB() *storedHostKey {
127 keyChecker := &storedHostKey{}
128 keyChecker.Add(testPublicKeys["ecdsa"])
129 keyChecker.Add(testPublicKeys["rsa"])
130 keyChecker.Add(testPublicKeys["dsa"])
131 return keyChecker
132 }
133
134 func clientConfig() *ssh.ClientConfig {
135 config := &ssh.ClientConfig{
136 User: username(),
137 Auth: []ssh.AuthMethod{
138 ssh.PublicKeys(testSigners["user"]),
139 },
140 HostKeyCallback: hostKeyDB().Check,
141 HostKeyAlgorithms: []string{
142 ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
143 ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
144 ssh.KeyAlgoED25519,
145 },
146 }
147 return config
148 }
149
150
151
152
153 func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
154 dir, err := os.MkdirTemp("", "unixConnection")
155 if err != nil {
156 return nil, nil, err
157 }
158 defer os.Remove(dir)
159
160 addr := filepath.Join(dir, "ssh")
161 listener, err := net.Listen("unix", addr)
162 if err != nil {
163 return nil, nil, err
164 }
165 defer listener.Close()
166 c1, err := net.Dial("unix", addr)
167 if err != nil {
168 return nil, nil, err
169 }
170
171 c2, err := listener.Accept()
172 if err != nil {
173 c1.Close()
174 return nil, nil, err
175 }
176
177 return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
178 }
179
180 func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
181 return s.TryDialWithAddr(config, "")
182 }
183
184
185
186 func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (client *ssh.Client, err error) {
187 sshd, err := exec.LookPath("sshd")
188 if err != nil {
189 s.t.Skipf("skipping test: %v", err)
190 }
191
192 c1, c2, err := unixConnection()
193 if err != nil {
194 s.t.Fatalf("unixConnection: %v", err)
195 }
196 defer func() {
197
198
199 c2.Close()
200
201
202
203
204 if client == nil {
205 c1.Close()
206 }
207 }()
208
209 f, err := c2.File()
210 if err != nil {
211 s.t.Fatalf("UnixConn.File: %v", err)
212 }
213 defer f.Close()
214
215 cmd := testenv.Command(s.t, sshd, "-f", s.configfile, "-i", "-e")
216 cmd.Stdin = f
217 cmd.Stdout = f
218 cmd.Stderr = new(bytes.Buffer)
219
220 if s.sshdTestPwSo != "" {
221 if s.testUser == "" {
222 s.t.Fatal("user missing from sshd_test_pw.so config")
223 }
224 if s.testPasswd == "" {
225 s.t.Fatal("password missing from sshd_test_pw.so config")
226 }
227 cmd.Env = append(os.Environ(),
228 fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
229 fmt.Sprintf("TEST_USER=%s", s.testUser),
230 fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
231 }
232
233 if err := cmd.Start(); err != nil {
234 s.t.Fatalf("s.cmd.Start: %v", err)
235 }
236 s.lastDialConn = c1
237 s.t.Cleanup(func() {
238
239
240
241
242 cmd.Process.Signal(os.Interrupt)
243 cmd.Wait()
244 if s.t.Failed() || testing.Verbose() {
245
246 s.t.Logf("sshd:\n%s", cmd.Stderr)
247 }
248 })
249
250 conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
251 if err != nil {
252 return nil, err
253 }
254 return ssh.NewClient(conn, chans, reqs), nil
255 }
256
257 func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
258 conn, err := s.TryDial(config)
259 if err != nil {
260 s.t.Fatalf("ssh.Client: %v", err)
261 }
262 return conn
263 }
264
265 func writeFile(path string, contents []byte) {
266 f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
267 if err != nil {
268 panic(err)
269 }
270 defer f.Close()
271 if _, err := f.Write(contents); err != nil {
272 panic(err)
273 }
274 }
275
276
277 func randomPassword() (string, error) {
278 b := make([]byte, 12)
279 _, err := rand.Read(b)
280 if err != nil {
281 return "", err
282 }
283 return base64.RawURLEncoding.EncodeToString(b), nil
284 }
285
286
287
288 func (s *server) setTestPassword(user, passwd string) error {
289 wd, _ := os.Getwd()
290 wrapper := filepath.Join(wd, "sshd_test_pw.so")
291 if _, err := os.Stat(wrapper); err != nil {
292 s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
293 return err
294 }
295
296 s.sshdTestPwSo = wrapper
297 s.testUser = user
298 s.testPasswd = passwd
299 return nil
300 }
301
302
303 func newServer(t *testing.T) *server {
304 return newServerForConfig(t, "default", map[string]string{})
305 }
306
307
308 func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
309 if testing.Short() {
310 t.Skip("skipping test due to -short")
311 }
312 u, err := user.Current()
313 if err != nil {
314 t.Fatalf("user.Current: %v", err)
315 }
316 uname := u.Name
317 if uname == "" {
318
319
320 uname = u.Username
321 }
322 if uname == "root" {
323 t.Skip("skipping test because current user is root")
324 }
325 dir, err := os.MkdirTemp("", "sshtest")
326 if err != nil {
327 t.Fatal(err)
328 }
329 f, err := os.Create(filepath.Join(dir, "sshd_config"))
330 if err != nil {
331 t.Fatal(err)
332 }
333 if _, ok := configTmpl[config]; ok == false {
334 t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
335 }
336 configVars["Dir"] = dir
337 err = configTmpl[config].Execute(f, configVars)
338 if err != nil {
339 t.Fatal(err)
340 }
341 f.Close()
342
343 writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
344
345 for k, v := range testdata.PEMBytes {
346 filename := "id_" + k
347 writeFile(filepath.Join(dir, filename), v)
348 writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
349 }
350
351 for k, v := range testdata.SSHCertificates {
352 filename := "id_" + k + "-cert.pub"
353 writeFile(filepath.Join(dir, filename), v)
354 }
355
356 var authkeys bytes.Buffer
357 for k := range testdata.PEMBytes {
358 authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
359 }
360 writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
361 t.Cleanup(func() {
362 if err := os.RemoveAll(dir); err != nil {
363 t.Error(err)
364 }
365 })
366
367 return &server{
368 t: t,
369 configfile: f.Name(),
370 }
371 }
372
373 func newTempSocket(t *testing.T) (string, func()) {
374 dir, err := os.MkdirTemp("", "socket")
375 if err != nil {
376 t.Fatal(err)
377 }
378 deferFunc := func() { os.RemoveAll(dir) }
379 addr := filepath.Join(dir, "sock")
380 return addr, deferFunc
381 }
382
View as plain text