1
2
3
4
5 package ssh
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net"
13 "reflect"
14 "strings"
15 "sync/atomic"
16 "testing"
17 "time"
18 )
19
20 func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) {
21 for _, tt := range []struct {
22 name string
23 key Signer
24 wantError bool
25 }{
26 {"rsa", testSigners["rsa"], false},
27 {"dsa", testSigners["dsa"], true},
28 {"ed25519", testSigners["ed25519"], true},
29 } {
30 c1, c2, err := netPipe()
31 if err != nil {
32 t.Fatalf("netPipe: %v", err)
33 }
34 defer c1.Close()
35 defer c2.Close()
36 serverConf := &ServerConfig{
37 PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512},
38 PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
39 return nil, nil
40 },
41 }
42 serverConf.AddHostKey(testSigners["ecdsap256"])
43
44 done := make(chan struct{})
45 go func() {
46 defer close(done)
47 NewServerConn(c1, serverConf)
48 }()
49
50 clientConf := ClientConfig{
51 User: "user",
52 Auth: []AuthMethod{
53 PublicKeys(tt.key),
54 },
55 HostKeyCallback: InsecureIgnoreHostKey(),
56 }
57
58 _, _, _, err = NewClientConn(c2, "", &clientConf)
59 if err != nil {
60 if !tt.wantError {
61 t.Errorf("%s: got unexpected error %q", tt.name, err.Error())
62 }
63 } else if tt.wantError {
64 t.Errorf("%s: succeeded, but want error", tt.name)
65 }
66 <-done
67 }
68 }
69
70 func TestMaxAuthTriesNoneMethod(t *testing.T) {
71 username := "testuser"
72 serverConfig := &ServerConfig{
73 MaxAuthTries: 2,
74 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
75 if conn.User() == username && string(password) == clientPassword {
76 return nil, nil
77 }
78 return nil, errors.New("invalid credentials")
79 },
80 }
81 c1, c2, err := netPipe()
82 if err != nil {
83 t.Fatalf("netPipe: %v", err)
84 }
85 defer c1.Close()
86 defer c2.Close()
87
88 var serverAuthErrors []error
89
90 serverConfig.AddHostKey(testSigners["rsa"])
91 serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
92 serverAuthErrors = append(serverAuthErrors, err)
93 }
94 go newServer(c1, serverConfig)
95
96 clientConfig := ClientConfig{
97 User: username,
98 HostKeyCallback: InsecureIgnoreHostKey(),
99 }
100 clientConfig.SetDefaults()
101
102
103 c := &connection{
104 sshConn: sshConn{
105 conn: c2,
106 user: username,
107 clientVersion: []byte(packageVersion),
108 },
109 }
110 c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
111 if err != nil {
112 t.Fatalf("unable to exchange version: %v", err)
113 }
114 c.transport = newClientTransport(
115 newTransport(c.sshConn.conn, clientConfig.Rand, true ),
116 c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr())
117 if err := c.transport.waitSession(); err != nil {
118 t.Fatalf("unable to wait session: %v", err)
119 }
120 c.sessionID = c.transport.getSessionID()
121 if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
122 t.Fatalf("unable to send ssh-userauth message: %v", err)
123 }
124 packet, err := c.transport.readPacket()
125 if err != nil {
126 t.Fatal(err)
127 }
128 if len(packet) > 0 && packet[0] == msgExtInfo {
129 packet, err = c.transport.readPacket()
130 if err != nil {
131 t.Fatal(err)
132 }
133 }
134 var serviceAccept serviceAcceptMsg
135 if err := Unmarshal(packet, &serviceAccept); err != nil {
136 t.Fatal(err)
137 }
138 for i := 0; i <= serverConfig.MaxAuthTries; i++ {
139 auth := new(noneAuth)
140 _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil)
141 if i < serverConfig.MaxAuthTries {
142 if err != nil {
143 t.Fatal(err)
144 }
145 continue
146 }
147 if err == nil {
148 t.Fatal("client: got no error")
149 } else if !strings.Contains(err.Error(), "too many authentication failures") {
150 t.Fatalf("client: got unexpected error: %v", err)
151 }
152 }
153 if len(serverAuthErrors) != 3 {
154 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
155 }
156 for _, err := range serverAuthErrors {
157 if !errors.Is(err, ErrNoAuth) {
158 t.Errorf("go error: %v; want: %v", err, ErrNoAuth)
159 }
160 }
161 }
162
163 func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) {
164 username := "testuser"
165 serverConfig := &ServerConfig{
166 MaxAuthTries: 1,
167 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
168 if conn.User() == username && string(password) == clientPassword {
169 return nil, nil
170 }
171 return nil, errors.New("invalid credentials")
172 },
173 }
174 clientConfig := &ClientConfig{
175 User: username,
176 Auth: []AuthMethod{
177 Password(clientPassword),
178 },
179 HostKeyCallback: InsecureIgnoreHostKey(),
180 }
181
182 serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig)
183 if err != nil {
184 t.Fatalf("client login error: %s", err)
185 }
186 if len(serverAuthErrors) != 2 {
187 t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors)
188 }
189 if !errors.Is(serverAuthErrors[0], ErrNoAuth) {
190 t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth)
191 }
192 if serverAuthErrors[1] != nil {
193 t.Errorf("unexpected error: %v", serverAuthErrors[1])
194 }
195 }
196
197 func TestNewServerConnValidationErrors(t *testing.T) {
198 serverConf := &ServerConfig{
199 PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01},
200 }
201 c := &markerConn{}
202 _, _, _, err := NewServerConn(c, serverConf)
203 if err == nil {
204 t.Fatal("NewServerConn with invalid public key auth algorithms succeeded")
205 }
206 if !c.isClosed() {
207 t.Fatal("NewServerConn with invalid public key auth algorithms left connection open")
208 }
209 if c.isUsed() {
210 t.Fatal("NewServerConn with invalid public key auth algorithms used connection")
211 }
212
213 serverConf = &ServerConfig{
214 Config: Config{
215 KeyExchanges: []string{kexAlgoDHGEXSHA256},
216 },
217 }
218 c = &markerConn{}
219 _, _, _, err = NewServerConn(c, serverConf)
220 if err == nil {
221 t.Fatal("NewServerConn with unsupported key exchange succeeded")
222 }
223 if !c.isClosed() {
224 t.Fatal("NewServerConn with unsupported key exchange left connection open")
225 }
226 if c.isUsed() {
227 t.Fatal("NewServerConn with unsupported key exchange used connection")
228 }
229 }
230
231 func TestBannerError(t *testing.T) {
232 serverConfig := &ServerConfig{
233 BannerCallback: func(ConnMetadata) string {
234 return "banner from BannerCallback"
235 },
236 NoClientAuth: true,
237 NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
238 err := &BannerError{
239 Err: errors.New("error from NoClientAuthCallback"),
240 Message: "banner from NoClientAuthCallback",
241 }
242 return nil, fmt.Errorf("wrapped: %w", err)
243 },
244 PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
245 return &Permissions{}, nil
246 },
247 PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
248 return nil, &BannerError{
249 Err: errors.New("error from PublicKeyCallback"),
250 Message: "banner from PublicKeyCallback",
251 }
252 },
253 KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
254 return nil, &BannerError{
255 Err: nil,
256 Message: "banner from KeyboardInteractiveCallback",
257 }
258 },
259 }
260 serverConfig.AddHostKey(testSigners["rsa"])
261
262 var banners []string
263 clientConfig := &ClientConfig{
264 User: "test",
265 Auth: []AuthMethod{
266 PublicKeys(testSigners["rsa"]),
267 KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
268 return []string{"letmein"}, nil
269 }),
270 Password(clientPassword),
271 },
272 HostKeyCallback: InsecureIgnoreHostKey(),
273 BannerCallback: func(msg string) error {
274 banners = append(banners, msg)
275 return nil
276 },
277 }
278
279 c1, c2, err := netPipe()
280 if err != nil {
281 t.Fatalf("netPipe: %v", err)
282 }
283 defer c1.Close()
284 defer c2.Close()
285 go newServer(c1, serverConfig)
286 c, _, _, err := NewClientConn(c2, "", clientConfig)
287 if err != nil {
288 t.Fatalf("client connection failed: %v", err)
289 }
290 defer c.Close()
291
292 wantBanners := []string{
293 "banner from BannerCallback",
294 "banner from NoClientAuthCallback",
295 "banner from PublicKeyCallback",
296 "banner from KeyboardInteractiveCallback",
297 }
298 if !reflect.DeepEqual(banners, wantBanners) {
299 t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
300 }
301 }
302
303 func TestPublicKeyCallbackLastSeen(t *testing.T) {
304 var lastSeenKey PublicKey
305
306 c1, c2, err := netPipe()
307 if err != nil {
308 t.Fatalf("netPipe: %v", err)
309 }
310 defer c1.Close()
311 defer c2.Close()
312 serverConf := &ServerConfig{
313 PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
314 lastSeenKey = key
315 fmt.Printf("seen %#v\n", key)
316 if _, ok := key.(*dsaPublicKey); !ok {
317 return nil, errors.New("nope")
318 }
319 return nil, nil
320 },
321 }
322 serverConf.AddHostKey(testSigners["ecdsap256"])
323
324 done := make(chan struct{})
325 go func() {
326 defer close(done)
327 NewServerConn(c1, serverConf)
328 }()
329
330 clientConf := ClientConfig{
331 User: "user",
332 Auth: []AuthMethod{
333 PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]),
334 },
335 HostKeyCallback: InsecureIgnoreHostKey(),
336 }
337
338 _, _, _, err = NewClientConn(c2, "", &clientConf)
339 if err != nil {
340 t.Fatal(err)
341 }
342 <-done
343
344 expectedPublicKey := testSigners["dsa"].PublicKey().Marshal()
345 lastSeenMarshalled := lastSeenKey.Marshal()
346 if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) {
347 t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey())
348 }
349 }
350
351 type markerConn struct {
352 closed uint32
353 used uint32
354 }
355
356 func (c *markerConn) isClosed() bool {
357 return atomic.LoadUint32(&c.closed) != 0
358 }
359
360 func (c *markerConn) isUsed() bool {
361 return atomic.LoadUint32(&c.used) != 0
362 }
363
364 func (c *markerConn) Close() error {
365 atomic.StoreUint32(&c.closed, 1)
366 return nil
367 }
368
369 func (c *markerConn) Read(b []byte) (n int, err error) {
370 atomic.StoreUint32(&c.used, 1)
371 if atomic.LoadUint32(&c.closed) != 0 {
372 return 0, net.ErrClosed
373 } else {
374 return 0, io.EOF
375 }
376 }
377
378 func (c *markerConn) Write(b []byte) (n int, err error) {
379 atomic.StoreUint32(&c.used, 1)
380 if atomic.LoadUint32(&c.closed) != 0 {
381 return 0, net.ErrClosed
382 } else {
383 return 0, io.ErrClosedPipe
384 }
385 }
386
387 func (*markerConn) LocalAddr() net.Addr { return nil }
388 func (*markerConn) RemoteAddr() net.Addr { return nil }
389
390 func (*markerConn) SetDeadline(t time.Time) error { return nil }
391 func (*markerConn) SetReadDeadline(t time.Time) error { return nil }
392 func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }
393
View as plain text