1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package embed
16
17 import (
18 "crypto/tls"
19 "errors"
20 "fmt"
21 "io/ioutil"
22 "net"
23 "net/url"
24 "os"
25 "testing"
26 "time"
27
28 "github.com/stretchr/testify/assert"
29 "go.etcd.io/etcd/client/pkg/v3/srv"
30 "go.etcd.io/etcd/client/pkg/v3/transport"
31 "go.etcd.io/etcd/client/pkg/v3/types"
32
33 "sigs.k8s.io/yaml"
34 )
35
36 func notFoundErr(service, domain string) error {
37 name := fmt.Sprintf("_%s._tcp.%s", service, domain)
38 return &net.DNSError{Err: "no such host", Name: name, Server: "10.0.0.53:53", IsTimeout: false, IsTemporary: false, IsNotFound: true}
39 }
40
41 func TestConfigFileOtherFields(t *testing.T) {
42 ctls := securityConfig{TrustedCAFile: "cca", CertFile: "ccert", KeyFile: "ckey"}
43 ptls := securityConfig{TrustedCAFile: "pca", CertFile: "pcert", KeyFile: "pkey"}
44 yc := struct {
45 ClientSecurityCfgFile securityConfig `json:"client-transport-security"`
46 PeerSecurityCfgFile securityConfig `json:"peer-transport-security"`
47 ForceNewCluster bool `json:"force-new-cluster"`
48 Logger string `json:"logger"`
49 LogOutputs []string `json:"log-outputs"`
50 Debug bool `json:"debug"`
51 SocketOpts transport.SocketOpts `json:"socket-options"`
52 }{
53 ctls,
54 ptls,
55 true,
56 "zap",
57 []string{"/dev/null"},
58 false,
59 transport.SocketOpts{
60 ReusePort: true,
61 },
62 }
63
64 b, err := yaml.Marshal(&yc)
65 if err != nil {
66 t.Fatal(err)
67 }
68
69 tmpfile := mustCreateCfgFile(t, b)
70 defer os.Remove(tmpfile.Name())
71
72 cfg, err := ConfigFromFile(tmpfile.Name())
73 if err != nil {
74 t.Fatal(err)
75 }
76
77 if !ctls.equals(&cfg.ClientTLSInfo) {
78 t.Errorf("ClientTLS = %v, want %v", cfg.ClientTLSInfo, ctls)
79 }
80 if !ptls.equals(&cfg.PeerTLSInfo) {
81 t.Errorf("PeerTLS = %v, want %v", cfg.PeerTLSInfo, ptls)
82 }
83
84 assert.Equal(t, true, cfg.ForceNewCluster, "ForceNewCluster does not match")
85
86 assert.Equal(t, true, cfg.SocketOpts.ReusePort, "ReusePort does not match")
87
88 assert.Equal(t, false, cfg.SocketOpts.ReuseAddress, "ReuseAddress does not match")
89 }
90
91
92 func TestUpdateDefaultClusterFromName(t *testing.T) {
93 cfg := NewConfig()
94 defaultInitialCluster := cfg.InitialCluster
95 oldscheme := cfg.AdvertisePeerUrls[0].Scheme
96 origpeer := cfg.AdvertisePeerUrls[0].String()
97 origadvc := cfg.AdvertiseClientUrls[0].String()
98
99 cfg.Name = "abc"
100 lpport := cfg.ListenPeerUrls[0].Port()
101
102
103 exp := fmt.Sprintf("%s=%s://localhost:%s", cfg.Name, oldscheme, lpport)
104 _, _ = cfg.UpdateDefaultClusterFromName(defaultInitialCluster)
105 if exp != cfg.InitialCluster {
106 t.Fatalf("initial-cluster expected %q, got %q", exp, cfg.InitialCluster)
107 }
108
109 if origpeer != cfg.AdvertisePeerUrls[0].String() {
110 t.Fatalf("advertise peer url expected %q, got %q", origadvc, cfg.AdvertisePeerUrls[0].String())
111 }
112
113 if origadvc != cfg.AdvertiseClientUrls[0].String() {
114 t.Fatalf("advertise client url expected %q, got %q", origadvc, cfg.AdvertiseClientUrls[0].String())
115 }
116 }
117
118
119
120 func TestUpdateDefaultClusterFromNameOverwrite(t *testing.T) {
121 if defaultHostname == "" {
122 t.Skip("machine's default host not found")
123 }
124
125 cfg := NewConfig()
126 defaultInitialCluster := cfg.InitialCluster
127 oldscheme := cfg.AdvertisePeerUrls[0].Scheme
128 origadvc := cfg.AdvertiseClientUrls[0].String()
129
130 cfg.Name = "abc"
131 lpport := cfg.ListenPeerUrls[0].Port()
132 cfg.ListenPeerUrls[0] = url.URL{Scheme: cfg.ListenPeerUrls[0].Scheme, Host: fmt.Sprintf("0.0.0.0:%s", lpport)}
133 dhost, _ := cfg.UpdateDefaultClusterFromName(defaultInitialCluster)
134 if dhost != defaultHostname {
135 t.Fatalf("expected default host %q, got %q", defaultHostname, dhost)
136 }
137 aphost, apport := cfg.AdvertisePeerUrls[0].Hostname(), cfg.AdvertisePeerUrls[0].Port()
138 if apport != lpport {
139 t.Fatalf("advertise peer url got different port %s, expected %s", apport, lpport)
140 }
141 if aphost != defaultHostname {
142 t.Fatalf("advertise peer url expected machine default host %q, got %q", defaultHostname, aphost)
143 }
144 expected := fmt.Sprintf("%s=%s://%s:%s", cfg.Name, oldscheme, defaultHostname, lpport)
145 if expected != cfg.InitialCluster {
146 t.Fatalf("initial-cluster expected %q, got %q", expected, cfg.InitialCluster)
147 }
148
149
150 if origadvc != cfg.AdvertiseClientUrls[0].String() {
151 t.Fatalf("advertise-client-url expected %q, got %q", origadvc, cfg.AdvertiseClientUrls[0].String())
152 }
153 }
154
155 func (s *securityConfig) equals(t *transport.TLSInfo) bool {
156 return s.CertFile == t.CertFile &&
157 s.CertAuth == t.ClientCertAuth &&
158 s.TrustedCAFile == t.TrustedCAFile
159 }
160
161 func mustCreateCfgFile(t *testing.T, b []byte) *os.File {
162 tmpfile, err := ioutil.TempFile("", "servercfg")
163 if err != nil {
164 t.Fatal(err)
165 }
166 if _, err = tmpfile.Write(b); err != nil {
167 t.Fatal(err)
168 }
169 if err = tmpfile.Close(); err != nil {
170 t.Fatal(err)
171 }
172 return tmpfile
173 }
174
175 func TestAutoCompactionModeInvalid(t *testing.T) {
176 cfg := NewConfig()
177 cfg.Logger = "zap"
178 cfg.LogOutputs = []string{"/dev/null"}
179 cfg.AutoCompactionMode = "period"
180 err := cfg.Validate()
181 if err == nil {
182 t.Errorf("expected non-nil error, got %v", err)
183 }
184 }
185
186 func TestAutoCompactionModeParse(t *testing.T) {
187 tests := []struct {
188 mode string
189 retention string
190 werr bool
191 wdur time.Duration
192 }{
193
194 {"revision", "1", false, 1},
195 {"revision", "1h", false, time.Hour},
196 {"revision", "a", true, 0},
197 {"revision", "-1", true, 0},
198
199 {"periodic", "1", false, time.Hour},
200 {"periodic", "a", true, 0},
201 {"revision", "-1", true, 0},
202
203 {"errmode", "1", false, 0},
204 {"errmode", "1h", false, time.Hour},
205 }
206
207 hasErr := func(err error) bool {
208 return err != nil
209 }
210
211 for i, tt := range tests {
212 dur, err := parseCompactionRetention(tt.mode, tt.retention)
213 if hasErr(err) != tt.werr {
214 t.Errorf("#%d: err = %v, want %v", i, err, tt.werr)
215 }
216 if dur != tt.wdur {
217 t.Errorf("#%d: duration = %s, want %s", i, dur, tt.wdur)
218 }
219 }
220 }
221
222 func TestPeerURLsMapAndTokenFromSRV(t *testing.T) {
223 defer func() { getCluster = srv.GetCluster }()
224
225 tests := []struct {
226 withSSL []string
227 withoutSSL []string
228 apurls []string
229 wurls string
230 werr bool
231 }{
232 {
233 []string{},
234 []string{},
235 []string{"http://localhost:2380"},
236 "",
237 true,
238 },
239 {
240 []string{"1.example.com=https://1.example.com:2380", "0=https://2.example.com:2380", "1=https://3.example.com:2380"},
241 []string{},
242 []string{"https://1.example.com:2380"},
243 "0=https://2.example.com:2380,1.example.com=https://1.example.com:2380,1=https://3.example.com:2380",
244 false,
245 },
246 {
247 []string{"1.example.com=https://1.example.com:2380"},
248 []string{"0=http://2.example.com:2380", "1=http://3.example.com:2380"},
249 []string{"https://1.example.com:2380"},
250 "0=http://2.example.com:2380,1.example.com=https://1.example.com:2380,1=http://3.example.com:2380",
251 false,
252 },
253 {
254 []string{},
255 []string{"1.example.com=http://1.example.com:2380", "0=http://2.example.com:2380", "1=http://3.example.com:2380"},
256 []string{"http://1.example.com:2380"},
257 "0=http://2.example.com:2380,1.example.com=http://1.example.com:2380,1=http://3.example.com:2380",
258 false,
259 },
260 }
261
262 hasErr := func(err error) bool {
263 return err != nil
264 }
265
266 for i, tt := range tests {
267 getCluster = func(serviceScheme string, service string, name string, dns string, apurls types.URLs) ([]string, error) {
268 var urls []string
269 if serviceScheme == "https" && service == "etcd-server-ssl" {
270 urls = tt.withSSL
271 } else if serviceScheme == "http" && service == "etcd-server" {
272 urls = tt.withoutSSL
273 }
274 if len(urls) > 0 {
275 return urls, nil
276 }
277 return urls, notFoundErr(service, dns)
278 }
279
280 cfg := NewConfig()
281 cfg.Name = "1.example.com"
282 cfg.InitialCluster = ""
283 cfg.InitialClusterToken = ""
284 cfg.DNSCluster = "example.com"
285 cfg.AdvertisePeerUrls = types.MustNewURLs(tt.apurls)
286
287 if err := cfg.Validate(); err != nil {
288 t.Errorf("#%d: failed to validate test Config: %v", i, err)
289 continue
290 }
291
292 urlsmap, _, err := cfg.PeerURLsMapAndToken("etcd")
293 if urlsmap.String() != tt.wurls {
294 t.Errorf("#%d: urlsmap = %s, want = %s", i, urlsmap.String(), tt.wurls)
295 }
296 if hasErr(err) != tt.werr {
297 t.Errorf("#%d: err = %v, want = %v", i, err, tt.werr)
298 }
299 }
300 }
301
302 func TestLeaseCheckpointValidate(t *testing.T) {
303 tcs := []struct {
304 name string
305 configFunc func() Config
306 expectError bool
307 }{
308 {
309 name: "Default config should pass",
310 configFunc: func() Config {
311 return *NewConfig()
312 },
313 },
314 {
315 name: "Enabling checkpoint leases should pass",
316 configFunc: func() Config {
317 cfg := *NewConfig()
318 cfg.ExperimentalEnableLeaseCheckpoint = true
319 return cfg
320 },
321 },
322 {
323 name: "Enabling checkpoint leases and persist should pass",
324 configFunc: func() Config {
325 cfg := *NewConfig()
326 cfg.ExperimentalEnableLeaseCheckpoint = true
327 cfg.ExperimentalEnableLeaseCheckpointPersist = true
328 return cfg
329 },
330 },
331 {
332 name: "Enabling checkpoint leases persist without checkpointing itself should fail",
333 configFunc: func() Config {
334 cfg := *NewConfig()
335 cfg.ExperimentalEnableLeaseCheckpointPersist = true
336 return cfg
337 },
338 expectError: true,
339 },
340 }
341 for _, tc := range tcs {
342 t.Run(tc.name, func(t *testing.T) {
343 cfg := tc.configFunc()
344 err := cfg.Validate()
345 if (err != nil) != tc.expectError {
346 t.Errorf("config.Validate() = %q, expected error: %v", err, tc.expectError)
347 }
348 })
349 }
350 }
351
352 func TestLogRotation(t *testing.T) {
353 tests := []struct {
354 name string
355 logOutputs []string
356 logRotationConfig string
357 wantErr bool
358 wantErrMsg error
359 }{
360 {
361 name: "mixed log output targets",
362 logOutputs: []string{"stderr", "/tmp/path"},
363 logRotationConfig: `{"maxsize": 1}`,
364 },
365 {
366 name: "log output relative path",
367 logOutputs: []string{"stderr", "tmp/path"},
368 logRotationConfig: `{"maxsize": 1}`,
369 },
370 {
371 name: "no file targets",
372 logOutputs: []string{"stderr"},
373 logRotationConfig: `{"maxsize": 1}`,
374 wantErr: true,
375 wantErrMsg: ErrLogRotationInvalidLogOutput,
376 },
377 {
378 name: "multiple file targets",
379 logOutputs: []string{"/tmp/path1", "/tmp/path2"},
380 logRotationConfig: DefaultLogRotationConfig,
381 wantErr: true,
382 wantErrMsg: ErrLogRotationInvalidLogOutput,
383 },
384 {
385 name: "default output",
386 logRotationConfig: `{"maxsize": 1}`,
387 wantErr: true,
388 wantErrMsg: ErrLogRotationInvalidLogOutput,
389 },
390 {
391 name: "default log rotation config",
392 logOutputs: []string{"/tmp/path"},
393 logRotationConfig: DefaultLogRotationConfig,
394 },
395 {
396 name: "invalid logger config",
397 logOutputs: []string{"/tmp/path"},
398 logRotationConfig: `{"maxsize": true}`,
399 wantErr: true,
400 wantErrMsg: errors.New("invalid log rotation config: json: cannot unmarshal bool into Go struct field logRotationConfig.maxsize of type int"),
401 },
402 {
403 name: "improperly formatted logger config",
404 logOutputs: []string{"/tmp/path"},
405 logRotationConfig: `{"maxsize": true`,
406 wantErr: true,
407 wantErrMsg: errors.New("improperly formatted log rotation config: unexpected end of JSON input"),
408 },
409 }
410 for _, tt := range tests {
411 t.Run(tt.name, func(t *testing.T) {
412 cfg := NewConfig()
413 cfg.Logger = "zap"
414 cfg.LogOutputs = tt.logOutputs
415 cfg.EnableLogRotation = true
416 cfg.LogRotationConfigJSON = tt.logRotationConfig
417 err := cfg.Validate()
418 if err != nil && !tt.wantErr {
419 t.Errorf("test %q, unexpected error %v", tt.name, err)
420 }
421 if err != nil && tt.wantErr && tt.wantErrMsg.Error() != err.Error() {
422 t.Errorf("test %q, expected error: %+v, got: %+v", tt.name, tt.wantErrMsg, err)
423 }
424 if err == nil && tt.wantErr {
425 t.Errorf("test %q, expected error, got nil", tt.name)
426 }
427 if err == nil {
428 cfg.GetLogger().Info("test log")
429 }
430 })
431 }
432 }
433
434 func TestTLSVersionMinMax(t *testing.T) {
435 tests := []struct {
436 name string
437 givenTLSMinVersion string
438 givenTLSMaxVersion string
439 givenCipherSuites []string
440 expectError bool
441 expectedMinTLSVersion uint16
442 expectedMaxTLSVersion uint16
443 }{
444 {
445 name: "Minimum TLS version is set",
446 givenTLSMinVersion: "TLS1.3",
447 expectedMinTLSVersion: tls.VersionTLS13,
448 expectedMaxTLSVersion: 0,
449 },
450 {
451 name: "Maximum TLS version is set",
452 givenTLSMaxVersion: "TLS1.2",
453 expectedMinTLSVersion: 0,
454 expectedMaxTLSVersion: tls.VersionTLS12,
455 },
456 {
457 name: "Minimum and Maximum TLS versions are set",
458 givenTLSMinVersion: "TLS1.3",
459 givenTLSMaxVersion: "TLS1.3",
460 expectedMinTLSVersion: tls.VersionTLS13,
461 expectedMaxTLSVersion: tls.VersionTLS13,
462 },
463 {
464 name: "Minimum and Maximum TLS versions are set in reverse order",
465 givenTLSMinVersion: "TLS1.3",
466 givenTLSMaxVersion: "TLS1.2",
467 expectError: true,
468 },
469 {
470 name: "Invalid minimum TLS version",
471 givenTLSMinVersion: "invalid version",
472 expectError: true,
473 },
474 {
475 name: "Invalid maximum TLS version",
476 givenTLSMaxVersion: "invalid version",
477 expectError: true,
478 },
479 {
480 name: "Cipher suites configured for TLS 1.3",
481 givenTLSMinVersion: "TLS1.3",
482 givenCipherSuites: []string{"TLS_AES_128_GCM_SHA256"},
483 expectError: true,
484 },
485 }
486
487 for _, tt := range tests {
488 t.Run(tt.name, func(t *testing.T) {
489 cfg := NewConfig()
490 cfg.TlsMinVersion = tt.givenTLSMinVersion
491 cfg.TlsMaxVersion = tt.givenTLSMaxVersion
492 cfg.CipherSuites = tt.givenCipherSuites
493
494 err := cfg.Validate()
495 if err != nil {
496 assert.True(t, tt.expectError, "Validate() returned error while expecting success: %v", err)
497 return
498 }
499
500 updateMinMaxVersions(&cfg.PeerTLSInfo, cfg.TlsMinVersion, cfg.TlsMaxVersion)
501 updateMinMaxVersions(&cfg.ClientTLSInfo, cfg.TlsMinVersion, cfg.TlsMaxVersion)
502
503 assert.Equal(t, tt.expectedMinTLSVersion, cfg.PeerTLSInfo.MinVersion)
504 assert.Equal(t, tt.expectedMaxTLSVersion, cfg.PeerTLSInfo.MaxVersion)
505 assert.Equal(t, tt.expectedMinTLSVersion, cfg.ClientTLSInfo.MinVersion)
506 assert.Equal(t, tt.expectedMaxTLSVersion, cfg.ClientTLSInfo.MaxVersion)
507 })
508 }
509 }
510
View as plain text