...
1 package conn
2
3 import (
4 "errors"
5 "math/rand"
6 "net"
7 "time"
8
9 "github.com/go-kit/log"
10 )
11
12
13
14 type Dialer func(network, address string) (net.Conn, error)
15
16
17 type AfterFunc func(time.Duration) <-chan time.Time
18
19
20
21
22
23
24
25
26 type Manager struct {
27 dialer Dialer
28 network string
29 address string
30 after AfterFunc
31 logger log.Logger
32
33 takec chan net.Conn
34 putc chan error
35 }
36
37
38
39
40
41 func NewManager(d Dialer, network, address string, after AfterFunc, logger log.Logger) *Manager {
42 m := &Manager{
43 dialer: d,
44 network: network,
45 address: address,
46 after: after,
47 logger: logger,
48
49 takec: make(chan net.Conn),
50 putc: make(chan error),
51 }
52 go m.loop()
53 return m
54 }
55
56
57
58 func NewDefaultManager(network, address string, logger log.Logger) *Manager {
59 return NewManager(net.Dial, network, address, time.After, logger)
60 }
61
62
63 func (m *Manager) Take() net.Conn {
64 return <-m.takec
65 }
66
67
68
69
70 func (m *Manager) Put(err error) {
71 m.putc <- err
72 }
73
74
75 func (m *Manager) Write(b []byte) (int, error) {
76 conn := m.Take()
77 if conn == nil {
78 return 0, ErrConnectionUnavailable
79 }
80 n, err := conn.Write(b)
81 defer m.Put(err)
82 return n, err
83 }
84
85 func (m *Manager) loop() {
86 var (
87 conn = dial(m.dialer, m.network, m.address, m.logger)
88 connc = make(chan net.Conn, 1)
89 reconnectc <-chan time.Time
90 backoff = time.Second
91 )
92
93
94
95
96 connc <- conn
97
98 for {
99 select {
100 case <-reconnectc:
101 reconnectc = nil
102 go func() { connc <- dial(m.dialer, m.network, m.address, m.logger) }()
103
104 case conn = <-connc:
105 if conn == nil {
106
107 backoff = Exponential(backoff)
108 reconnectc = m.after(backoff)
109 } else {
110
111 backoff = time.Second
112 reconnectc = nil
113 }
114
115 case m.takec <- conn:
116
117 case err := <-m.putc:
118 if err != nil && conn != nil {
119 m.logger.Log("err", err)
120 conn.Close()
121 conn = nil
122 reconnectc = m.after(time.Nanosecond)
123 }
124 }
125 }
126 }
127
128 func dial(d Dialer, network, address string, logger log.Logger) net.Conn {
129 conn, err := d(network, address)
130 if err != nil {
131 logger.Log("err", err)
132 conn = nil
133 }
134 return conn
135 }
136
137
138
139
140 func Exponential(d time.Duration) time.Duration {
141 d *= 2
142 jitter := rand.Float64() + 0.5
143 d = time.Duration(int64(float64(d.Nanoseconds()) * jitter))
144 if d > time.Minute {
145 d = time.Minute
146 }
147 return d
148
149 }
150
151
152
153 var ErrConnectionUnavailable = errors.New("connection unavailable")
154
View as plain text