...

Source file src/github.com/prometheus/alertmanager/cluster/tls_connection.go

Documentation: github.com/prometheus/alertmanager/cluster

     1  // Copyright 2020 The Prometheus Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package cluster
    15  
    16  import (
    17  	"bufio"
    18  	"crypto/tls"
    19  	"encoding/binary"
    20  	"io"
    21  	"net"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/gogo/protobuf/proto"
    26  	"github.com/hashicorp/memberlist"
    27  	"github.com/pkg/errors"
    28  
    29  	"github.com/prometheus/alertmanager/cluster/clusterpb"
    30  )
    31  
    32  const (
    33  	version      = "v0.1.0"
    34  	uint32length = 4
    35  )
    36  
    37  // tlsConn wraps net.Conn with connection pooling data.
    38  type tlsConn struct {
    39  	mtx        sync.Mutex
    40  	connection net.Conn
    41  	live       bool
    42  }
    43  
    44  func dialTLSConn(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tlsConn, error) {
    45  	dialer := &net.Dialer{Timeout: timeout}
    46  	conn, err := tls.DialWithDialer(dialer, network, addr, tlsConfig)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	return &tlsConn{
    52  		connection: conn,
    53  		live:       true,
    54  	}, nil
    55  }
    56  
    57  func rcvTLSConn(conn net.Conn) *tlsConn {
    58  	return &tlsConn{
    59  		connection: conn,
    60  		live:       true,
    61  	}
    62  }
    63  
    64  // Write writes a byte array into the connection. It returns the number of bytes written and an error.
    65  func (conn *tlsConn) Write(b []byte) (int, error) {
    66  	conn.mtx.Lock()
    67  	defer conn.mtx.Unlock()
    68  	n, err := conn.connection.Write(b)
    69  	if err != nil {
    70  		conn.live = false
    71  	}
    72  	return n, err
    73  }
    74  
    75  func (conn *tlsConn) alive() bool {
    76  	conn.mtx.Lock()
    77  	defer conn.mtx.Unlock()
    78  	return conn.live
    79  }
    80  
    81  func (conn *tlsConn) getRawConn() net.Conn {
    82  	conn.mtx.Lock()
    83  	defer conn.mtx.Unlock()
    84  	raw := conn.connection
    85  	conn.live = false
    86  	conn.connection = nil
    87  	return raw
    88  }
    89  
    90  // writePacket writes all the bytes in one operation so no concurrent write happens in between.
    91  // It prefixes the message length.
    92  func (conn *tlsConn) writePacket(fromAddr string, b []byte) error {
    93  	msg, err := proto.Marshal(
    94  		&clusterpb.MemberlistMessage{
    95  			Version:  version,
    96  			Kind:     clusterpb.MemberlistMessage_PACKET,
    97  			FromAddr: fromAddr,
    98  			Msg:      b,
    99  		},
   100  	)
   101  	if err != nil {
   102  		return errors.Wrap(err, "unable to marshal memeberlist packet message")
   103  	}
   104  	buf := make([]byte, uint32length, uint32length+len(msg))
   105  	binary.LittleEndian.PutUint32(buf, uint32(len(msg)))
   106  	_, err = conn.Write(append(buf, msg...))
   107  	return err
   108  }
   109  
   110  // writeStream simply signals that this is a stream connection by sending the connection type.
   111  func (conn *tlsConn) writeStream() error {
   112  	msg, err := proto.Marshal(
   113  		&clusterpb.MemberlistMessage{
   114  			Version: version,
   115  			Kind:    clusterpb.MemberlistMessage_STREAM,
   116  		},
   117  	)
   118  	if err != nil {
   119  		return errors.Wrap(err, "unable to marshal memeberlist stream message")
   120  	}
   121  	buf := make([]byte, uint32length, uint32length+len(msg))
   122  	binary.LittleEndian.PutUint32(buf, uint32(len(msg)))
   123  	_, err = conn.Write(append(buf, msg...))
   124  	return err
   125  }
   126  
   127  // read returns a packet for packet connections or an error if there is one.
   128  // It returns nothing if the connection is meant to be streamed.
   129  func (conn *tlsConn) read() (*memberlist.Packet, error) {
   130  	if conn.connection == nil {
   131  		return nil, errors.New("nil connection")
   132  	}
   133  
   134  	conn.mtx.Lock()
   135  	reader := bufio.NewReader(conn.connection)
   136  	lenBuf := make([]byte, uint32length)
   137  	_, err := io.ReadFull(reader, lenBuf)
   138  	if err != nil {
   139  		return nil, errors.Wrap(err, "error reading message length")
   140  	}
   141  	msgLen := binary.LittleEndian.Uint32(lenBuf)
   142  	msgBuf := make([]byte, msgLen)
   143  	_, err = io.ReadFull(reader, msgBuf)
   144  	conn.mtx.Unlock()
   145  
   146  	if err != nil {
   147  		return nil, errors.Wrap(err, "error reading message")
   148  	}
   149  	pb := clusterpb.MemberlistMessage{}
   150  	err = proto.Unmarshal(msgBuf, &pb)
   151  	if err != nil {
   152  		return nil, errors.Wrap(err, "error parsing message")
   153  	}
   154  	if pb.Version != version {
   155  		return nil, errors.New("tls memberlist message version incompatible")
   156  	}
   157  	switch pb.Kind {
   158  	case clusterpb.MemberlistMessage_STREAM:
   159  		return nil, nil
   160  	case clusterpb.MemberlistMessage_PACKET:
   161  		return toPacket(pb)
   162  	default:
   163  		return nil, errors.New("could not read from either stream or packet channel")
   164  	}
   165  }
   166  
   167  func toPacket(pb clusterpb.MemberlistMessage) (*memberlist.Packet, error) {
   168  	addr, err := net.ResolveTCPAddr(network, pb.FromAddr)
   169  	if err != nil {
   170  		return nil, errors.Wrap(err, "error parsing packet sender address")
   171  	}
   172  	return &memberlist.Packet{
   173  		Buf:       pb.Msg,
   174  		From:      addr,
   175  		Timestamp: time.Now(),
   176  	}, nil
   177  }
   178  
   179  func (conn *tlsConn) Close() error {
   180  	conn.mtx.Lock()
   181  	defer conn.mtx.Unlock()
   182  	conn.live = false
   183  	if conn.connection == nil {
   184  		return nil
   185  	}
   186  	return conn.connection.Close()
   187  }
   188  

View as plain text