1
2
3
4
5
6
7
8
9
10
11
12
13
14 package cluster
15
16 import (
17 "bufio"
18 context2 "context"
19 "fmt"
20 "io"
21 "net"
22 "sync"
23 "testing"
24 "time"
25
26 "github.com/go-kit/log"
27 "github.com/stretchr/testify/require"
28 )
29
30 var logger = log.NewNopLogger()
31
32 func TestNewTLSTransport(t *testing.T) {
33 testCases := []struct {
34 bindAddr string
35 bindPort int
36 tlsConfFile string
37 err string
38 }{
39 {err: "must specify TLSTransportConfig"},
40 {err: "invalid bind address \"\"", tlsConfFile: "testdata/tls_config_node1.yml"},
41 {bindAddr: "abc123", err: "invalid bind address \"abc123\"", tlsConfFile: "testdata/tls_config_node1.yml"},
42 {bindAddr: localhost, bindPort: 0, tlsConfFile: "testdata/tls_config_node1.yml"},
43 {bindAddr: localhost, bindPort: 9094, tlsConfFile: "testdata/tls_config_node2.yml"},
44 }
45 l := log.NewNopLogger()
46 for _, tc := range testCases {
47 cfg := mustTLSTransportConfig(tc.tlsConfFile)
48 transport, err := NewTLSTransport(context2.Background(), l, nil, tc.bindAddr, tc.bindPort, cfg)
49 if len(tc.err) > 0 {
50 require.Equal(t, tc.err, err.Error())
51 require.Nil(t, transport)
52 } else {
53 require.Nil(t, err)
54 require.Equal(t, tc.bindAddr, transport.bindAddr)
55 require.Equal(t, tc.bindPort, transport.bindPort)
56 require.Equal(t, l, transport.logger)
57 require.NotNil(t, transport.listener)
58 transport.Shutdown()
59 }
60 }
61 }
62
63 const localhost = "127.0.0.1"
64
65 func TestFinalAdvertiseAddr(t *testing.T) {
66 testCases := []struct {
67 bindAddr string
68 bindPort int
69 inputIP string
70 inputPort int
71 expectedIP string
72 expectedPort int
73 expectedError string
74 }{
75 {bindAddr: localhost, bindPort: 9094, inputIP: "10.0.0.5", inputPort: 54231, expectedIP: "10.0.0.5", expectedPort: 54231},
76 {bindAddr: localhost, bindPort: 9093, inputIP: "invalid", inputPort: 54231, expectedError: "failed to parse advertise address \"invalid\""},
77 {bindAddr: "0.0.0.0", bindPort: 0, inputIP: "", inputPort: 0, expectedIP: "random"},
78 {bindAddr: localhost, bindPort: 0, inputIP: "", inputPort: 0, expectedIP: localhost},
79 {bindAddr: localhost, bindPort: 9095, inputIP: "", inputPort: 0, expectedIP: localhost, expectedPort: 9095},
80 }
81 for _, tc := range testCases {
82 tlsConf := mustTLSTransportConfig("testdata/tls_config_node1.yml")
83 transport, err := NewTLSTransport(context2.Background(), logger, nil, tc.bindAddr, tc.bindPort, tlsConf)
84 require.Nil(t, err)
85 ip, port, err := transport.FinalAdvertiseAddr(tc.inputIP, tc.inputPort)
86 if len(tc.expectedError) > 0 {
87 require.Equal(t, tc.expectedError, err.Error())
88 } else {
89 require.Nil(t, err)
90 if tc.expectedPort == 0 {
91 require.True(t, tc.expectedPort < port)
92 require.True(t, uint32(port) <= uint32(1<<32-1))
93 } else {
94 require.Equal(t, tc.expectedPort, port)
95 }
96 if tc.expectedIP == "random" {
97 require.NotNil(t, ip)
98 } else {
99 require.Equal(t, tc.expectedIP, ip.String())
100 }
101 }
102 transport.Shutdown()
103 }
104 }
105
106 func TestWriteTo(t *testing.T) {
107 tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
108 t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
109 defer t1.Shutdown()
110
111 tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
112 t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
113 defer t2.Shutdown()
114
115 from := fmt.Sprintf("%s:%d", t1.bindAddr, t1.GetAutoBindPort())
116 to := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
117 sent := []byte(("test packet"))
118 _, err := t1.WriteTo(sent, to)
119 require.Nil(t, err)
120 packet := <-t2.PacketCh()
121 require.Equal(t, sent, packet.Buf)
122 require.Equal(t, from, packet.From.String())
123 }
124
125 func BenchmarkWriteTo(b *testing.B) {
126 tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
127 t1, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
128 defer t1.Shutdown()
129
130 tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
131 t2, _ := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
132 defer t2.Shutdown()
133
134 b.ResetTimer()
135 from := fmt.Sprintf("%s:%d", t1.bindAddr, t1.GetAutoBindPort())
136 to := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
137 sent := []byte(("test packet"))
138
139 _, err := t1.WriteTo(sent, to)
140 require.Nil(b, err)
141 packet := <-t2.PacketCh()
142
143 require.Equal(b, sent, packet.Buf)
144 require.Equal(b, from, packet.From.String())
145 }
146
147 func TestDialTimout(t *testing.T) {
148 tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
149 t1, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf1)
150 require.Nil(t, err)
151 defer t1.Shutdown()
152
153 tlsConf2 := mustTLSTransportConfig("testdata/tls_config_node2.yml")
154 t2, err := NewTLSTransport(context2.Background(), logger, nil, "127.0.0.1", 0, tlsConf2)
155 require.Nil(t, err)
156 defer t2.Shutdown()
157
158 addr := fmt.Sprintf("%s:%d", t2.bindAddr, t2.GetAutoBindPort())
159 from, err := t1.DialTimeout(addr, 5*time.Second)
160 require.Nil(t, err)
161 defer from.Close()
162
163 var to net.Conn
164 var wg sync.WaitGroup
165 wg.Add(1)
166 go func() {
167 to = <-t2.StreamCh()
168 wg.Done()
169 }()
170
171 sent := []byte(("test stream"))
172 m, err := from.Write(sent)
173 require.Nil(t, err)
174 require.Greater(t, m, 0)
175
176 wg.Wait()
177
178 reader := bufio.NewReader(to)
179 buf := make([]byte, len(sent))
180 n, err := io.ReadFull(reader, buf)
181 require.Nil(t, err)
182 require.Equal(t, len(sent), n)
183 require.Equal(t, sent, buf)
184 }
185
186 type logWr struct {
187 bytes []byte
188 }
189
190 func (l *logWr) Write(p []byte) (n int, err error) {
191 l.bytes = append(l.bytes, p...)
192 return len(p), nil
193 }
194
195 func TestShutdown(t *testing.T) {
196 tlsConf1 := mustTLSTransportConfig("testdata/tls_config_node1.yml")
197 l := &logWr{}
198 t1, _ := NewTLSTransport(context2.Background(), log.NewLogfmtLogger(l), nil, "127.0.0.1", 0, tlsConf1)
199
200 time.Sleep(500 * time.Millisecond)
201 err := t1.Shutdown()
202 require.Nil(t, err)
203 require.NotContains(t, string(l.bytes), "use of closed network connection")
204 require.Contains(t, string(l.bytes), "shutting down tls transport")
205 }
206
207 func mustTLSTransportConfig(filename string) *TLSTransportConfig {
208 config, err := GetTLSTransportConfig(filename)
209 if err != nil {
210 panic(err)
211 }
212 return config
213 }
214
View as plain text