1
2
3
4
5
6
7
8
9
10
11
12
13
14 package cluster
15
16 import (
17 "bufio"
18 "crypto/tls"
19 "encoding/binary"
20 "io"
21 "net"
22 "sync"
23 "time"
24
25 "github.com/gogo/protobuf/proto"
26 "github.com/hashicorp/memberlist"
27 "github.com/pkg/errors"
28
29 "github.com/prometheus/alertmanager/cluster/clusterpb"
30 )
31
32 const (
33 version = "v0.1.0"
34 uint32length = 4
35 )
36
37
38 type tlsConn struct {
39 mtx sync.Mutex
40 connection net.Conn
41 live bool
42 }
43
44 func dialTLSConn(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tlsConn, error) {
45 dialer := &net.Dialer{Timeout: timeout}
46 conn, err := tls.DialWithDialer(dialer, network, addr, tlsConfig)
47 if err != nil {
48 return nil, err
49 }
50
51 return &tlsConn{
52 connection: conn,
53 live: true,
54 }, nil
55 }
56
57 func rcvTLSConn(conn net.Conn) *tlsConn {
58 return &tlsConn{
59 connection: conn,
60 live: true,
61 }
62 }
63
64
65 func (conn *tlsConn) Write(b []byte) (int, error) {
66 conn.mtx.Lock()
67 defer conn.mtx.Unlock()
68 n, err := conn.connection.Write(b)
69 if err != nil {
70 conn.live = false
71 }
72 return n, err
73 }
74
75 func (conn *tlsConn) alive() bool {
76 conn.mtx.Lock()
77 defer conn.mtx.Unlock()
78 return conn.live
79 }
80
81 func (conn *tlsConn) getRawConn() net.Conn {
82 conn.mtx.Lock()
83 defer conn.mtx.Unlock()
84 raw := conn.connection
85 conn.live = false
86 conn.connection = nil
87 return raw
88 }
89
90
91
92 func (conn *tlsConn) writePacket(fromAddr string, b []byte) error {
93 msg, err := proto.Marshal(
94 &clusterpb.MemberlistMessage{
95 Version: version,
96 Kind: clusterpb.MemberlistMessage_PACKET,
97 FromAddr: fromAddr,
98 Msg: b,
99 },
100 )
101 if err != nil {
102 return errors.Wrap(err, "unable to marshal memeberlist packet message")
103 }
104 buf := make([]byte, uint32length, uint32length+len(msg))
105 binary.LittleEndian.PutUint32(buf, uint32(len(msg)))
106 _, err = conn.Write(append(buf, msg...))
107 return err
108 }
109
110
111 func (conn *tlsConn) writeStream() error {
112 msg, err := proto.Marshal(
113 &clusterpb.MemberlistMessage{
114 Version: version,
115 Kind: clusterpb.MemberlistMessage_STREAM,
116 },
117 )
118 if err != nil {
119 return errors.Wrap(err, "unable to marshal memeberlist stream message")
120 }
121 buf := make([]byte, uint32length, uint32length+len(msg))
122 binary.LittleEndian.PutUint32(buf, uint32(len(msg)))
123 _, err = conn.Write(append(buf, msg...))
124 return err
125 }
126
127
128
129 func (conn *tlsConn) read() (*memberlist.Packet, error) {
130 if conn.connection == nil {
131 return nil, errors.New("nil connection")
132 }
133
134 conn.mtx.Lock()
135 reader := bufio.NewReader(conn.connection)
136 lenBuf := make([]byte, uint32length)
137 _, err := io.ReadFull(reader, lenBuf)
138 if err != nil {
139 return nil, errors.Wrap(err, "error reading message length")
140 }
141 msgLen := binary.LittleEndian.Uint32(lenBuf)
142 msgBuf := make([]byte, msgLen)
143 _, err = io.ReadFull(reader, msgBuf)
144 conn.mtx.Unlock()
145
146 if err != nil {
147 return nil, errors.Wrap(err, "error reading message")
148 }
149 pb := clusterpb.MemberlistMessage{}
150 err = proto.Unmarshal(msgBuf, &pb)
151 if err != nil {
152 return nil, errors.Wrap(err, "error parsing message")
153 }
154 if pb.Version != version {
155 return nil, errors.New("tls memberlist message version incompatible")
156 }
157 switch pb.Kind {
158 case clusterpb.MemberlistMessage_STREAM:
159 return nil, nil
160 case clusterpb.MemberlistMessage_PACKET:
161 return toPacket(pb)
162 default:
163 return nil, errors.New("could not read from either stream or packet channel")
164 }
165 }
166
167 func toPacket(pb clusterpb.MemberlistMessage) (*memberlist.Packet, error) {
168 addr, err := net.ResolveTCPAddr(network, pb.FromAddr)
169 if err != nil {
170 return nil, errors.Wrap(err, "error parsing packet sender address")
171 }
172 return &memberlist.Packet{
173 Buf: pb.Msg,
174 From: addr,
175 Timestamp: time.Now(),
176 }, nil
177 }
178
179 func (conn *tlsConn) Close() error {
180 conn.mtx.Lock()
181 defer conn.mtx.Unlock()
182 conn.live = false
183 if conn.connection == nil {
184 return nil
185 }
186 return conn.connection.Close()
187 }
188
View as plain text