1
16
17 package remote
18
19 import (
20 "flag"
21 "fmt"
22 "os"
23 "os/exec"
24 "os/user"
25 "strings"
26 "sync"
27
28 "k8s.io/klog/v2"
29 )
30
31 var sshOptions = flag.String("ssh-options", "", "Commandline options passed to ssh.")
32 var sshEnv = flag.String("ssh-env", "", "Use predefined ssh options for environment. Options: gce")
33 var sshKey = flag.String("ssh-key", "", "Path to ssh private key.")
34 var sshUser = flag.String("ssh-user", "", "Use predefined user for ssh.")
35
36 var sshOptionsMap map[string]string
37 var sshDefaultKeyMap map[string]string
38 var sshDefaultUserMap map[string]string
39
40 func init() {
41 usr, err := user.Current()
42 if err != nil {
43 klog.Fatal(err)
44 }
45 sshOptionsMap = map[string]string{
46 "gce": "-o UserKnownHostsFile=/dev/null -o IdentitiesOnly=yes -o CheckHostIP=no -o StrictHostKeyChecking=no -o ServerAliveInterval=30 -o LogLevel=ERROR",
47 "aws": "-o UserKnownHostsFile=/dev/null -o IdentitiesOnly=yes -o CheckHostIP=no -o StrictHostKeyChecking=no -o ServerAliveInterval=30 -o LogLevel=ERROR",
48 }
49 defaultGceKey := os.Getenv("GCE_SSH_PRIVATE_KEY_FILE")
50 if defaultGceKey == "" {
51 defaultGceKey = fmt.Sprintf("%s/.ssh/google_compute_engine", usr.HomeDir)
52 }
53 sshDefaultKeyMap = map[string]string{
54 "gce": defaultGceKey,
55 }
56 sshDefaultUserMap = map[string]string{
57 "aws": "ec2-user",
58 }
59 }
60
61 var hostnameIPOverrides = struct {
62 sync.RWMutex
63 m map[string]string
64 }{m: make(map[string]string)}
65
66
67 func AddHostnameIP(hostname, ip string) {
68 hostnameIPOverrides.Lock()
69 defer hostnameIPOverrides.Unlock()
70 hostnameIPOverrides.m[hostname] = ip
71 }
72
73 var sshKeyOverrides = struct {
74 sync.RWMutex
75 m map[string]string
76 }{m: make(map[string]string)}
77
78
79 func AddSSHKey(hostname, keyFilePath string) {
80 sshKeyOverrides.Lock()
81 defer sshKeyOverrides.Unlock()
82 sshKeyOverrides.m[hostname] = keyFilePath
83 }
84
85
86
87 func GetSSHUser() string {
88 if *sshUser == "" {
89 *sshUser = os.Getenv("KUBE_SSH_USER")
90 }
91 if *sshUser == "" {
92 *sshUser = sshDefaultUserMap[*sshEnv]
93 }
94 return *sshUser
95 }
96
97
98 func GetHostnameOrIP(hostname string) string {
99 hostnameIPOverrides.RLock()
100 defer hostnameIPOverrides.RUnlock()
101 host := hostname
102 if ip, found := hostnameIPOverrides.m[hostname]; found {
103 host = ip
104 }
105
106 sshUser := GetSSHUser()
107 if sshUser != "" {
108 host = fmt.Sprintf("%s@%s", sshUser, host)
109 }
110 return host
111 }
112
113
114 func getSSHCommand(sep string, args ...string) string {
115 return fmt.Sprintf("'%s'", strings.Join(args, sep))
116 }
117
118
119
120 func SSH(host string, cmd ...string) (string, error) {
121 return runSSHCommand(host, "ssh", append([]string{GetHostnameOrIP(host), "--", "sudo"}, cmd...)...)
122 }
123
124
125
126 func SSHNoSudo(host string, cmd ...string) (string, error) {
127 return runSSHCommand(host, "ssh", append([]string{GetHostnameOrIP(host), "--"}, cmd...)...)
128 }
129
130
131 func runSSHCommand(host, cmd string, args ...string) (string, error) {
132 if key, err := getPrivateSSHKey(host); len(key) != 0 {
133 if err != nil {
134 klog.Errorf("private SSH key (%s) not found. Check if the SSH key is configured properly:, err: %v", key, err)
135 return "", fmt.Errorf("private SSH key (%s) does not exist", key)
136 }
137
138 args = append([]string{"-i", key}, args...)
139 }
140 if env, found := sshOptionsMap[*sshEnv]; found {
141 args = append(strings.Split(env, " "), args...)
142 }
143 if *sshOptions != "" {
144 args = append(strings.Split(*sshOptions, " "), args...)
145 }
146 klog.Infof("Running the command %s, with args: %v", cmd, args)
147 output, err := exec.Command(cmd, args...).CombinedOutput()
148 if err != nil {
149 klog.Errorf("failed to run SSH command: out: %s, err: %v", output, err)
150 return string(output), fmt.Errorf("command [%s %s] failed with error: %w", cmd, strings.Join(args, " "), err)
151 }
152 return string(output), nil
153 }
154
155
156 func getPrivateSSHKey(host string) (string, error) {
157 if *sshKey != "" {
158 if _, err := os.Stat(*sshKey); err != nil {
159 return *sshKey, err
160 }
161
162 return *sshKey, nil
163 }
164
165 sshKeyOverrides.Lock()
166 defer sshKeyOverrides.Unlock()
167 if key, ok := sshKeyOverrides.m[host]; ok {
168 return key, nil
169 }
170
171 if key, found := sshDefaultKeyMap[*sshEnv]; found {
172 if _, err := os.Stat(key); err != nil {
173 return key, err
174 }
175
176 return key, nil
177 }
178
179 return "", nil
180 }
181
View as plain text