...

Source file src/github.com/prometheus/alertmanager/cluster/tls_transport_test.go

Documentation: github.com/prometheus/alertmanager/cluster

     1  // Copyright 2020 The Prometheus Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package cluster
    15  
    16  import (
    17  	"bufio"
    18  	context2 "context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/go-kit/log"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  var logger = log.NewNopLogger()
    31  
    32  func TestNewTLSTransport(t *testing.T) {
    33  	testCases := []struct {
    34  		bindAddr    string
    35  		bindPort    int
    36  		tlsConfFile string
    37  		err         string
    38  	}{
    39  		{err: "must specify TLSTransportConfig"},
    40  		{err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"},
    41  		{bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"},
    42  		{bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"},
    43  		{bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"},
    44  	}
    45  	l := log.NewNopLogger()
    46  	for _, tc := range testCases {
    47  		cfg := mustTLSTransportConfig(tc.tlsConfFile)
    48  		transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg)
    49  		if len(tc.err) > 0 {
    50  			require.Equal(t, tc.err, err.Error())
    51  			require.Nil(t, transport)
    52  		} else {
    53  			require.Nil(t, err)
    54  			require.Equal(t, tc.bindAddr, transport.bindAddr)
    55  			require.Equal(t, tc.bindPort, transport.bindPort)
    56  			require.Equal(t, l, transport.logger)
    57  			require.NotNil(t, transport.listener)
    58  			transport.Shutdown()
    59  		}
    60  	}
    61  }
    62  
    63  const localhost = "127.0.0.1"
    64  
    65  func TestFinalAdvertiseAddr(t *testing.T) {
    66  	testCases := []struct {
    67  		bindAddr      string
    68  		bindPort      int
    69  		inputIP       string
    70  		inputPort     int
    71  		expectedIP    string
    72  		expectedPort  int
    73  		expectedError string
    74  	}{
    75  		{bindAddr: localhost, bindPort: 9094, inputIP: "10.0.0.5", inputPort: 54231, expectedIP: "10.0.0.5", expectedPort: 54231},
    76  		{bindAddr: localhost, bindPort: 9093, inputIP: "invalid", inputPort: 54231, expectedError: "failed to parse advertise address \"invalid\""},
    77  		{bindAddr: "0.0.0.0", bindPort: 0, inputIP: "", inputPort: 0, expectedIP: "random"},
    78  		{bindAddr: localhost, bindPort: 0, inputIP: "", inputPort: 0, expectedIP: localhost},
    79  		{bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
    80  	}
    81  	for _, tc := range testCases {
    82  		tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml")
    83  		transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
    84  		require.Nil(t, err)
    85  		ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
    86  		if len(tc.expectedError) > 0 {
    87  			require.Equal(t, tc.expectedError, err.Error())
    88  		} else {
    89  			require.Nil(t, err)
    90  			if tc.expectedPort == 0 {
    91  				require.True(t, tc.expectedPort < port)
    92  				require.True(t, uint32(port) <= uint32(1<<32-1))
    93  			} else {
    94  				require.Equal(t, tc.expectedPort, port)
    95  			}
    96  			if tc.expectedIP == "random" {
    97  				require.NotNil(t, ip)
    98  			} else {
    99  				require.Equal(t, tc.expectedIP, ip.String())
   100  			}
   101  		}
   102  		transport.Shutdown()
   103  	}
   104  }
   105  
   106  func TestWriteTo(t *testing.T) {
   107  	tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
   108  	t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
   109  	defer t1.Shutdown()
   110  
   111  	tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
   112  	t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
   113  	defer t2.Shutdown()
   114  
   115  	from := fmt.Sprintf("%s:%d", t1.bindAddr, t1.GetAutoBindPort())
   116  	to := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
   117  	sent := []byte(("test packet"))
   118  	_, err := t1.WriteTo(sent, to)
   119  	require.Nil(t, err)
   120  	packet := <-t2.PacketCh()
   121  	require.Equal(t, sent, packet.Buf)
   122  	require.Equal(t, from, packet.From.String())
   123  }
   124  
   125  func BenchmarkWriteTo(b *testing.B) {
   126  	tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
   127  	t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
   128  	defer t1.Shutdown()
   129  
   130  	tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
   131  	t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
   132  	defer t2.Shutdown()
   133  
   134  	b.ResetTimer()
   135  	from := fmt.Sprintf("%s:%d", t1.bindAddr, t1.GetAutoBindPort())
   136  	to := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
   137  	sent := []byte(("test packet"))
   138  
   139  	_, err := t1.WriteTo(sent, to)
   140  	require.Nil(b, err)
   141  	packet := <-t2.PacketCh()
   142  
   143  	require.Equal(b, sent, packet.Buf)
   144  	require.Equal(b, from, packet.From.String())
   145  }
   146  
   147  func TestDialTimout(t *testing.T) {
   148  	tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
   149  	t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
   150  	require.Nil(t, err)
   151  	defer t1.Shutdown()
   152  
   153  	tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
   154  	t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
   155  	require.Nil(t, err)
   156  	defer t2.Shutdown()
   157  
   158  	addr := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
   159  	from, err := t1.DialTimeout(addr, 5*time.Second)
   160  	require.Nil(t, err)
   161  	defer from.Close()
   162  
   163  	var to net.Conn
   164  	var wg sync.WaitGroup
   165  	wg.Add(1)
   166  	go func() {
   167  		to = <-t2.StreamCh()
   168  		wg.Done()
   169  	}()
   170  
   171  	sent := []byte(("test stream"))
   172  	m, err := from.Write(sent)
   173  	require.Nil(t, err)
   174  	require.Greater(t, m, 0)
   175  
   176  	wg.Wait()
   177  
   178  	reader := bufio.NewReader(to)
   179  	buf := make([]byte, len(sent))
   180  	n, err := io.ReadFull(reader, buf)
   181  	require.Nil(t, err)
   182  	require.Equal(t, len(sent), n)
   183  	require.Equal(t, sent, buf)
   184  }
   185  
   186  type logWr struct {
   187  	bytes []byte
   188  }
   189  
   190  func (l *logWr) Write(p []byte) (n int, err error) {
   191  	l.bytes = append(l.bytes, p...)
   192  	return len(p), nil
   193  }
   194  
   195  func TestShutdown(t *testing.T) {
   196  	tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
   197  	l := &logWr{}
   198  	t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
   199  	// Sleeping to make sure listeners have started and can subsequently be shut down gracefully.
   200  	time.Sleep(500 * time.Millisecond)
   201  	err := t1.Shutdown()
   202  	require.Nil(t, err)
   203  	require.NotContains(t, string(l.bytes), "use of closed network connection")
   204  	require.Contains(t, string(l.bytes), "shutting down tls transport")
   205  }
   206  
   207  func mustTLSTransportConfig(filename string) *TLSTransportConfig {
   208  	config, err := GetTLSTransportConfig(filename)
   209  	if err != nil {
   210  		panic(err)
   211  	}
   212  	return config
   213  }
   214  

View as plain text