1
2
3
4
19
20 package main
21
22 import (
23 "bytes"
24 cryptorand "crypto/rand"
25 "crypto/rsa"
26 "crypto/x509"
27 "crypto/x509/pkix"
28 "encoding/pem"
29 "flag"
30 "fmt"
31 "math/big"
32 "net"
33 "os"
34 "path/filepath"
35 "strings"
36 "sync"
37 "testing"
38 "time"
39
40 "github.com/blang/semver/v4"
41 "k8s.io/klog/v2"
42 netutils "k8s.io/utils/net"
43 )
44
45 var (
46 testSupportedVersions = mustParseSupportedVersions([]string{"3.0.17", "3.1.12"})
47 testVersionPrevious = &EtcdVersion{semver.MustParse("3.0.17")}
48 testVersionLatest = &EtcdVersion{semver.MustParse("3.1.12")}
49 )
50
51 func init() {
52
53 klog.InitFlags(nil)
54 flag.Set("logtostderr", "true")
55 flag.Set("v", "9")
56 }
57
58 func TestMigrate(t *testing.T) {
59 migrations := []struct {
60 title string
61 memberCount int
62 startVersion string
63 endVersion string
64 protocol string
65 clientListenUrls string
66 }{
67
68 {"v3-v3-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},
69 {"oldest-newest-up", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},
70 {"v3-v3-up-with-additional-client-url", 1, "3.0.17/etcd3", "3.1.12/etcd3", "https", "http://127.0.0.1:2379,http://10.128.0.1:2379"},
71
72
73 {"ha-v3-v3-up", 3, "3.0.17/etcd3", "3.1.12/etcd3", "https", ""},
74
75
76 {"v3-v3-down", 1, "3.1.12/etcd3", "3.0.17/etcd3", "https", ""},
77
78
79 }
80
81 for _, m := range migrations {
82 t.Run(m.title, func(t *testing.T) {
83 start := mustParseEtcdVersionPair(m.startVersion)
84 end := mustParseEtcdVersionPair(m.endVersion)
85
86 testCfgs := clusterConfig(t, m.title, m.memberCount, m.protocol, m.clientListenUrls)
87
88 servers := []*EtcdMigrateServer{}
89 for _, cfg := range testCfgs {
90 client, err := NewEtcdMigrateClient(cfg)
91 if err != nil {
92 t.Fatalf("Failed to create client: %v", err)
93 }
94 server := NewEtcdMigrateServer(cfg, client)
95 servers = append(servers, server)
96 }
97
98
99 parallel(servers, func(server *EtcdMigrateServer) {
100 dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
101 if err != nil {
102 t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
103 }
104 migrator := &Migrator{server.cfg, dataDir, server.client}
105 err = migrator.MigrateIfNeeded(start)
106 if err != nil {
107 t.Fatalf("Migration failed: %v", err)
108 }
109 err = server.Start(start.version)
110 if err != nil {
111 t.Fatalf("Failed to start server: %v", err)
112 }
113 })
114
115
116 parallel(servers, func(server *EtcdMigrateServer) {
117 key := fmt.Sprintf("/registry/%s", server.cfg.name)
118 value := fmt.Sprintf("value-%s", server.cfg.name)
119 err := server.client.Put(start.version, key, value)
120 if err != nil {
121 t.Fatalf("failed to write text value: %v", err)
122 }
123
124 checkVal, err := server.client.Get(start.version, key)
125 if err != nil {
126 t.Errorf("Error getting %s for validation: %v", key, err)
127 }
128 if checkVal != value {
129 t.Errorf("Expected %s from %s but got %s", value, key, checkVal)
130 }
131 })
132
133
134 serial(servers, func(server *EtcdMigrateServer) {
135 err := server.Stop()
136 if err != nil {
137 t.Fatalf("Stop server failed: %v", err)
138 }
139 dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
140 if err != nil {
141 t.Fatalf("Error opening or creating data directory %s: %v", server.cfg.dataDirectory, err)
142 }
143 migrator := &Migrator{server.cfg, dataDir, server.client}
144 err = migrator.MigrateIfNeeded(end)
145 if err != nil {
146 t.Fatalf("Migration failed: %v", err)
147 }
148 err = server.Start(end.version)
149 if err != nil {
150 t.Fatalf("Start server failed: %v", err)
151 }
152 })
153
154
155 parallel(servers, func(server *EtcdMigrateServer) {
156 for _, s := range servers {
157 key := fmt.Sprintf("/registry/%s", s.cfg.name)
158 value := fmt.Sprintf("value-%s", s.cfg.name)
159 checkVal, err := server.client.Get(end.version, key)
160 if err != nil {
161 t.Errorf("Error getting %s from etcd 2.x after rollback from 3.x: %v", key, err)
162 }
163 if checkVal != value {
164 t.Errorf("Expected %s from %s but got %s when reading after rollback from %s to %s", value, key, checkVal, start, end)
165 }
166 }
167 })
168
169
170 parallel(servers, func(server *EtcdMigrateServer) {
171 err := server.Stop()
172 if err != nil {
173 t.Fatalf("Failed to stop server: %v", err)
174 }
175 })
176
177
178 parallel(servers, func(server *EtcdMigrateServer) {
179 dataDir, err := OpenOrCreateDataDirectory(server.cfg.dataDirectory)
180 v, err := dataDir.versionFile.Read()
181 if err != nil {
182 t.Fatalf("Failed to read version.txt file: %v", err)
183 }
184 if !v.Equals(end) {
185 t.Errorf("Expected version.txt to contain %s but got %s", end, v)
186 }
187
188 checkPermissions(t, server.cfg.dataDirectory, 0755|os.ModeDir)
189 checkPermissions(t, dataDir.versionFile.path, 0644)
190 })
191 })
192 }
193 }
194
195 func parallel(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
196 var wg sync.WaitGroup
197 wg.Add(len(servers))
198 for _, server := range servers {
199 go func(s *EtcdMigrateServer) {
200 defer wg.Done()
201 fn(s)
202 }(server)
203 }
204 wg.Wait()
205 }
206
207 func serial(servers []*EtcdMigrateServer, fn func(server *EtcdMigrateServer)) {
208 for _, server := range servers {
209 fn(server)
210 }
211 }
212
213 func checkPermissions(t *testing.T, path string, expected os.FileMode) {
214 info, err := os.Stat(path)
215 if err != nil {
216 t.Fatalf("Failed to stat file %s: %v", path, err)
217 }
218 if info.Mode() != expected {
219 t.Errorf("Expected permissions for file %s of %s, but got %s", path, expected, info.Mode())
220 }
221 }
222
223 func clusterConfig(t *testing.T, name string, memberCount int, protocol string, clientListenUrls string) []*EtcdMigrateCfg {
224 peers := []string{}
225 for i := 0; i < memberCount; i++ {
226 memberName := fmt.Sprintf("%s-%d", name, i)
227 peerPort := uint64(2380 + i*10000)
228 peer := fmt.Sprintf("%s=%s://127.0.0.1:%d", memberName, protocol, peerPort)
229 peers = append(peers, peer)
230 }
231 initialCluster := strings.Join(peers, ",")
232
233 extraArgs := ""
234 if protocol == "https" {
235 extraArgs = getOrCreateTLSPeerCertArgs(t)
236 }
237
238 cfgs := []*EtcdMigrateCfg{}
239 for i := 0; i < memberCount; i++ {
240 memberName := fmt.Sprintf("%s-%d", name, i)
241 peerURL := fmt.Sprintf("%s://127.0.0.1:%d", protocol, uint64(2380+i*10000))
242 cfg := &EtcdMigrateCfg{
243 binPath: "/usr/local/bin",
244 name: memberName,
245 initialCluster: initialCluster,
246 port: uint64(2379 + i*10000),
247 peerListenUrls: peerURL,
248 peerAdvertiseUrls: peerURL,
249 clientListenUrls: clientListenUrls,
250 etcdDataPrefix: "/registry",
251 ttlKeysDirectory: "/registry/events",
252 supportedVersions: testSupportedVersions,
253 dataDirectory: fmt.Sprintf("/tmp/etcd-data-dir-%s", memberName),
254 etcdServerArgs: extraArgs,
255 }
256 cfgs = append(cfgs, cfg)
257 }
258 return cfgs
259 }
260
261 func getOrCreateTLSPeerCertArgs(t *testing.T) string {
262 spec := TestCertSpec{
263 host: "localhost",
264 ips: []string{"127.0.0.1"},
265 }
266 certDir := "/tmp/certs"
267 certFile := filepath.Join(certDir, "test.crt")
268 keyFile := filepath.Join(certDir, "test.key")
269 err := getOrCreateTestCertFiles(certFile, keyFile, spec)
270 if err != nil {
271 t.Fatalf("failed to create server cert: %v", err)
272 }
273 return fmt.Sprintf("--peer-client-cert-auth --peer-trusted-ca-file=%s --peer-cert-file=%s --peer-key-file=%s", certFile, certFile, keyFile)
274 }
275
276 type TestCertSpec struct {
277 host string
278 names, ips []string
279 }
280
281 func getOrCreateTestCertFiles(certFileName, keyFileName string, spec TestCertSpec) (err error) {
282 if _, err := os.Stat(certFileName); err == nil {
283 if _, err := os.Stat(keyFileName); err == nil {
284 return nil
285 }
286 }
287
288 certPem, keyPem, err := generateSelfSignedCertKey(spec.host, parseIPList(spec.ips), spec.names)
289 if err != nil {
290 return err
291 }
292
293 os.MkdirAll(filepath.Dir(certFileName), os.FileMode(0777))
294 err = os.WriteFile(certFileName, certPem, os.FileMode(0777))
295 if err != nil {
296 return err
297 }
298
299 os.MkdirAll(filepath.Dir(keyFileName), os.FileMode(0777))
300 err = os.WriteFile(keyFileName, keyPem, os.FileMode(0777))
301 if err != nil {
302 return err
303 }
304
305 return nil
306 }
307
308 func parseIPList(ips []string) []net.IP {
309 var netIPs []net.IP
310 for _, ip := range ips {
311 netIPs = append(netIPs, netutils.ParseIPSloppy(ip))
312 }
313 return netIPs
314 }
315
316
317
318
319 func generateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) {
320 priv, err := rsa.GenerateKey(cryptorand.Reader, 2048)
321 if err != nil {
322 return nil, nil, err
323 }
324
325 template := x509.Certificate{
326 SerialNumber: big.NewInt(1),
327 Subject: pkix.Name{
328 CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
329 },
330 NotBefore: time.Unix(0, 0),
331 NotAfter: time.Now().Add(time.Hour * 24 * 365 * 100),
332
333 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
334 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
335 BasicConstraintsValid: true,
336 IsCA: true,
337 }
338
339 if ip := netutils.ParseIPSloppy(host); ip != nil {
340 template.IPAddresses = append(template.IPAddresses, ip)
341 } else {
342 template.DNSNames = append(template.DNSNames, host)
343 }
344
345 template.IPAddresses = append(template.IPAddresses, alternateIPs...)
346 template.DNSNames = append(template.DNSNames, alternateDNS...)
347
348 derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv)
349 if err != nil {
350 return nil, nil, err
351 }
352
353
354 certBuffer := bytes.Buffer{}
355 if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
356 return nil, nil, err
357 }
358
359
360 keyBuffer := bytes.Buffer{}
361 if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
362 return nil, nil, err
363 }
364
365 return certBuffer.Bytes(), keyBuffer.Bytes(), nil
366 }
367
368
369
370 func mustParseEtcdVersionPair(s string) *EtcdVersionPair {
371 pair, err := ParseEtcdVersionPair(s)
372 if err != nil {
373 panic(err)
374 }
375 return pair
376 }
377
378
379 func mustParseSupportedVersions(list []string) SupportedVersions {
380 versions, err := ParseSupportedVersions(list)
381 if err != nil {
382 panic(err)
383 }
384 return versions
385 }
386
View as plain text