...

Source file src/github.com/prometheus/alertmanager/cluster/tls_transport.go

Documentation: github.com/prometheus/alertmanager/cluster

     1  // Copyright 2020 The Prometheus Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  // Forked from https://github.com/mxinden/memberlist-tls-transport.
    15  
    16  // Implements Transport interface so that all gossip communications occur via TLS over TCP.
    17  
    18  package cluster
    19  
    20  import (
    21  	"context"
    22  	"crypto/tls"
    23  	"fmt"
    24  	"net"
    25  	"strings"
    26  	"time"
    27  
    28  	"github.com/go-kit/log"
    29  	"github.com/go-kit/log/level"
    30  	"github.com/hashicorp/go-sockaddr"
    31  	"github.com/hashicorp/memberlist"
    32  	"github.com/pkg/errors"
    33  	"github.com/prometheus/client_golang/prometheus"
    34  	common "github.com/prometheus/common/config"
    35  	"github.com/prometheus/exporter-toolkit/web"
    36  )
    37  
    38  const (
    39  	metricNamespace = "alertmanager"
    40  	metricSubsystem = "tls_transport"
    41  	network         = "tcp"
    42  )
    43  
    44  // TLSTransport is a Transport implementation that uses TLS over TCP for both
    45  // packet and stream operations.
    46  type TLSTransport struct {
    47  	ctx          context.Context
    48  	cancel       context.CancelFunc
    49  	logger       log.Logger
    50  	bindAddr     string
    51  	bindPort     int
    52  	done         chan struct{}
    53  	listener     net.Listener
    54  	packetCh     chan *memberlist.Packet
    55  	streamCh     chan net.Conn
    56  	connPool     *connectionPool
    57  	tlsServerCfg *tls.Config
    58  	tlsClientCfg *tls.Config
    59  
    60  	packetsSent prometheus.Counter
    61  	packetsRcvd prometheus.Counter
    62  	streamsSent prometheus.Counter
    63  	streamsRcvd prometheus.Counter
    64  	readErrs    prometheus.Counter
    65  	writeErrs   *prometheus.CounterVec
    66  }
    67  
    68  // NewTLSTransport returns a TLS transport with the given configuration.
    69  // On successful initialization, a tls listener will be created and listening.
    70  // A valid bindAddr is required. If bindPort == 0, the system will assign
    71  // a free port automatically.
    72  func NewTLSTransport(
    73  	ctx context.Context,
    74  	logger log.Logger,
    75  	reg prometheus.Registerer,
    76  	bindAddr string,
    77  	bindPort int,
    78  	cfg *TLSTransportConfig,
    79  ) (*TLSTransport, error) {
    80  	if cfg == nil {
    81  		return nil, errors.New("must specify TLSTransportConfig")
    82  	}
    83  	tlsServerCfg, err := web.ConfigToTLSConfig(cfg.TLSServerConfig)
    84  	if err != nil {
    85  		return nil, errors.Wrap(err, "invalid TLS server config")
    86  	}
    87  	tlsClientCfg, err := common.NewTLSConfig(cfg.TLSClientConfig)
    88  	if err != nil {
    89  		return nil, errors.Wrap(err, "invalid TLS client config")
    90  	}
    91  	ip := net.ParseIP(bindAddr)
    92  	if ip == nil {
    93  		return nil, fmt.Errorf("invalid bind address \"%s\"", bindAddr)
    94  	}
    95  	addr := &net.TCPAddr{IP: ip, Port: bindPort}
    96  	listener, err := tls.Listen(network, addr.String(), tlsServerCfg)
    97  	if err != nil {
    98  		return nil, errors.Wrap(err, fmt.Sprintf("failed to start TLS listener on %q port %d", bindAddr, bindPort))
    99  	}
   100  	connPool, err := newConnectionPool(tlsClientCfg)
   101  	if err != nil {
   102  		return nil, errors.Wrap(err, "failed to initialize tls transport connection pool")
   103  	}
   104  	ctx, cancel := context.WithCancel(ctx)
   105  	t := &TLSTransport{
   106  		ctx:          ctx,
   107  		cancel:       cancel,
   108  		logger:       logger,
   109  		bindAddr:     bindAddr,
   110  		bindPort:     bindPort,
   111  		done:         make(chan struct{}),
   112  		listener:     listener,
   113  		packetCh:     make(chan *memberlist.Packet),
   114  		streamCh:     make(chan net.Conn),
   115  		connPool:     connPool,
   116  		tlsServerCfg: tlsServerCfg,
   117  		tlsClientCfg: tlsClientCfg,
   118  	}
   119  
   120  	t.registerMetrics(reg)
   121  
   122  	go func() {
   123  		t.listen()
   124  		close(t.done)
   125  	}()
   126  	return t, nil
   127  }
   128  
   129  // FinalAdvertiseAddr is given the user's configured values (which
   130  // might be empty) and returns the desired IP and port to advertise to
   131  // the rest of the cluster.
   132  func (t *TLSTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, error) {
   133  	var advertiseAddr net.IP
   134  	var advertisePort int
   135  	if ip != "" {
   136  		advertiseAddr = net.ParseIP(ip)
   137  		if advertiseAddr == nil {
   138  			return nil, 0, fmt.Errorf("failed to parse advertise address %q", ip)
   139  		}
   140  
   141  		if ip4 := advertiseAddr.To4(); ip4 != nil {
   142  			advertiseAddr = ip4
   143  		}
   144  		advertisePort = port
   145  	} else {
   146  		if t.bindAddr == "0.0.0.0" {
   147  			// Otherwise, if we're not bound to a specific IP, let's
   148  			// use a suitable private IP address.
   149  			var err error
   150  			ip, err = sockaddr.GetPrivateIP()
   151  			if err != nil {
   152  				return nil, 0, fmt.Errorf("failed to get interface addresses: %v", err)
   153  			}
   154  			if ip == "" {
   155  				return nil, 0, fmt.Errorf("no private IP address found, and explicit IP not provided")
   156  			}
   157  
   158  			advertiseAddr = net.ParseIP(ip)
   159  			if advertiseAddr == nil {
   160  				return nil, 0, fmt.Errorf("failed to parse advertise address: %q", ip)
   161  			}
   162  		} else {
   163  			advertiseAddr = t.listener.Addr().(*net.TCPAddr).IP
   164  		}
   165  		advertisePort = t.GetAutoBindPort()
   166  	}
   167  	return advertiseAddr, advertisePort, nil
   168  }
   169  
   170  // PacketCh returns a channel that can be read to receive incoming
   171  // packets from other peers.
   172  func (t *TLSTransport) PacketCh() <-chan *memberlist.Packet {
   173  	return t.packetCh
   174  }
   175  
   176  // StreamCh returns a channel that can be read to handle incoming stream
   177  // connections from other peers.
   178  func (t *TLSTransport) StreamCh() <-chan net.Conn {
   179  	return t.streamCh
   180  }
   181  
   182  // Shutdown is called when memberlist is shutting down; this gives the
   183  // TLS Transport a chance to clean up the listener and other goroutines.
   184  func (t *TLSTransport) Shutdown() error {
   185  	level.Debug(t.logger).Log("msg", "shutting down tls transport")
   186  	t.cancel()
   187  	err := t.listener.Close()
   188  	t.connPool.shutdown()
   189  	<-t.done
   190  	return err
   191  }
   192  
   193  // WriteTo is a packet-oriented interface that borrows a connection
   194  // from the pool, and writes to it. It also returns a timestamp of when
   195  // the packet was written.
   196  func (t *TLSTransport) WriteTo(b []byte, addr string) (time.Time, error) {
   197  	conn, err := t.connPool.borrowConnection(addr, DefaultTCPTimeout)
   198  	if err != nil {
   199  		t.writeErrs.WithLabelValues("packet").Inc()
   200  		return time.Now(), errors.Wrap(err, "failed to dial")
   201  	}
   202  	fromAddr := t.listener.Addr().String()
   203  	err = conn.writePacket(fromAddr, b)
   204  	if err != nil {
   205  		t.writeErrs.WithLabelValues("packet").Inc()
   206  		return time.Now(), errors.Wrap(err, "failed to write packet")
   207  	}
   208  	t.packetsSent.Add(float64(len(b)))
   209  	return time.Now(), nil
   210  }
   211  
   212  // DialTimeout is used to create a connection that allows memberlist
   213  // to perform two-way communications with a peer.
   214  func (t *TLSTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
   215  	conn, err := dialTLSConn(addr, timeout, t.tlsClientCfg)
   216  	if err != nil {
   217  		t.writeErrs.WithLabelValues("stream").Inc()
   218  		return nil, errors.Wrap(err, "failed to dial")
   219  	}
   220  	err = conn.writeStream()
   221  	netConn := conn.getRawConn()
   222  	if err != nil {
   223  		t.writeErrs.WithLabelValues("stream").Inc()
   224  		return netConn, errors.Wrap(err, "failed to create stream connection")
   225  	}
   226  	t.streamsSent.Inc()
   227  	return netConn, nil
   228  }
   229  
   230  // GetAutoBindPort returns the bind port that was automatically given by the system
   231  // if a bindPort of 0 was specified during instantiation.
   232  func (t *TLSTransport) GetAutoBindPort() int {
   233  	return t.listener.Addr().(*net.TCPAddr).Port
   234  }
   235  
   236  // listen starts up multiple handlers accepting concurrent connections.
   237  func (t *TLSTransport) listen() {
   238  	for {
   239  		select {
   240  		case <-t.ctx.Done():
   241  
   242  			return
   243  		default:
   244  			conn, err := t.listener.Accept()
   245  			if err != nil {
   246  				// The error "use of closed network connection" is returned when the listener is closed.
   247  				// It is not exported in a more reasonable way. See https://github.com/golang/go/issues/4373.
   248  				if strings.Contains(err.Error(), "use of closed network connection") {
   249  					return
   250  				}
   251  				t.readErrs.Inc()
   252  				level.Debug(t.logger).Log("msg", "error accepting connection", "err", err)
   253  
   254  			} else {
   255  				go t.handle(conn)
   256  			}
   257  		}
   258  	}
   259  }
   260  
   261  func (t *TLSTransport) handle(conn net.Conn) {
   262  	for {
   263  		packet, err := rcvTLSConn(conn).read()
   264  		if err != nil {
   265  			level.Debug(t.logger).Log("msg", "error reading from connection", "err", err)
   266  			t.readErrs.Inc()
   267  			return
   268  		}
   269  		select {
   270  		case <-t.ctx.Done():
   271  			return
   272  		default:
   273  			if packet != nil {
   274  				n := len(packet.Buf)
   275  				t.packetCh <- packet
   276  				t.packetsRcvd.Add(float64(n))
   277  			} else {
   278  				t.streamCh <- conn
   279  				t.streamsRcvd.Inc()
   280  				return
   281  			}
   282  		}
   283  	}
   284  }
   285  
   286  func (t *TLSTransport) registerMetrics(reg prometheus.Registerer) {
   287  	t.packetsSent = prometheus.NewCounter(
   288  		prometheus.CounterOpts{
   289  			Namespace: metricNamespace,
   290  			Subsystem: metricSubsystem,
   291  			Name:      "packet_bytes_sent_total",
   292  			Help:      "The number of packet bytes sent to outgoing connections (excluding internal metadata).",
   293  		},
   294  	)
   295  	t.packetsRcvd = prometheus.NewCounter(
   296  		prometheus.CounterOpts{
   297  			Namespace: metricNamespace,
   298  			Subsystem: metricSubsystem,
   299  			Name:      "packet_bytes_received_total",
   300  			Help:      "The number of packet bytes received from incoming connections (excluding internal metadata).",
   301  		},
   302  	)
   303  	t.streamsSent = prometheus.NewCounter(
   304  		prometheus.CounterOpts{
   305  			Namespace: metricNamespace,
   306  			Subsystem: metricSubsystem,
   307  			Name:      "stream_connections_sent_total",
   308  			Help:      "The number of stream connections sent.",
   309  		},
   310  	)
   311  
   312  	t.streamsRcvd = prometheus.NewCounter(
   313  		prometheus.CounterOpts{
   314  			Namespace: metricNamespace,
   315  			Subsystem: metricSubsystem,
   316  			Name:      "stream_connections_received_total",
   317  			Help:      "The number of stream connections received.",
   318  		},
   319  	)
   320  	t.readErrs = prometheus.NewCounter(
   321  		prometheus.CounterOpts{
   322  			Namespace: metricNamespace,
   323  			Subsystem: metricSubsystem,
   324  			Name:      "read_errors_total",
   325  			Help:      "The number of errors encountered while reading from incoming connections.",
   326  		},
   327  	)
   328  	t.writeErrs = prometheus.NewCounterVec(
   329  		prometheus.CounterOpts{
   330  			Namespace: metricNamespace,
   331  			Subsystem: metricSubsystem,
   332  			Name:      "write_errors_total",
   333  			Help:      "The number of errors encountered while writing to outgoing connections.",
   334  		},
   335  		[]string{"connection_type"},
   336  	)
   337  
   338  	if reg != nil {
   339  		reg.MustRegister(t.packetsSent)
   340  		reg.MustRegister(t.packetsRcvd)
   341  		reg.MustRegister(t.streamsSent)
   342  		reg.MustRegister(t.streamsRcvd)
   343  		reg.MustRegister(t.readErrs)
   344  		reg.MustRegister(t.writeErrs)
   345  	}
   346  }
   347  

View as plain text