...

Source file src/go.etcd.io/etcd/server/v3/embed/config_test.go

Documentation: go.etcd.io/etcd/server/v3/embed

     1  // Copyright 2016 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package embed
    16  
    17  import (
    18  	"crypto/tls"
    19  	"errors"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"net"
    23  	"net/url"
    24  	"os"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  	"go.etcd.io/etcd/client/pkg/v3/srv"
    30  	"go.etcd.io/etcd/client/pkg/v3/transport"
    31  	"go.etcd.io/etcd/client/pkg/v3/types"
    32  
    33  	"sigs.k8s.io/yaml"
    34  )
    35  
    36  func notFoundErr(service, domain string) error {
    37  	name := fmt.Sprintf("_%s._tcp.%s", service, domain)
    38  	return &net.DNSError{Err: "no such host", Name: name, Server: "10.0.0.53:53", IsTimeout: false, IsTemporary: false, IsNotFound: true}
    39  }
    40  
    41  func TestConfigFileOtherFields(t *testing.T) {
    42  	ctls := securityConfig{TrustedCAFile: "cca", CertFile: "ccert", KeyFile: "ckey"}
    43  	ptls := securityConfig{TrustedCAFile: "pca", CertFile: "pcert", KeyFile: "pkey"}
    44  	yc := struct {
    45  		ClientSecurityCfgFile securityConfig       `json:"client-transport-security"`
    46  		PeerSecurityCfgFile   securityConfig       `json:"peer-transport-security"`
    47  		ForceNewCluster       bool                 `json:"force-new-cluster"`
    48  		Logger                string               `json:"logger"`
    49  		LogOutputs            []string             `json:"log-outputs"`
    50  		Debug                 bool                 `json:"debug"`
    51  		SocketOpts            transport.SocketOpts `json:"socket-options"`
    52  	}{
    53  		ctls,
    54  		ptls,
    55  		true,
    56  		"zap",
    57  		[]string{"/dev/null"},
    58  		false,
    59  		transport.SocketOpts{
    60  			ReusePort: true,
    61  		},
    62  	}
    63  
    64  	b, err := yaml.Marshal(&yc)
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  
    69  	tmpfile := mustCreateCfgFile(t, b)
    70  	defer os.Remove(tmpfile.Name())
    71  
    72  	cfg, err := ConfigFromFile(tmpfile.Name())
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  
    77  	if !ctls.equals(&cfg.ClientTLSInfo) {
    78  		t.Errorf("ClientTLS = %v, want %v", cfg.ClientTLSInfo, ctls)
    79  	}
    80  	if !ptls.equals(&cfg.PeerTLSInfo) {
    81  		t.Errorf("PeerTLS = %v, want %v", cfg.PeerTLSInfo, ptls)
    82  	}
    83  
    84  	assert.Equal(t, true, cfg.ForceNewCluster, "ForceNewCluster does not match")
    85  
    86  	assert.Equal(t, true, cfg.SocketOpts.ReusePort, "ReusePort does not match")
    87  
    88  	assert.Equal(t, false, cfg.SocketOpts.ReuseAddress, "ReuseAddress does not match")
    89  }
    90  
    91  // TestUpdateDefaultClusterFromName ensures that etcd can start with 'etcd --name=abc'.
    92  func TestUpdateDefaultClusterFromName(t *testing.T) {
    93  	cfg := NewConfig()
    94  	defaultInitialCluster := cfg.InitialCluster
    95  	oldscheme := cfg.AdvertisePeerUrls[0].Scheme
    96  	origpeer := cfg.AdvertisePeerUrls[0].String()
    97  	origadvc := cfg.AdvertiseClientUrls[0].String()
    98  
    99  	cfg.Name = "abc"
   100  	lpport := cfg.ListenPeerUrls[0].Port()
   101  
   102  	// in case of 'etcd --name=abc'
   103  	exp := fmt.Sprintf("%s=%s://localhost:%s", cfg.Name, oldscheme, lpport)
   104  	_, _ = cfg.UpdateDefaultClusterFromName(defaultInitialCluster)
   105  	if exp != cfg.InitialCluster {
   106  		t.Fatalf("initial-cluster expected %q, got %q", exp, cfg.InitialCluster)
   107  	}
   108  	// advertise peer URL should not be affected
   109  	if origpeer != cfg.AdvertisePeerUrls[0].String() {
   110  		t.Fatalf("advertise peer url expected %q, got %q", origadvc, cfg.AdvertisePeerUrls[0].String())
   111  	}
   112  	// advertise client URL should not be affected
   113  	if origadvc != cfg.AdvertiseClientUrls[0].String() {
   114  		t.Fatalf("advertise client url expected %q, got %q", origadvc, cfg.AdvertiseClientUrls[0].String())
   115  	}
   116  }
   117  
   118  // TestUpdateDefaultClusterFromNameOverwrite ensures that machine's default host is only used
   119  // if advertise URLs are default values(localhost:2379,2380) AND if listen URL is 0.0.0.0.
   120  func TestUpdateDefaultClusterFromNameOverwrite(t *testing.T) {
   121  	if defaultHostname == "" {
   122  		t.Skip("machine's default host not found")
   123  	}
   124  
   125  	cfg := NewConfig()
   126  	defaultInitialCluster := cfg.InitialCluster
   127  	oldscheme := cfg.AdvertisePeerUrls[0].Scheme
   128  	origadvc := cfg.AdvertiseClientUrls[0].String()
   129  
   130  	cfg.Name = "abc"
   131  	lpport := cfg.ListenPeerUrls[0].Port()
   132  	cfg.ListenPeerUrls[0] = url.URL{Scheme: cfg.ListenPeerUrls[0].Scheme, Host: fmt.Sprintf("0.0.0.0:%s", lpport)}
   133  	dhost, _ := cfg.UpdateDefaultClusterFromName(defaultInitialCluster)
   134  	if dhost != defaultHostname {
   135  		t.Fatalf("expected default host %q, got %q", defaultHostname, dhost)
   136  	}
   137  	aphost, apport := cfg.AdvertisePeerUrls[0].Hostname(), cfg.AdvertisePeerUrls[0].Port()
   138  	if apport != lpport {
   139  		t.Fatalf("advertise peer url got different port %s, expected %s", apport, lpport)
   140  	}
   141  	if aphost != defaultHostname {
   142  		t.Fatalf("advertise peer url expected machine default host %q, got %q", defaultHostname, aphost)
   143  	}
   144  	expected := fmt.Sprintf("%s=%s://%s:%s", cfg.Name, oldscheme, defaultHostname, lpport)
   145  	if expected != cfg.InitialCluster {
   146  		t.Fatalf("initial-cluster expected %q, got %q", expected, cfg.InitialCluster)
   147  	}
   148  
   149  	// advertise client URL should not be affected
   150  	if origadvc != cfg.AdvertiseClientUrls[0].String() {
   151  		t.Fatalf("advertise-client-url expected %q, got %q", origadvc, cfg.AdvertiseClientUrls[0].String())
   152  	}
   153  }
   154  
   155  func (s *securityConfig) equals(t *transport.TLSInfo) bool {
   156  	return s.CertFile == t.CertFile &&
   157  		s.CertAuth == t.ClientCertAuth &&
   158  		s.TrustedCAFile == t.TrustedCAFile
   159  }
   160  
   161  func mustCreateCfgFile(t *testing.T, b []byte) *os.File {
   162  	tmpfile, err := ioutil.TempFile("", "servercfg")
   163  	if err != nil {
   164  		t.Fatal(err)
   165  	}
   166  	if _, err = tmpfile.Write(b); err != nil {
   167  		t.Fatal(err)
   168  	}
   169  	if err = tmpfile.Close(); err != nil {
   170  		t.Fatal(err)
   171  	}
   172  	return tmpfile
   173  }
   174  
   175  func TestAutoCompactionModeInvalid(t *testing.T) {
   176  	cfg := NewConfig()
   177  	cfg.Logger = "zap"
   178  	cfg.LogOutputs = []string{"/dev/null"}
   179  	cfg.AutoCompactionMode = "period"
   180  	err := cfg.Validate()
   181  	if err == nil {
   182  		t.Errorf("expected non-nil error, got %v", err)
   183  	}
   184  }
   185  
   186  func TestAutoCompactionModeParse(t *testing.T) {
   187  	tests := []struct {
   188  		mode      string
   189  		retention string
   190  		werr      bool
   191  		wdur      time.Duration
   192  	}{
   193  		// revision
   194  		{"revision", "1", false, 1},
   195  		{"revision", "1h", false, time.Hour},
   196  		{"revision", "a", true, 0},
   197  		{"revision", "-1", true, 0},
   198  		// periodic
   199  		{"periodic", "1", false, time.Hour},
   200  		{"periodic", "a", true, 0},
   201  		{"revision", "-1", true, 0},
   202  		// err mode
   203  		{"errmode", "1", false, 0},
   204  		{"errmode", "1h", false, time.Hour},
   205  	}
   206  
   207  	hasErr := func(err error) bool {
   208  		return err != nil
   209  	}
   210  
   211  	for i, tt := range tests {
   212  		dur, err := parseCompactionRetention(tt.mode, tt.retention)
   213  		if hasErr(err) != tt.werr {
   214  			t.Errorf("#%d: err = %v, want %v", i, err, tt.werr)
   215  		}
   216  		if dur != tt.wdur {
   217  			t.Errorf("#%d: duration = %s, want %s", i, dur, tt.wdur)
   218  		}
   219  	}
   220  }
   221  
   222  func TestPeerURLsMapAndTokenFromSRV(t *testing.T) {
   223  	defer func() { getCluster = srv.GetCluster }()
   224  
   225  	tests := []struct {
   226  		withSSL    []string
   227  		withoutSSL []string
   228  		apurls     []string
   229  		wurls      string
   230  		werr       bool
   231  	}{
   232  		{
   233  			[]string{},
   234  			[]string{},
   235  			[]string{"http://localhost:2380"},
   236  			"",
   237  			true,
   238  		},
   239  		{
   240  			[]string{"1.example.com=https://1.example.com:2380", "0=https://2.example.com:2380", "1=https://3.example.com:2380"},
   241  			[]string{},
   242  			[]string{"https://1.example.com:2380"},
   243  			"0=https://2.example.com:2380,1.example.com=https://1.example.com:2380,1=https://3.example.com:2380",
   244  			false,
   245  		},
   246  		{
   247  			[]string{"1.example.com=https://1.example.com:2380"},
   248  			[]string{"0=http://2.example.com:2380", "1=http://3.example.com:2380"},
   249  			[]string{"https://1.example.com:2380"},
   250  			"0=http://2.example.com:2380,1.example.com=https://1.example.com:2380,1=http://3.example.com:2380",
   251  			false,
   252  		},
   253  		{
   254  			[]string{},
   255  			[]string{"1.example.com=http://1.example.com:2380", "0=http://2.example.com:2380", "1=http://3.example.com:2380"},
   256  			[]string{"http://1.example.com:2380"},
   257  			"0=http://2.example.com:2380,1.example.com=http://1.example.com:2380,1=http://3.example.com:2380",
   258  			false,
   259  		},
   260  	}
   261  
   262  	hasErr := func(err error) bool {
   263  		return err != nil
   264  	}
   265  
   266  	for i, tt := range tests {
   267  		getCluster = func(serviceScheme string, service string, name string, dns string, apurls types.URLs) ([]string, error) {
   268  			var urls []string
   269  			if serviceScheme == "https" && service == "etcd-server-ssl" {
   270  				urls = tt.withSSL
   271  			} else if serviceScheme == "http" && service == "etcd-server" {
   272  				urls = tt.withoutSSL
   273  			}
   274  			if len(urls) > 0 {
   275  				return urls, nil
   276  			}
   277  			return urls, notFoundErr(service, dns)
   278  		}
   279  
   280  		cfg := NewConfig()
   281  		cfg.Name = "1.example.com"
   282  		cfg.InitialCluster = ""
   283  		cfg.InitialClusterToken = ""
   284  		cfg.DNSCluster = "example.com"
   285  		cfg.AdvertisePeerUrls = types.MustNewURLs(tt.apurls)
   286  
   287  		if err := cfg.Validate(); err != nil {
   288  			t.Errorf("#%d: failed to validate test Config: %v", i, err)
   289  			continue
   290  		}
   291  
   292  		urlsmap, _, err := cfg.PeerURLsMapAndToken("etcd")
   293  		if urlsmap.String() != tt.wurls {
   294  			t.Errorf("#%d: urlsmap = %s, want = %s", i, urlsmap.String(), tt.wurls)
   295  		}
   296  		if hasErr(err) != tt.werr {
   297  			t.Errorf("#%d: err = %v, want = %v", i, err, tt.werr)
   298  		}
   299  	}
   300  }
   301  
   302  func TestLeaseCheckpointValidate(t *testing.T) {
   303  	tcs := []struct {
   304  		name        string
   305  		configFunc  func() Config
   306  		expectError bool
   307  	}{
   308  		{
   309  			name: "Default config should pass",
   310  			configFunc: func() Config {
   311  				return *NewConfig()
   312  			},
   313  		},
   314  		{
   315  			name: "Enabling checkpoint leases should pass",
   316  			configFunc: func() Config {
   317  				cfg := *NewConfig()
   318  				cfg.ExperimentalEnableLeaseCheckpoint = true
   319  				return cfg
   320  			},
   321  		},
   322  		{
   323  			name: "Enabling checkpoint leases and persist should pass",
   324  			configFunc: func() Config {
   325  				cfg := *NewConfig()
   326  				cfg.ExperimentalEnableLeaseCheckpoint = true
   327  				cfg.ExperimentalEnableLeaseCheckpointPersist = true
   328  				return cfg
   329  			},
   330  		},
   331  		{
   332  			name: "Enabling checkpoint leases persist without checkpointing itself should fail",
   333  			configFunc: func() Config {
   334  				cfg := *NewConfig()
   335  				cfg.ExperimentalEnableLeaseCheckpointPersist = true
   336  				return cfg
   337  			},
   338  			expectError: true,
   339  		},
   340  	}
   341  	for _, tc := range tcs {
   342  		t.Run(tc.name, func(t *testing.T) {
   343  			cfg := tc.configFunc()
   344  			err := cfg.Validate()
   345  			if (err != nil) != tc.expectError {
   346  				t.Errorf("config.Validate() = %q, expected error: %v", err, tc.expectError)
   347  			}
   348  		})
   349  	}
   350  }
   351  
   352  func TestLogRotation(t *testing.T) {
   353  	tests := []struct {
   354  		name              string
   355  		logOutputs        []string
   356  		logRotationConfig string
   357  		wantErr           bool
   358  		wantErrMsg        error
   359  	}{
   360  		{
   361  			name:              "mixed log output targets",
   362  			logOutputs:        []string{"stderr", "/tmp/path"},
   363  			logRotationConfig: `{"maxsize": 1}`,
   364  		},
   365  		{
   366  			name:              "log output relative path",
   367  			logOutputs:        []string{"stderr", "tmp/path"},
   368  			logRotationConfig: `{"maxsize": 1}`,
   369  		},
   370  		{
   371  			name:              "no file targets",
   372  			logOutputs:        []string{"stderr"},
   373  			logRotationConfig: `{"maxsize": 1}`,
   374  			wantErr:           true,
   375  			wantErrMsg:        ErrLogRotationInvalidLogOutput,
   376  		},
   377  		{
   378  			name:              "multiple file targets",
   379  			logOutputs:        []string{"/tmp/path1", "/tmp/path2"},
   380  			logRotationConfig: DefaultLogRotationConfig,
   381  			wantErr:           true,
   382  			wantErrMsg:        ErrLogRotationInvalidLogOutput,
   383  		},
   384  		{
   385  			name:              "default output",
   386  			logRotationConfig: `{"maxsize": 1}`,
   387  			wantErr:           true,
   388  			wantErrMsg:        ErrLogRotationInvalidLogOutput,
   389  		},
   390  		{
   391  			name:              "default log rotation config",
   392  			logOutputs:        []string{"/tmp/path"},
   393  			logRotationConfig: DefaultLogRotationConfig,
   394  		},
   395  		{
   396  			name:              "invalid logger config",
   397  			logOutputs:        []string{"/tmp/path"},
   398  			logRotationConfig: `{"maxsize": true}`,
   399  			wantErr:           true,
   400  			wantErrMsg:        errors.New("invalid log rotation config: json: cannot unmarshal bool into Go struct field logRotationConfig.maxsize of type int"),
   401  		},
   402  		{
   403  			name:              "improperly formatted logger config",
   404  			logOutputs:        []string{"/tmp/path"},
   405  			logRotationConfig: `{"maxsize": true`,
   406  			wantErr:           true,
   407  			wantErrMsg:        errors.New("improperly formatted log rotation config: unexpected end of JSON input"),
   408  		},
   409  	}
   410  	for _, tt := range tests {
   411  		t.Run(tt.name, func(t *testing.T) {
   412  			cfg := NewConfig()
   413  			cfg.Logger = "zap"
   414  			cfg.LogOutputs = tt.logOutputs
   415  			cfg.EnableLogRotation = true
   416  			cfg.LogRotationConfigJSON = tt.logRotationConfig
   417  			err := cfg.Validate()
   418  			if err != nil && !tt.wantErr {
   419  				t.Errorf("test %q, unexpected error %v", tt.name, err)
   420  			}
   421  			if err != nil && tt.wantErr && tt.wantErrMsg.Error() != err.Error() {
   422  				t.Errorf("test %q, expected error: %+v, got: %+v", tt.name, tt.wantErrMsg, err)
   423  			}
   424  			if err == nil && tt.wantErr {
   425  				t.Errorf("test %q, expected error, got nil", tt.name)
   426  			}
   427  			if err == nil {
   428  				cfg.GetLogger().Info("test log")
   429  			}
   430  		})
   431  	}
   432  }
   433  
   434  func TestTLSVersionMinMax(t *testing.T) {
   435  	tests := []struct {
   436  		name                  string
   437  		givenTLSMinVersion    string
   438  		givenTLSMaxVersion    string
   439  		givenCipherSuites     []string
   440  		expectError           bool
   441  		expectedMinTLSVersion uint16
   442  		expectedMaxTLSVersion uint16
   443  	}{
   444  		{
   445  			name:                  "Minimum TLS version is set",
   446  			givenTLSMinVersion:    "TLS1.3",
   447  			expectedMinTLSVersion: tls.VersionTLS13,
   448  			expectedMaxTLSVersion: 0,
   449  		},
   450  		{
   451  			name:                  "Maximum TLS version is set",
   452  			givenTLSMaxVersion:    "TLS1.2",
   453  			expectedMinTLSVersion: 0,
   454  			expectedMaxTLSVersion: tls.VersionTLS12,
   455  		},
   456  		{
   457  			name:                  "Minimum and Maximum TLS versions are set",
   458  			givenTLSMinVersion:    "TLS1.3",
   459  			givenTLSMaxVersion:    "TLS1.3",
   460  			expectedMinTLSVersion: tls.VersionTLS13,
   461  			expectedMaxTLSVersion: tls.VersionTLS13,
   462  		},
   463  		{
   464  			name:               "Minimum and Maximum TLS versions are set in reverse order",
   465  			givenTLSMinVersion: "TLS1.3",
   466  			givenTLSMaxVersion: "TLS1.2",
   467  			expectError:        true,
   468  		},
   469  		{
   470  			name:               "Invalid minimum TLS version",
   471  			givenTLSMinVersion: "invalid version",
   472  			expectError:        true,
   473  		},
   474  		{
   475  			name:               "Invalid maximum TLS version",
   476  			givenTLSMaxVersion: "invalid version",
   477  			expectError:        true,
   478  		},
   479  		{
   480  			name:               "Cipher suites configured for TLS 1.3",
   481  			givenTLSMinVersion: "TLS1.3",
   482  			givenCipherSuites:  []string{"TLS_AES_128_GCM_SHA256"},
   483  			expectError:        true,
   484  		},
   485  	}
   486  
   487  	for _, tt := range tests {
   488  		t.Run(tt.name, func(t *testing.T) {
   489  			cfg := NewConfig()
   490  			cfg.TlsMinVersion = tt.givenTLSMinVersion
   491  			cfg.TlsMaxVersion = tt.givenTLSMaxVersion
   492  			cfg.CipherSuites = tt.givenCipherSuites
   493  
   494  			err := cfg.Validate()
   495  			if err != nil {
   496  				assert.True(t, tt.expectError, "Validate() returned error while expecting success: %v", err)
   497  				return
   498  			}
   499  
   500  			updateMinMaxVersions(&cfg.PeerTLSInfo, cfg.TlsMinVersion, cfg.TlsMaxVersion)
   501  			updateMinMaxVersions(&cfg.ClientTLSInfo, cfg.TlsMinVersion, cfg.TlsMaxVersion)
   502  
   503  			assert.Equal(t, tt.expectedMinTLSVersion, cfg.PeerTLSInfo.MinVersion)
   504  			assert.Equal(t, tt.expectedMaxTLSVersion, cfg.PeerTLSInfo.MaxVersion)
   505  			assert.Equal(t, tt.expectedMinTLSVersion, cfg.ClientTLSInfo.MinVersion)
   506  			assert.Equal(t, tt.expectedMaxTLSVersion, cfg.ClientTLSInfo.MaxVersion)
   507  		})
   508  	}
   509  }
   510  

View as plain text