1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package cluster
19
20 import (
21 "context"
22 "crypto/tls"
23 "fmt"
24 "net"
25 "strings"
26 "time"
27
28 "github.com/go-kit/log"
29 "github.com/go-kit/log/level"
30 "github.com/hashicorp/go-sockaddr"
31 "github.com/hashicorp/memberlist"
32 "github.com/pkg/errors"
33 "github.com/prometheus/client_golang/prometheus"
34 common "github.com/prometheus/common/config"
35 "github.com/prometheus/exporter-toolkit/web"
36 )
37
38 const (
39 metricNamespace = "alertmanager"
40 metricSubsystem = "tls_transport"
41 network = "tcp"
42 )
43
44
45
46 type TLSTransport struct {
47 ctx context.Context
48 cancel context.CancelFunc
49 logger log.Logger
50 bindAddr string
51 bindPort int
52 done chan struct{}
53 listener net.Listener
54 packetCh chan *memberlist.Packet
55 streamCh chan net.Conn
56 connPool *connectionPool
57 tlsServerCfg *tls.Config
58 tlsClientCfg *tls.Config
59
60 packetsSent prometheus.Counter
61 packetsRcvd prometheus.Counter
62 streamsSent prometheus.Counter
63 streamsRcvd prometheus.Counter
64 readErrs prometheus.Counter
65 writeErrs *prometheus.CounterVec
66 }
67
68
69
70
71
72 func NewTLSTransport(
73 ctx context.Context,
74 logger log.Logger,
75 reg prometheus.Registerer,
76 bindAddr string,
77 bindPort int,
78 cfg *TLSTransportConfig,
79 ) (*TLSTransport, error) {
80 if cfg == nil {
81 return nil, errors.New("must specify TLSTransportConfig")
82 }
83 tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
84 if err != nil {
85 return nil, errors.Wrap(err, "invalid TLS server config")
86 }
87 tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
88 if err != nil {
89 return nil, errors.Wrap(err, "invalid TLS client config")
90 }
91 ip := net.ParseIP(bindAddr)
92 if ip == nil {
93 return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
94 }
95 addr := &net.TCPAddr{IP: ip, Port: bindPort}
96 listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
97 if err != nil {
98 return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
99 }
100 connPool, err := newConnectionPool(tlsClientCfg)
101 if err != nil {
102 return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
103 }
104 ctx, cancel := context.WithCancel(ctx)
105 t := &TLSTransport{
106 ctx: ctx,
107 cancel: cancel,
108 logger: logger,
109 bindAddr: bindAddr,
110 bindPort: bindPort,
111 done: make(chan struct{}),
112 listener: listener,
113 packetCh: make(chan *memberlist.Packet),
114 streamCh: make(chan net.Conn),
115 connPool: connPool,
116 tlsServerCfg: tlsServerCfg,
117 tlsClientCfg: tlsClientCfg,
118 }
119
120 t.registerMetrics(reg)
121
122 go func() {
123 t.listen()
124 close(t.done)
125 }()
126 return t, nil
127 }
128
129
130
131
132 func (t *TLSTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, error) {
133 var advertiseAddr net.IP
134 var advertisePort int
135 if ip != "" {
136 advertiseAddr = net.ParseIP(ip)
137 if advertiseAddr == nil {
138 return nil, 0, fmt.Errorf("failed to parse advertise address %q", ip)
139 }
140
141 if ip4 := advertiseAddr.To4(); ip4 != nil {
142 advertiseAddr = ip4
143 }
144 advertisePort = port
145 } else {
146 if t.bindAddr == "0.0.0.0" {
147
148
149 var err error
150 ip, err = sockaddr.GetPrivateIP()
151 if err != nil {
152 return nil, 0, fmt.Errorf("failed to get interface addresses: %v", err)
153 }
154 if ip == "" {
155 return nil, 0, fmt.Errorf("no private IP address found, and explicit IP not provided")
156 }
157
158 advertiseAddr = net.ParseIP(ip)
159 if advertiseAddr == nil {
160 return nil, 0, fmt.Errorf("failed to parse advertise address: %q", ip)
161 }
162 } else {
163 advertiseAddr = t.listener.Addr().(*net.TCPAddr).IP
164 }
165 advertisePort = t.GetAutoBindPort()
166 }
167 return advertiseAddr, advertisePort, nil
168 }
169
170
171
172 func (t *TLSTransport) PacketCh() <-chan *memberlist.Packet {
173 return t.packetCh
174 }
175
176
177
178 func (t *TLSTransport) StreamCh() <-chan net.Conn {
179 return t.streamCh
180 }
181
182
183
184 func (t *TLSTransport) Shutdown() error {
185 level.Debug(t.logger).Log("msg", "shutting down tls transport")
186 t.cancel()
187 err := t.listener.Close()
188 t.connPool.shutdown()
189 <-t.done
190 return err
191 }
192
193
194
195
196 func (t *TLSTransport) WriteTo(b []byte, addr string) (time.Time, error) {
197 conn, err := t.connPool.borrowConnection(addr, DefaultTCPTimeout)
198 if err != nil {
199 t.writeErrs.WithLabelValues("packet").Inc()
200 return time.Now(), errors.Wrap(err, "failed to dial")
201 }
202 fromAddr := t.listener.Addr().String()
203 err = conn.writePacket(fromAddr, b)
204 if err != nil {
205 t.writeErrs.WithLabelValues("packet").Inc()
206 return time.Now(), errors.Wrap(err, "failed to write packet")
207 }
208 t.packetsSent.Add(float64(len(b)))
209 return time.Now(), nil
210 }
211
212
213
214 func (t *TLSTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
215 conn, err := dialTLSConn(addr, timeout, t.tlsClientCfg)
216 if err != nil {
217 t.writeErrs.WithLabelValues("stream").Inc()
218 return nil, errors.Wrap(err, "failed to dial")
219 }
220 err = conn.writeStream()
221 netConn := conn.getRawConn()
222 if err != nil {
223 t.writeErrs.WithLabelValues("stream").Inc()
224 return netConn, errors.Wrap(err, "failed to create stream connection")
225 }
226 t.streamsSent.Inc()
227 return netConn, nil
228 }
229
230
231
232 func (t *TLSTransport) GetAutoBindPort() int {
233 return t.listener.Addr().(*net.TCPAddr).Port
234 }
235
236
237 func (t *TLSTransport) listen() {
238 for {
239 select {
240 case <-t.ctx.Done():
241
242 return
243 default:
244 conn, err := t.listener.Accept()
245 if err != nil {
246
247
248 if strings.Contains(err.Error(), "use of closed network connection") {
249 return
250 }
251 t.readErrs.Inc()
252 level.Debug(t.logger).Log("msg", "error accepting connection", "err", err)
253
254 } else {
255 go t.handle(conn)
256 }
257 }
258 }
259 }
260
261 func (t *TLSTransport) handle(conn net.Conn) {
262 for {
263 packet, err := rcvTLSConn(conn).read()
264 if err != nil {
265 level.Debug(t.logger).Log("msg", "error reading from connection", "err", err)
266 t.readErrs.Inc()
267 return
268 }
269 select {
270 case <-t.ctx.Done():
271 return
272 default:
273 if packet != nil {
274 n := len(packet.Buf)
275 t.packetCh <- packet
276 t.packetsRcvd.Add(float64(n))
277 } else {
278 t.streamCh <- conn
279 t.streamsRcvd.Inc()
280 return
281 }
282 }
283 }
284 }
285
286 func (t *TLSTransport) registerMetrics(reg prometheus.Registerer) {
287 t.packetsSent = prometheus.NewCounter(
288 prometheus.CounterOpts{
289 Namespace: metricNamespace,
290 Subsystem: metricSubsystem,
291 Name: "packet_bytes_sent_total",
292 Help: "The number of packet bytes sent to outgoing connections (excluding internal metadata).",
293 },
294 )
295 t.packetsRcvd = prometheus.NewCounter(
296 prometheus.CounterOpts{
297 Namespace: metricNamespace,
298 Subsystem: metricSubsystem,
299 Name: "packet_bytes_received_total",
300 Help: "The number of packet bytes received from incoming connections (excluding internal metadata).",
301 },
302 )
303 t.streamsSent = prometheus.NewCounter(
304 prometheus.CounterOpts{
305 Namespace: metricNamespace,
306 Subsystem: metricSubsystem,
307 Name: "stream_connections_sent_total",
308 Help: "The number of stream connections sent.",
309 },
310 )
311
312 t.streamsRcvd = prometheus.NewCounter(
313 prometheus.CounterOpts{
314 Namespace: metricNamespace,
315 Subsystem: metricSubsystem,
316 Name: "stream_connections_received_total",
317 Help: "The number of stream connections received.",
318 },
319 )
320 t.readErrs = prometheus.NewCounter(
321 prometheus.CounterOpts{
322 Namespace: metricNamespace,
323 Subsystem: metricSubsystem,
324 Name: "read_errors_total",
325 Help: "The number of errors encountered while reading from incoming connections.",
326 },
327 )
328 t.writeErrs = prometheus.NewCounterVec(
329 prometheus.CounterOpts{
330 Namespace: metricNamespace,
331 Subsystem: metricSubsystem,
332 Name: "write_errors_total",
333 Help: "The number of errors encountered while writing to outgoing connections.",
334 },
335 []string{"connection_type"},
336 )
337
338 if reg != nil {
339 reg.MustRegister(t.packetsSent)
340 reg.MustRegister(t.packetsRcvd)
341 reg.MustRegister(t.streamsSent)
342 reg.MustRegister(t.streamsRcvd)
343 reg.MustRegister(t.readErrs)
344 reg.MustRegister(t.writeErrs)
345 }
346 }
347
View as plain text