1
16
17 package ssh
18
19 import (
20 "bytes"
21 "context"
22 "fmt"
23 "net"
24 "os"
25 "path/filepath"
26 "sync"
27 "time"
28
29 "github.com/onsi/gomega"
30
31 "golang.org/x/crypto/ssh"
32
33 v1 "k8s.io/api/core/v1"
34 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
35 "k8s.io/apimachinery/pkg/fields"
36 "k8s.io/apimachinery/pkg/util/wait"
37 clientset "k8s.io/client-go/kubernetes"
38 "k8s.io/kubernetes/test/e2e/framework"
39 )
40
41 const (
42
43 SSHPort = "22"
44
45
46 pollNodeInterval = 2 * time.Second
47
48
49
50 singleCallTimeout = 5 * time.Minute
51
52
53 sshBastionEnvKey = "KUBE_SSH_BASTION"
54 )
55
56
57
58 func GetSigner(provider string) (ssh.Signer, error) {
59
60 if path := os.Getenv("KUBE_SSH_KEY_PATH"); len(path) > 0 {
61 return makePrivateKeySignerFromFile(path)
62 }
63
64
65
66
67 keyfile := ""
68 switch provider {
69 case "gce", "gke", "kubemark":
70 keyfile = os.Getenv("GCE_SSH_KEY")
71 if keyfile == "" {
72 keyfile = "google_compute_engine"
73 }
74 case "aws", "eks":
75 keyfile = os.Getenv("AWS_SSH_KEY")
76 if keyfile == "" {
77 keyfile = "kube_aws_rsa"
78 }
79 case "local", "vsphere":
80 keyfile = os.Getenv("LOCAL_SSH_KEY")
81 if keyfile == "" {
82 keyfile = "id_rsa"
83 }
84 case "skeleton":
85 keyfile = os.Getenv("KUBE_SSH_KEY")
86 if keyfile == "" {
87 keyfile = "id_rsa"
88 }
89 case "azure":
90 keyfile = os.Getenv("AZURE_SSH_KEY")
91 if keyfile == "" {
92 keyfile = "id_rsa"
93 }
94 default:
95 return nil, fmt.Errorf("GetSigner(...) not implemented for %s", provider)
96 }
97
98
99
100 if !filepath.IsAbs(keyfile) {
101 keydir := filepath.Join(os.Getenv("HOME"), ".ssh")
102 keyfile = filepath.Join(keydir, keyfile)
103 }
104
105 return makePrivateKeySignerFromFile(keyfile)
106 }
107
108 func makePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
109 buffer, err := os.ReadFile(key)
110 if err != nil {
111 return nil, fmt.Errorf("error reading SSH key %s: %w", key, err)
112 }
113
114 signer, err := ssh.ParsePrivateKey(buffer)
115 if err != nil {
116 return nil, fmt.Errorf("error parsing SSH key: %w", err)
117 }
118
119 return signer, err
120 }
121
122
123
124
125
126
127 func NodeSSHHosts(ctx context.Context, c clientset.Interface) ([]string, error) {
128 nodelist := waitListSchedulableNodesOrDie(ctx, c)
129
130 hosts := nodeAddresses(nodelist, v1.NodeExternalIP)
131
132 if len(hosts) < len(nodelist.Items) {
133 framework.Logf("No external IP address on nodes, falling back to internal IPs")
134 hosts = nodeAddresses(nodelist, v1.NodeInternalIP)
135 }
136
137
138 if len(hosts) != len(nodelist.Items) {
139 return hosts, fmt.Errorf(
140 "only found %d IPs on nodes, but found %d nodes. Nodelist: %v",
141 len(hosts), len(nodelist.Items), nodelist)
142 }
143
144 lenHosts := len(hosts)
145 wg := &sync.WaitGroup{}
146 wg.Add(lenHosts)
147 sshHosts := make([]string, 0, lenHosts)
148 var sshHostsLock sync.Mutex
149
150 for _, host := range hosts {
151 go func(host string) {
152 defer wg.Done()
153 if canConnect(host) {
154 framework.Logf("Assuming SSH on host %s", host)
155 sshHostsLock.Lock()
156 sshHosts = append(sshHosts, net.JoinHostPort(host, SSHPort))
157 sshHostsLock.Unlock()
158 } else {
159 framework.Logf("Skipping host %s because it does not run anything on port %s", host, SSHPort)
160 }
161 }(host)
162 }
163 wg.Wait()
164
165 return sshHosts, nil
166 }
167
168
169 func canConnect(host string) bool {
170 if _, ok := os.LookupEnv(sshBastionEnvKey); ok {
171 return true
172 }
173 hostPort := net.JoinHostPort(host, SSHPort)
174 conn, err := net.DialTimeout("tcp", hostPort, 3*time.Second)
175 if err != nil {
176 framework.Logf("cannot dial %s: %v", hostPort, err)
177 return false
178 }
179 conn.Close()
180 return true
181 }
182
183
184 type Result struct {
185 User string
186 Host string
187 Cmd string
188 Stdout string
189 Stderr string
190 Code int
191 }
192
193
194
195
196 func NodeExec(ctx context.Context, nodeName, cmd, provider string) (Result, error) {
197 return SSH(ctx, cmd, net.JoinHostPort(nodeName, SSHPort), provider)
198 }
199
200
201
202
203 func SSH(ctx context.Context, cmd, host, provider string) (Result, error) {
204 result := Result{Host: host, Cmd: cmd}
205
206
207 signer, err := GetSigner(provider)
208 if err != nil {
209 return result, fmt.Errorf("error getting signer for provider %s: %w", provider, err)
210 }
211
212
213
214 result.User = os.Getenv("KUBE_SSH_USER")
215 if result.User == "" {
216 result.User = os.Getenv("USER")
217 }
218
219 if bastion := os.Getenv(sshBastionEnvKey); len(bastion) > 0 {
220 stdout, stderr, code, err := runSSHCommandViaBastion(ctx, cmd, result.User, bastion, host, signer)
221 result.Stdout = stdout
222 result.Stderr = stderr
223 result.Code = code
224 return result, err
225 }
226
227 stdout, stderr, code, err := runSSHCommand(ctx, cmd, result.User, host, signer)
228 result.Stdout = stdout
229 result.Stderr = stderr
230 result.Code = code
231
232 return result, err
233 }
234
235
236
237 func runSSHCommand(ctx context.Context, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
238 if user == "" {
239 user = os.Getenv("USER")
240 }
241
242 config := &ssh.ClientConfig{
243 User: user,
244 Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
245 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
246 }
247 client, err := ssh.Dial("tcp", host, config)
248 if err != nil {
249 err = wait.PollWithContext(ctx, 5*time.Second, 20*time.Second, func(ctx context.Context) (bool, error) {
250 fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err)
251 if client, err = ssh.Dial("tcp", host, config); err != nil {
252 return false, nil
253 }
254 return true, nil
255 })
256 }
257 if err != nil {
258 return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: %w", user, host, err)
259 }
260 defer client.Close()
261 session, err := client.NewSession()
262 if err != nil {
263 return "", "", 0, fmt.Errorf("error creating session to %s@%s: %w", user, host, err)
264 }
265 defer session.Close()
266
267
268 code := 0
269 var bout, berr bytes.Buffer
270 session.Stdout, session.Stderr = &bout, &berr
271 if err = session.Run(cmd); err != nil {
272
273 if exiterr, ok := err.(*ssh.ExitError); ok {
274
275
276
277 if code = exiterr.ExitStatus(); code != 0 {
278 err = nil
279 }
280 } else {
281
282
283 err = fmt.Errorf("failed running `%s` on %s@%s: %w", cmd, user, host, err)
284 }
285 }
286 return bout.String(), berr.String(), code, err
287 }
288
289
290
291
292
293 func runSSHCommandViaBastion(ctx context.Context, cmd, user, bastion, host string, signer ssh.Signer) (string, string, int, error) {
294
295 config := &ssh.ClientConfig{
296 User: user,
297 Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
298 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
299 Timeout: 150 * time.Second,
300 }
301 bastionClient, err := ssh.Dial("tcp", bastion, config)
302 if err != nil {
303 err = wait.PollWithContext(ctx, 5*time.Second, 20*time.Second, func(ctx context.Context) (bool, error) {
304 fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, bastion, err)
305 if bastionClient, err = ssh.Dial("tcp", bastion, config); err != nil {
306 return false, err
307 }
308 return true, nil
309 })
310 }
311 if err != nil {
312 return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: %w", user, bastion, err)
313 }
314 defer bastionClient.Close()
315
316 conn, err := bastionClient.Dial("tcp", host)
317 if err != nil {
318 return "", "", 0, fmt.Errorf("error dialing %s from bastion: %w", host, err)
319 }
320 defer conn.Close()
321
322 ncc, chans, reqs, err := ssh.NewClientConn(conn, host, config)
323 if err != nil {
324 return "", "", 0, fmt.Errorf("error creating forwarding connection %s from bastion: %w", host, err)
325 }
326 client := ssh.NewClient(ncc, chans, reqs)
327 defer client.Close()
328
329 session, err := client.NewSession()
330 if err != nil {
331 return "", "", 0, fmt.Errorf("error creating session to %s@%s from bastion: %w", user, host, err)
332 }
333 defer session.Close()
334
335
336 code := 0
337 var bout, berr bytes.Buffer
338 session.Stdout, session.Stderr = &bout, &berr
339 if err = session.Run(cmd); err != nil {
340
341 if exiterr, ok := err.(*ssh.ExitError); ok {
342
343
344
345 if code = exiterr.ExitStatus(); code != 0 {
346 err = nil
347 }
348 } else {
349
350
351 err = fmt.Errorf("failed running `%s` on %s@%s: %w", cmd, user, host, err)
352 }
353 }
354 return bout.String(), berr.String(), code, err
355 }
356
357
358 func LogResult(result Result) {
359 remote := fmt.Sprintf("%s@%s", result.User, result.Host)
360 framework.Logf("ssh %s: command: %s", remote, result.Cmd)
361 framework.Logf("ssh %s: stdout: %q", remote, result.Stdout)
362 framework.Logf("ssh %s: stderr: %q", remote, result.Stderr)
363 framework.Logf("ssh %s: exit code: %d", remote, result.Code)
364 }
365
366
367 func IssueSSHCommandWithResult(ctx context.Context, cmd, provider string, node *v1.Node) (*Result, error) {
368 framework.Logf("Getting external IP address for %s", node.Name)
369 host := ""
370 for _, a := range node.Status.Addresses {
371 if a.Type == v1.NodeExternalIP && a.Address != "" {
372 host = net.JoinHostPort(a.Address, SSHPort)
373 break
374 }
375 }
376
377 if host == "" {
378
379 for _, a := range node.Status.Addresses {
380 if a.Type == v1.NodeInternalIP && a.Address != "" {
381 host = net.JoinHostPort(a.Address, SSHPort)
382 break
383 }
384 }
385 }
386
387 if host == "" {
388 return nil, fmt.Errorf("couldn't find any IP address for node %s", node.Name)
389 }
390
391 framework.Logf("SSH %q on %s(%s)", cmd, node.Name, host)
392 result, err := SSH(ctx, cmd, host, provider)
393 LogResult(result)
394
395 if result.Code != 0 || err != nil {
396 return nil, fmt.Errorf("failed running %q: %v (exit code %d, stderr %v)",
397 cmd, err, result.Code, result.Stderr)
398 }
399
400 return &result, nil
401 }
402
403
404 func IssueSSHCommand(ctx context.Context, cmd, provider string, node *v1.Node) error {
405 _, err := IssueSSHCommandWithResult(ctx, cmd, provider, node)
406 if err != nil {
407 return err
408 }
409 return nil
410 }
411
412
413 func nodeAddresses(nodelist *v1.NodeList, addrType v1.NodeAddressType) []string {
414 hosts := []string{}
415 for _, n := range nodelist.Items {
416 for _, addr := range n.Status.Addresses {
417 if addr.Type == addrType && addr.Address != "" {
418 hosts = append(hosts, addr.Address)
419 break
420 }
421 }
422 }
423 return hosts
424 }
425
426
427 func waitListSchedulableNodes(ctx context.Context, c clientset.Interface) (*v1.NodeList, error) {
428 var nodes *v1.NodeList
429 var err error
430 if wait.PollUntilContextTimeout(ctx, pollNodeInterval, singleCallTimeout, true, func(ctx context.Context) (bool, error) {
431 nodes, err = c.CoreV1().Nodes().List(ctx, metav1.ListOptions{FieldSelector: fields.Set{
432 "spec.unschedulable": "false",
433 }.AsSelector().String()})
434 if err != nil {
435 return false, err
436 }
437 return true, nil
438 }) != nil {
439 return nodes, err
440 }
441 return nodes, nil
442 }
443
444
445 func waitListSchedulableNodesOrDie(ctx context.Context, c clientset.Interface) *v1.NodeList {
446 nodes, err := waitListSchedulableNodes(ctx, c)
447 if err != nil {
448 expectNoError(err, "Non-retryable failure or timed out while listing nodes for e2e cluster.")
449 }
450 return nodes
451 }
452
453
454 func expectNoError(err error, explain ...interface{}) {
455 expectNoErrorWithOffset(1, err, explain...)
456 }
457
458
459
460 func expectNoErrorWithOffset(offset int, err error, explain ...interface{}) {
461 if err != nil {
462 framework.Logf("Unexpected error occurred: %v", err)
463 }
464 gomega.ExpectWithOffset(1+offset, err).NotTo(gomega.HaveOccurred(), explain...)
465 }
466
View as plain text