1
2
3
4
5 package agent
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "errors"
11 "io"
12 "net"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "runtime"
17 "strconv"
18 "strings"
19 "testing"
20 "time"
21
22 "golang.org/x/crypto/ssh"
23 )
24
25
26 func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
27 if testing.Short() {
28
29
30 t.Skip("skipping test due to -short")
31 }
32 if runtime.GOOS == "windows" {
33 t.Skip("skipping on windows, we don't support connecting to the ssh-agent via a named pipe")
34 }
35
36 bin, err := exec.LookPath("ssh-agent")
37 if err != nil {
38 t.Skip("could not find ssh-agent")
39 }
40
41 cmd := exec.Command(bin, "-s")
42 cmd.Env = []string{}
43 cmd.Stderr = new(bytes.Buffer)
44 out, err := cmd.Output()
45 if err != nil {
46 t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
47 }
48
49
50
51
52
53
54
55 fields := bytes.Split(out, []byte(";"))
56 line := bytes.SplitN(fields[0], []byte("="), 2)
57 line[0] = bytes.TrimLeft(line[0], "\n")
58 if string(line[0]) != "SSH_AUTH_SOCK" {
59 t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
60 }
61 socket = string(line[1])
62
63 line = bytes.SplitN(fields[2], []byte("="), 2)
64 line[0] = bytes.TrimLeft(line[0], "\n")
65 if string(line[0]) != "SSH_AGENT_PID" {
66 t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
67 }
68 pidStr := line[1]
69 pid, err := strconv.Atoi(string(pidStr))
70 if err != nil {
71 t.Fatalf("Atoi(%q): %v", pidStr, err)
72 }
73
74 conn, err := net.Dial("unix", string(socket))
75 if err != nil {
76 t.Fatalf("net.Dial: %v", err)
77 }
78
79 ac := NewClient(conn)
80 return ac, socket, func() {
81 proc, _ := os.FindProcess(pid)
82 if proc != nil {
83 proc.Kill()
84 }
85 conn.Close()
86 os.RemoveAll(filepath.Dir(socket))
87 }
88 }
89
90 func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
91 c1, c2, err := netPipe()
92 if err != nil {
93 t.Fatalf("netPipe: %v", err)
94 }
95 go ServeAgent(agent, c2)
96
97 return NewClient(c1), func() {
98 c1.Close()
99 c2.Close()
100 }
101 }
102
103
104 func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
105 return startAgent(t, NewKeyring())
106 }
107
108 func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
109 agent, _, cleanup := startOpenSSHAgent(t)
110 defer cleanup()
111
112 testAgentInterface(t, agent, key, cert, lifetimeSecs)
113 }
114
115 func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
116 agent, cleanup := startKeyringAgent(t)
117 defer cleanup()
118
119 testAgentInterface(t, agent, key, cert, lifetimeSecs)
120 }
121
122 func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
123 signer, err := ssh.NewSignerFromKey(key)
124 if err != nil {
125 t.Fatalf("NewSignerFromKey(%T): %v", key, err)
126 }
127
128 if keys, err := agent.List(); err != nil {
129 t.Fatalf("RequestIdentities: %v", err)
130 } else if len(keys) > 0 {
131 t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
132 }
133
134
135 var pubKey ssh.PublicKey
136 if cert != nil {
137 err = agent.Add(AddedKey{
138 PrivateKey: key,
139 Certificate: cert,
140 Comment: "comment",
141 LifetimeSecs: lifetimeSecs,
142 })
143 pubKey = cert
144 } else {
145 err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs})
146 pubKey = signer.PublicKey()
147 }
148 if err != nil {
149 t.Fatalf("insert(%T): %v", key, err)
150 }
151
152
153 if keys, err := agent.List(); err != nil {
154 t.Fatalf("List: %v", err)
155 } else if len(keys) != 1 {
156 t.Fatalf("got %v, want 1 key", keys)
157 } else if keys[0].Comment != "comment" {
158 t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
159 } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
160 t.Fatalf("key mismatch")
161 }
162
163
164 data := []byte("hello")
165 sig, err := agent.Sign(pubKey, data)
166 if err != nil {
167 t.Logf("sign failed with key type %q", pubKey.Type())
168
169
170 if pubKey.Type() != ssh.KeyAlgoRSA && pubKey.Type() != ssh.CertAlgoRSAv01 {
171 t.Fatalf("Sign(%s): %v", pubKey.Type(), err)
172 }
173 } else {
174 if err := pubKey.Verify(data, sig); err != nil {
175 t.Logf("verify failed with key type %q", pubKey.Type())
176 if pubKey.Type() != ssh.KeyAlgoRSA {
177 t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
178 }
179 }
180 }
181
182
183 if pubKey.Type() == ssh.KeyAlgoRSA {
184 sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
185 sig, err = agent.SignWithFlags(pubKey, data, flag)
186 if err != nil {
187 t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
188 }
189 if sig.Format != expectedSigFormat {
190 t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
191 }
192 if err := pubKey.Verify(data, sig); err != nil {
193 t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
194 }
195 }
196 sshFlagTest(SignatureFlagRsaSha256, ssh.KeyAlgoRSASHA256)
197 sshFlagTest(SignatureFlagRsaSha512, ssh.KeyAlgoRSASHA512)
198 }
199
200
201 if lifetimeSecs > 0 {
202 time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
203 keys, err := agent.List()
204 if err != nil {
205 t.Fatalf("List: %v", err)
206 }
207 if len(keys) > 0 {
208 t.Fatalf("key not expired")
209 }
210 }
211
212 }
213
214 func TestMalformedRequests(t *testing.T) {
215 keyringAgent := NewKeyring()
216
217 testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
218 c, s := net.Pipe()
219 defer c.Close()
220 defer s.Close()
221 go func() {
222 _, err := c.Write(requestBytes)
223 if err != nil {
224 t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
225 }
226 c.Close()
227 }()
228 err := ServeAgent(keyringAgent, s)
229 if err == nil {
230 t.Error("ServeAgent should have returned an error to malformed input")
231 } else {
232 if (err != io.EOF) != wantServerErr {
233 t.Errorf("ServeAgent returned expected error: %v", err)
234 }
235 }
236 }
237
238 var testCases = []struct {
239 name string
240 requestBytes []byte
241 wantServerErr bool
242 }{
243 {"Empty request", []byte{}, false},
244 {"Short header", []byte{0x00}, true},
245 {"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
246 {"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
247 }
248 for _, tc := range testCases {
249 t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
250 }
251 }
252
253 func TestAgent(t *testing.T) {
254 for _, keyType := range []string{"rsa", "ecdsa", "ed25519"} {
255 testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
256 testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
257 }
258 }
259
260 func TestCert(t *testing.T) {
261 cert := &ssh.Certificate{
262 Key: testPublicKeys["rsa"],
263 ValidBefore: ssh.CertTimeInfinity,
264 CertType: ssh.UserCert,
265 }
266 cert.SignCert(rand.Reader, testSigners["ecdsa"])
267
268 testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0)
269 testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
270 }
271
272
273 func netListener() (net.Listener, error) {
274 listener, err := net.Listen("tcp", "127.0.0.1:0")
275 if err != nil {
276 listener, err = net.Listen("tcp", "[::1]:0")
277 if err != nil {
278 return nil, err
279 }
280 }
281 return listener, nil
282 }
283
284
285
286
287 func netPipe() (net.Conn, net.Conn, error) {
288 listener, err := netListener()
289 if err != nil {
290 return nil, nil, err
291 }
292 defer listener.Close()
293 c1, err := net.Dial("tcp", listener.Addr().String())
294 if err != nil {
295 return nil, nil, err
296 }
297
298 c2, err := listener.Accept()
299 if err != nil {
300 c1.Close()
301 return nil, nil, err
302 }
303
304 return c1, c2, nil
305 }
306
307 func TestServerResponseTooLarge(t *testing.T) {
308 a, b, err := netPipe()
309 if err != nil {
310 t.Fatalf("netPipe: %v", err)
311 }
312 done := make(chan struct{})
313 defer func() { <-done }()
314
315 defer a.Close()
316 defer b.Close()
317
318 var response identitiesAnswerAgentMsg
319 response.NumKeys = 1
320 response.Keys = make([]byte, maxAgentResponseBytes+1)
321
322 agent := NewClient(a)
323 go func() {
324 defer close(done)
325 n, err := b.Write(ssh.Marshal(response))
326 if n < 4 {
327 if runtime.GOOS == "plan9" {
328 if e1, ok := err.(*net.OpError); ok {
329 if e2, ok := e1.Err.(*os.PathError); ok {
330 switch e2.Err.Error() {
331 case "Hangup", "i/o on hungup channel":
332
333 return
334 }
335 }
336 }
337 }
338 t.Errorf("At least 4 bytes (the response size) should have been successfully written: %d < 4: %v", n, err)
339 }
340 }()
341 _, err = agent.List()
342 if err == nil {
343 t.Fatal("Did not get error result")
344 }
345 if err.Error() != "agent: client error: response too large" {
346 t.Fatal("Did not get expected error result")
347 }
348 }
349
350 func TestAuth(t *testing.T) {
351 agent, _, cleanup := startOpenSSHAgent(t)
352 defer cleanup()
353
354 a, b, err := netPipe()
355 if err != nil {
356 t.Fatalf("netPipe: %v", err)
357 }
358
359 defer a.Close()
360 defer b.Close()
361
362 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
363 t.Errorf("Add: %v", err)
364 }
365
366 serverConf := ssh.ServerConfig{}
367 serverConf.AddHostKey(testSigners["rsa"])
368 serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
369 if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
370 return nil, nil
371 }
372
373 return nil, errors.New("pubkey rejected")
374 }
375
376 go func() {
377 conn, _, _, err := ssh.NewServerConn(a, &serverConf)
378 if err != nil {
379 t.Errorf("NewServerConn error: %v", err)
380 return
381 }
382 conn.Close()
383 }()
384
385 conf := ssh.ClientConfig{
386 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
387 }
388 conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
389 conn, _, _, err := ssh.NewClientConn(b, "", &conf)
390 if err != nil {
391 t.Fatalf("NewClientConn: %v", err)
392 }
393 conn.Close()
394 }
395
396 func TestLockOpenSSHAgent(t *testing.T) {
397 agent, _, cleanup := startOpenSSHAgent(t)
398 defer cleanup()
399 testLockAgent(agent, t)
400 }
401
402 func TestLockKeyringAgent(t *testing.T) {
403 agent, cleanup := startKeyringAgent(t)
404 defer cleanup()
405 testLockAgent(agent, t)
406 }
407
408 func testLockAgent(agent Agent, t *testing.T) {
409 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
410 t.Errorf("Add: %v", err)
411 }
412 if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["ecdsa"], Comment: "comment ecdsa"}); err != nil {
413 t.Errorf("Add: %v", err)
414 }
415 if keys, err := agent.List(); err != nil {
416 t.Errorf("List: %v", err)
417 } else if len(keys) != 2 {
418 t.Errorf("Want 2 keys, got %v", keys)
419 }
420
421 passphrase := []byte("secret")
422 if err := agent.Lock(passphrase); err != nil {
423 t.Errorf("Lock: %v", err)
424 }
425
426 if keys, err := agent.List(); err != nil {
427 t.Errorf("List: %v", err)
428 } else if len(keys) != 0 {
429 t.Errorf("Want 0 keys, got %v", keys)
430 }
431
432 signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
433 if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
434 t.Fatalf("Sign did not fail")
435 }
436
437 if err := agent.Remove(signer.PublicKey()); err == nil {
438 t.Fatalf("Remove did not fail")
439 }
440
441 if err := agent.RemoveAll(); err == nil {
442 t.Fatalf("RemoveAll did not fail")
443 }
444
445 if err := agent.Unlock(nil); err == nil {
446 t.Errorf("Unlock with wrong passphrase succeeded")
447 }
448 if err := agent.Unlock(passphrase); err != nil {
449 t.Errorf("Unlock: %v", err)
450 }
451
452 if err := agent.Remove(signer.PublicKey()); err != nil {
453 t.Fatalf("Remove: %v", err)
454 }
455
456 if keys, err := agent.List(); err != nil {
457 t.Errorf("List: %v", err)
458 } else if len(keys) != 1 {
459 t.Errorf("Want 1 keys, got %v", keys)
460 }
461 }
462
463 func testOpenSSHAgentLifetime(t *testing.T) {
464 agent, _, cleanup := startOpenSSHAgent(t)
465 defer cleanup()
466 testAgentLifetime(t, agent)
467 }
468
469 func testKeyringAgentLifetime(t *testing.T) {
470 agent, cleanup := startKeyringAgent(t)
471 defer cleanup()
472 testAgentLifetime(t, agent)
473 }
474
475 func testAgentLifetime(t *testing.T, agent Agent) {
476 for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
477
478 err := agent.Add(AddedKey{
479 PrivateKey: testPrivateKeys[keyType],
480 Comment: "comment",
481 LifetimeSecs: 1,
482 })
483 if err != nil {
484 t.Fatalf("add: %v", err)
485 }
486
487 cert := &ssh.Certificate{
488 Key: testPublicKeys[keyType],
489 ValidBefore: ssh.CertTimeInfinity,
490 CertType: ssh.UserCert,
491 }
492 cert.SignCert(rand.Reader, testSigners[keyType])
493 err = agent.Add(AddedKey{
494 PrivateKey: testPrivateKeys[keyType],
495 Certificate: cert,
496 Comment: "comment",
497 LifetimeSecs: 1,
498 })
499 if err != nil {
500 t.Fatalf("add: %v", err)
501 }
502 }
503 time.Sleep(1100 * time.Millisecond)
504 if keys, err := agent.List(); err != nil {
505 t.Errorf("List: %v", err)
506 } else if len(keys) != 0 {
507 t.Errorf("Want 0 keys, got %v", len(keys))
508 }
509 }
510
511 type keyringExtended struct {
512 *keyring
513 }
514
515 func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
516 if extensionType != "my-extension@example.com" {
517 return []byte{agentExtensionFailure}, nil
518 }
519 return append([]byte{agentSuccess}, contents...), nil
520 }
521
522 func TestAgentExtensions(t *testing.T) {
523 agent, _, cleanup := startOpenSSHAgent(t)
524 defer cleanup()
525 _, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
526 if err == nil {
527 t.Fatal("should have gotten agent extension failure")
528 }
529
530 agent, cleanup = startAgent(t, &keyringExtended{})
531 defer cleanup()
532 result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
533 if err != nil {
534 t.Fatalf("agent extension failure: %v", err)
535 }
536 if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
537 t.Fatalf("agent extension result invalid: %v", result)
538 }
539
540 _, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
541 if err == nil {
542 t.Fatal("should have gotten agent extension failure")
543 }
544 }
545
View as plain text