1 package bdns
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "math/rand"
8 "net"
9 "strconv"
10 "sync"
11 "time"
12
13 "github.com/letsencrypt/boulder/cmd"
14 "github.com/miekg/dns"
15 "github.com/prometheus/client_golang/prometheus"
16 )
17
18
19
20
21
22 type ServerProvider interface {
23 Addrs() ([]string, error)
24 Stop()
25 }
26
27
28
29
30 type staticProvider struct {
31 servers []string
32 }
33
34 var _ ServerProvider = &staticProvider{}
35
36
37
38
39
40
41 func validateServerAddress(address string) error {
42
43 host, port, err := net.SplitHostPort(address)
44 if err != nil {
45 return err
46 }
47
48
49 if host == "" || port == "" {
50 return errors.New("port cannot be missing")
51 }
52
53
54 portNum, err := strconv.Atoi(port)
55 if err != nil {
56 return fmt.Errorf("parsing port number: %s", err)
57 }
58 if portNum <= 0 || portNum > 65535 {
59 return errors.New("port must be an integer between 0 - 65535")
60 }
61
62
63 IPv6 := net.ParseIP(host).To16()
64 IPv4 := net.ParseIP(host).To4()
65 FQDN := dns.IsFqdn(dns.Fqdn(host))
66 if IPv6 == nil && IPv4 == nil && !FQDN {
67 return errors.New("host is not an FQDN or IP address")
68 }
69 return nil
70 }
71
72 func NewStaticProvider(servers []string) (*staticProvider, error) {
73 var serverAddrs []string
74 for _, server := range servers {
75 err := validateServerAddress(server)
76 if err != nil {
77 return nil, fmt.Errorf("server address %q invalid: %s", server, err)
78 }
79 serverAddrs = append(serverAddrs, server)
80 }
81 return &staticProvider{servers: serverAddrs}, nil
82 }
83
84 func (sp *staticProvider) Addrs() ([]string, error) {
85 if len(sp.servers) == 0 {
86 return nil, fmt.Errorf("no servers configured")
87 }
88 r := make([]string, len(sp.servers))
89 perm := rand.Perm(len(sp.servers))
90 for i, v := range perm {
91 r[i] = sp.servers[v]
92 }
93 return r, nil
94 }
95
96 func (sp *staticProvider) Stop() {}
97
98
99
100
101
102 type dynamicProvider struct {
103
104
105
106
107
108 dnsAuthority string
109
110
111 service string
112
113 domain string
114
115
116 addrs map[string][]uint16
117
118 cancel chan interface{}
119 mu sync.RWMutex
120 refresh time.Duration
121 updateCounter *prometheus.CounterVec
122 }
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 func ParseTarget(target, defaultPort string) (host, port string, err error) {
139 if target == "" {
140 return "", "", errors.New("missing address")
141 }
142 ip := net.ParseIP(target)
143 if ip != nil {
144
145 return target, defaultPort, nil
146 }
147 host, port, err = net.SplitHostPort(target)
148 if err == nil {
149 if port == "" {
150
151
152 return "", "", errors.New("missing port after port-separator colon")
153 }
154
155 if host == "" {
156
157
158 host = "localhost"
159 }
160 return host, port, nil
161 }
162 host, port, err = net.SplitHostPort(target + ":" + defaultPort)
163 if err == nil {
164
165 return host, port, nil
166 }
167 return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
168 }
169
170 var _ ServerProvider = &dynamicProvider{}
171
172
173
174
175
176
177 func StartDynamicProvider(c *cmd.DNSProvider, refresh time.Duration) (*dynamicProvider, error) {
178 if c.SRVLookup.Domain == "" {
179 return nil, fmt.Errorf("'domain' cannot be empty")
180 }
181
182 service := c.SRVLookup.Service
183 if service == "" {
184
185
186 service = "dns"
187 }
188
189 host, port, err := ParseTarget(c.DNSAuthority, "53")
190 if err != nil {
191 return nil, err
192 }
193
194 dnsAuthority := net.JoinHostPort(host, port)
195 err = validateServerAddress(dnsAuthority)
196 if err != nil {
197 return nil, err
198 }
199
200 dp := dynamicProvider{
201 dnsAuthority: dnsAuthority,
202 service: service,
203 domain: c.SRVLookup.Domain,
204 addrs: make(map[string][]uint16),
205 cancel: make(chan interface{}),
206 refresh: refresh,
207 updateCounter: prometheus.NewCounterVec(
208 prometheus.CounterOpts{
209 Name: "dns_update",
210 Help: "Counter of attempts to update a dynamic provider",
211 },
212 []string{"success"},
213 ),
214 }
215
216
217
218 err = dp.update()
219 if err != nil {
220 return nil, fmt.Errorf("failed to start dynamic provider: %w", err)
221 }
222 go dp.run()
223
224 return &dp, nil
225 }
226
227
228
229 func (dp *dynamicProvider) run() {
230 t := time.NewTicker(dp.refresh)
231 for {
232 select {
233 case <-t.C:
234 err := dp.update()
235 if err != nil {
236 dp.updateCounter.With(prometheus.Labels{
237 "success": "false",
238 }).Inc()
239 continue
240 }
241 dp.updateCounter.With(prometheus.Labels{
242 "success": "true",
243 }).Inc()
244 case <-dp.cancel:
245 return
246 }
247 }
248 }
249
250
251
252
253 func (dp *dynamicProvider) update() error {
254 ctx, cancel := context.WithTimeout(context.Background(), dp.refresh/2)
255 defer cancel()
256
257 resolver := &net.Resolver{
258 PreferGo: true,
259 Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
260 d := &net.Dialer{}
261 return d.DialContext(ctx, network, dp.dnsAuthority)
262 },
263 }
264
265
266 record := fmt.Sprintf("_%s._udp.%s.", dp.service, dp.domain)
267
268 _, srvs, err := resolver.LookupSRV(ctx, dp.service, "udp", dp.domain)
269 if err != nil {
270 return fmt.Errorf("during SRV lookup of %q: %w", record, err)
271 }
272 if len(srvs) == 0 {
273 return fmt.Errorf("SRV lookup of %q returned 0 results", record)
274 }
275
276 addrPorts := make(map[string][]uint16)
277 for _, srv := range srvs {
278 addrs, err := resolver.LookupHost(ctx, srv.Target)
279 if err != nil {
280 return fmt.Errorf("during A/AAAA lookup of target %q from SRV record %q: %w", srv.Target, record, err)
281 }
282 for _, addr := range addrs {
283 joinedHostPort := net.JoinHostPort(addr, fmt.Sprint(srv.Port))
284 err := validateServerAddress(joinedHostPort)
285 if err != nil {
286 return fmt.Errorf("invalid addr %q from SRV record %q: %w", joinedHostPort, record, err)
287 }
288 addrPorts[addr] = append(addrPorts[addr], srv.Port)
289 }
290 }
291
292 dp.mu.Lock()
293 dp.addrs = addrPorts
294 dp.mu.Unlock()
295 return nil
296 }
297
298
299
300 func (dp *dynamicProvider) Addrs() ([]string, error) {
301 var r []string
302 dp.mu.RLock()
303 for ip, ports := range dp.addrs {
304 port := fmt.Sprint(ports[rand.Intn(len(ports))])
305 addr := net.JoinHostPort(ip, port)
306 r = append(r, addr)
307 }
308 dp.mu.RUnlock()
309 rand.Shuffle(len(r), func(i, j int) {
310 r[i], r[j] = r[j], r[i]
311 })
312 return r, nil
313 }
314
315
316
317 func (dp *dynamicProvider) Stop() {
318 close(dp.cancel)
319 }
320
View as plain text